# Implementing the Transformer

Reference: [Implementation_Tutorial](Transformer_Implementation_Tutorial.ipynb)

In [73]:
from torch import nn 
import torch.nn.functional as F
import torch
from math import log, sqrt

In [52]:
dev = 'mps' if torch.backends.mps.is_available() else 'cpu'

## Embdedding and Position Encoding Module

In [None]:
class EmbeddingWithPositionalEncoding(nn.Module):
    def __init__(self, vocab_size: int, 
                 d_embed: int, 
                 d_model: int,
                 dropout_p: float = 0.1
                 ):
        super().__init__()
        self.d_model = d_model
        self.d_embed = d_embed
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=d_embed,
            device=dev
        )
        self.projection = nn.Linear(
            in_features=d_embed,
            out_features=d_model,
            device=dev
        )
        self.scaling = float(sqrt(self.d_model))

        self.layerNorm = nn.LayerNorm(
            self.d_model,
            device=dev
        )
        
        self.dropout = nn.Dropout(p=dropout_p)

    @staticmethod # decorator that indicates that the following function doesn't operate on `self`
    def create_positional_encoding(seq_length:int, 
                                   d_model:int, 
                                   batch_size:int
                                   ):

        positions = torch.arange(seq_length, dtype=torch.long, device=dev)\
            .unsqueeze(1) # shape (seq_length, 1) i.e. makes it vertical
        
        div_term = torch.exp(
            (torch.arange(0, d_model, 2)/d_model)*(-4)*log(10)
        ).to(dev)
        
        pe = torch.zeros(size=(seq_length, d_model), dtype=torch.long, device=dev) # the tensor to be multiplied to positions tensor to get pe
        pe[:, 0::2] = torch.sin(positions*div_term) # for even dimensions
        pe[:, 1::2] = torch.cos(positions*div_term) # for odd dimensions
        pe = pe.unsqueeze(0).expand(batch_size, -1, -1) # copy out the encodings for each batch
        return pe
    
    def forward(self, x):
        batch_size, seq_length = x.shape

        # step 1: make embeddings
        token_embedding = self.embedding(x)

        # step 2: go from d_embed to d_model
        token_embedding = self.projection(token_embedding) \
            * self.scaling # multiplying with scaling factor, just like in the paper

        # step 3: add positional encoding
        pos_encoding = self.create_positional_encoding(
            seq_length=seq_length, 
            d_model = self.d_model,
            batch_size=batch_size
        )

        #step 4: normalize the sum of pos encoding and token_embed
        norm_sum = self.layerNorm(pos_encoding + token_embedding)
        op = self.dropout(norm_sum)
        return op



## Attention Module

- Two types of attention I learnt:
  - **Self-Attention:** key values come from the same input tensor
  - **Cross-Attention:** key values come fromt he output of a different multi-head attention module

In [80]:
class TransformerAttention(nn.Module):
    def __init__(self, 
                 d_model: int,
                 num_heads: int,
                 dropout_p: float = 0.1
                 ):
        super().__init__()
        if (d_model % num_heads) != 0: raise ValueError(f'`d_model` not divisible by `num_heads`')
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_heads = self.d_model // self.num_heads
        self.scale_factor = float(1.0 / sqrt(self.d_heads))
        self.dropout = nn.Dropout(p=dropout_p)

        #linear transformations
        self.q_proj = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_model,
            device=dev
        )

        self.k_proj = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_model,
            device=dev
        )

        self.v_proj = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_model,
            device=dev
        )

        self.output_proj = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_model,
            device=dev
        )

    def forward(self, 
                seq: torch.Tensor, 
                key_value_states:torch.Tensor = None, 
                att_mask: torch.Tensor = None):
        batch_size, seq_length, d_model = seq.size()

        Q_state: torch.Tensor = self.q_proj(seq)
        if key_value_states is not None:
            kv_seq_len = key_value_states.size(1)
            K_state: torch.Tensor = self.k_proj(key_value_states)
            V_state: torch.Tensor = self.v_proj(key_value_states)
        else:
            kv_seq_len = seq_length
            K_state: torch.Tensor = self.k_proj(seq)
            V_state: torch.Tensor = self.v_proj(seq)

        Q_state = Q_state.view(batch_size, seq_length, self.num_heads, self.d_heads).transpose(1, 2)
        K_state = K_state.view(batch_size, kv_seq_len, self.num_heads, self.d_heads).transpose(1, 2)
        V_state = V_state.view(batch_size, kv_seq_len, self.num_heads, self.d_heads).transpose(1, 2)

        Q_state = Q_state * self.scale_factor
        
        self.att_matrix = torch.matmul(Q_state, K_state.transpose(-1, -2))
        

        if att_mask is not None:
            att = att + att_mask # yes, in this case the mask is not multiplied, but added. This is to ensure that after softmax the things to be excluded are 0
        
        att_score = F.softmax(self.att_matrix, dim=-1) # torch.nn.Softmax() is used in __init__, F.softmax() is used for these inline operations.
        att_score = self.dropout(att_score)
        att_op = torch.matmul(att_score, V_state)

        #concatenating all heads 
        att_op = att_op.transpose(1, 2)
        att_op = att_op.contiguous().view(batch_size, seq_length, self.num_heads*self.d_heads)

        att_op = self.output_proj(att_op)

        return att_op



##

In [83]:
enc = EmbeddingWithPositionalEncoding(
    vocab_size=100,
    d_embed=512, 
    d_model=256
)

att_layer = TransformerAttention(
    d_model=256, 
    num_heads=8
)

x = enc(torch.tensor([1, 2, 3], device=dev).unsqueeze(0))
att_layer(x)

tensor([[[-1.3663e-01,  1.7272e-01, -3.8721e-01,  1.1135e-01, -5.6162e-03,
           2.7718e-01, -2.6944e-01,  3.0791e-01,  6.9384e-02,  1.1480e-01,
           2.6803e-02, -1.0110e-01,  5.3058e-02,  4.6782e-02,  3.4256e-01,
           3.2629e-01,  2.2553e-01,  2.3477e-02,  3.6560e-01,  2.5196e-01,
          -1.0592e-01, -4.7758e-01,  4.6523e-02, -2.1700e-01,  3.0365e-01,
           1.8905e-01,  1.3182e-01, -2.1174e-01,  2.8109e-01, -1.6727e-02,
           1.6279e-01,  1.3043e-01,  2.8903e-01, -3.8033e-01, -4.1511e-02,
           1.8432e-01,  5.2712e-01,  2.4933e-01, -4.0932e-01, -7.4459e-03,
          -2.3427e-01,  1.5465e-01,  1.8812e-01,  1.6551e-01,  1.7161e-01,
           1.4275e-02, -4.1186e-03,  3.1627e-02,  1.5110e-01,  2.0662e-01,
           1.9985e-01,  1.1315e-01, -3.1759e-01,  4.6601e-01, -3.9809e-01,
          -7.4844e-02, -1.2481e-01, -1.2842e-01, -9.1034e-02, -2.1741e-01,
          -1.0017e-01, -6.9226e-02,  7.2272e-02,  2.8972e-02, -7.2722e-02,
          -4.0630e-02,  2

## Feed-Forward Network

- According, to section 3.3 of the paper, this has 2 layers
- d_model -> d_ff -> d_model
- same parameters for every position.

In [None]:
class FeedForwardNetwork(nn.Module):
    def __init__(self,
                 d_model: int,
                 d_ff: int):
        
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff

        self.fc1 = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_ff,
            device=dev
        )

        self.fc2 = nn.Linear(
            in_features=self.d_ff,
            out_features=self.d_model,
            device=dev
        )
        
    def forward(self, input:torch.Tensor):
        batch_size, seq_length, d_input = input.size()
        f1 = F.relu(self.fc1(input))
        f2 = self.fc1(f1)
        return f2
