# Implementing the Transformer

Reference: [Implementation_Tutorial](Transformer_Implementation_Tutorial.ipynb)

In [7]:
from torch import nn 
import torch.nn.functional as F
import torch
from math import log, sqrt

In [8]:
dev = 'mps' if torch.backends.mps.is_available() else 'cpu'

## Embdedding and Position Encoding Module

In [9]:
class EmbeddingWithPositionalEncoding(nn.Module):
    def __init__(self, vocab_size: int, 
                 d_embed: int, 
                 d_model: int,
                 dropout_p: float = 0.1
                 ):
        super().__init__()
        self.d_model = d_model
        self.d_embed = d_embed
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=d_embed,
            device=dev
        )
        self.projection = nn.Linear(
            in_features=d_embed,
            out_features=d_model,
            device=dev
        )
        self.scaling = float(sqrt(self.d_model))

        self.layerNorm = nn.LayerNorm(
            self.d_model,
            device=dev
        )
        
        self.dropout = nn.Dropout(p=dropout_p)

    @staticmethod # decorator that indicates that the following function doesn't operate on `self`
    def create_positional_encoding(seq_length:int, 
                                   d_model:int, 
                                   batch_size:int
                                   ):

        positions = torch.arange(seq_length, dtype=torch.long, device=dev)\
            .unsqueeze(1) # shape (seq_length, 1) i.e. makes it vertical
        
        div_term = torch.exp(
            (torch.arange(0, d_model, 2)/d_model)*(-4)*log(10)
        ).to(dev)
        
        pe = torch.zeros(size=(seq_length, d_model), dtype=torch.long, device=dev) # the tensor to be multiplied to positions tensor to get pe
        pe[:, 0::2] = torch.sin(positions*div_term) # for even dimensions
        pe[:, 1::2] = torch.cos(positions*div_term) # for odd dimensions
        pe = pe.unsqueeze(0).expand(batch_size, -1, -1) # copy out the encodings for each batch
        return pe
    
    def forward(self, x):
        batch_size, seq_length = x.shape

        # step 1: make embeddings
        token_embedding = self.embedding(x)

        # step 2: go from d_embed to d_model
        token_embedding = self.projection(token_embedding) \
            * self.scaling # multiplying with scaling factor, just like in the paper

        # step 3: add positional encoding
        pos_encoding = self.create_positional_encoding(
            seq_length=seq_length, 
            d_model = self.d_model,
            batch_size=batch_size
        )

        #step 4: normalize the sum of pos encoding and token_embed
        norm_sum = self.layerNorm(pos_encoding + token_embedding)
        op = self.dropout(norm_sum)
        return op



## Attention Module

- Two types of attention I learnt:
  - **Self-Attention:** key values come from the same input tensor
  - **Cross-Attention:** key values come fromt he output of a different multi-head attention module

In [54]:
class TransformerAttention(nn.Module):
    def __init__(self, 
                 d_model: int,
                 num_heads: int,
                 dropout_p: float = 0.1
                 ):
        super().__init__()
        if (d_model % num_heads) != 0: raise ValueError(f'`d_model` not divisible by `num_heads`')
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_heads = self.d_model // self.num_heads
        self.scale_factor = float(1.0 / sqrt(self.d_heads))
        self.dropout = nn.Dropout(p=dropout_p)

        #linear transformations
        self.q_proj = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_model,
            device=dev
        )

        self.k_proj = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_model,
            device=dev
        )

        self.v_proj = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_model,
            device=dev
        )

        self.output_proj = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_model,
            device=dev
        )

    def forward(self, 
                seq: torch.Tensor, 
                key_value_states:torch.Tensor = None, 
                att_mask: torch.Tensor = None):
        batch_size, seq_length, d_model = seq.size()

        Q_state: torch.Tensor = self.q_proj(seq)
        if key_value_states is not None:
            kv_seq_len = key_value_states.size(1)
            K_state: torch.Tensor = self.k_proj(key_value_states)
            V_state: torch.Tensor = self.v_proj(key_value_states)
        else:
            kv_seq_len = seq_length
            K_state: torch.Tensor = self.k_proj(seq)
            V_state: torch.Tensor = self.v_proj(seq)

        Q_state = Q_state.view(batch_size, seq_length, self.num_heads, self.d_heads).transpose(1, 2)
        K_state = K_state.view(batch_size, kv_seq_len, self.num_heads, self.d_heads).transpose(1, 2)
        V_state = V_state.view(batch_size, kv_seq_len, self.num_heads, self.d_heads).transpose(1, 2)

        Q_state = Q_state * self.scale_factor
        
        self.att_matrix = torch.matmul(Q_state, K_state.transpose(-1, -2))
        

        if att_mask is not None:
            self.att_matrix = self.att_matrix + att_mask # yes, in this case the mask is not multiplied, but added. This is to ensure that after softmax the things to be excluded are 0
        
        att_score = F.softmax(self.att_matrix, dim=-1) # torch.nn.Softmax() is used in __init__, F.softmax() is used for these inline operations.
        att_score = self.dropout(att_score)
        att_op = torch.matmul(att_score, V_state)

        #concatenating all heads 
        att_op = att_op.transpose(1, 2)
        att_op = att_op.contiguous().view(batch_size, seq_length, self.num_heads*self.d_heads)

        att_op = self.output_proj(att_op)

        return att_op



##

In [11]:
enc = EmbeddingWithPositionalEncoding(
    vocab_size=100,
    d_embed=512, 
    d_model=256
)

att_layer = TransformerAttention(
    d_model=256, 
    num_heads=8
)

x = enc(torch.tensor([1, 2, 3], device=dev).unsqueeze(0))
att_layer(x)

tensor([[[-5.8576e-01, -3.7510e-02, -3.2933e-02, -1.5562e-01, -1.5543e-01,
          -9.6675e-02, -8.3409e-02, -4.0597e-01, -2.9541e-01, -2.0960e-01,
          -3.6246e-01, -3.4334e-01, -4.5305e-02, -1.9015e-01, -1.8598e-01,
          -2.7759e-01,  6.2402e-01,  9.8864e-02,  3.1582e-01, -1.7067e-02,
          -8.0224e-02, -1.5658e-01, -5.4462e-02,  5.5463e-02,  9.8601e-02,
           3.6265e-01, -5.2200e-01, -1.9980e-01, -2.0509e-02, -1.6520e-01,
           1.5930e-01,  7.7368e-02, -3.3910e-02, -1.4300e-02, -3.5297e-02,
          -6.4087e-02,  4.5576e-02, -3.7363e-02, -4.2275e-01,  2.5096e-01,
           5.1844e-02,  2.2476e-01, -2.3435e-01, -3.9022e-02,  1.4253e-01,
           2.1255e-01, -2.2203e-01,  1.8360e-01, -3.0399e-01, -8.9777e-02,
           1.0599e-01, -2.9458e-01, -3.7979e-01,  1.3672e-01,  3.4986e-01,
           9.6259e-02, -1.7183e-02,  2.4675e-01, -9.0852e-02, -1.0352e-01,
           1.7198e-01,  5.5760e-01,  2.8286e-02, -3.9149e-01, -2.6553e-01,
           5.9615e-02, -1

## Feed-Forward Network

- According, to section 3.3 of the paper, this has 2 layers
- d_model -> d_ff -> d_model
- same parameters for every position.

In [12]:
class FeedForwardNetwork(nn.Module):
    def __init__(self,
                 d_model: int,
                 d_ff: int):
        
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff

        self.fc1 = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_ff,
            device=dev
        )

        self.fc2 = nn.Linear(
            in_features=self.d_ff,
            out_features=self.d_model,
            device=dev
        )
        
    def forward(self, input:torch.Tensor):
        batch_size, seq_length, d_input = input.size()
        f1 = F.relu(self.fc1(input))
        f2 = self.fc2(f1)
        return f2


## The Encoder Module

In [42]:
class TransformerEncoder(nn.Module):
    def __init__(self,
                 d_model: int, 
                 num_heads: int,
                 d_ff: int,
                 dropout_p = 0.1
                 ):
        
        super().__init__()

        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout_p = dropout_p
        self.d_ff = d_ff

        self.att_layer = TransformerAttention(
            d_model=self.d_model,
            num_heads=self.num_heads,
            dropout_p=self.dropout_p
        )

        self.ffn = FeedForwardNetwork(
            d_model=self.d_model,
            d_ff = self.d_ff
        )

        self.norm1 = nn.LayerNorm(self.d_model, device=dev)
        self.norm2 = nn.LayerNorm(self.d_model, device=dev)

        self.dropout = nn.Dropout(p=self.dropout_p)
        
    def forward(self, x: torch.Tensor):
        x_att = self.att_layer(x)

        x_att = self.dropout(x_att)
        x_norm1 = self.norm1(x + x_att)

        x_ff = self.ffn(x_norm1)

        x_ff = self.dropout(x_ff)
        x_norm2 = self.norm2(x_ff + x_norm1)
        
        return x_norm2

In [39]:
encoder = TransformerEncoder(
    d_model= 512,
    d_ff=2048,
    num_heads=4
)

In [41]:
enc = EmbeddingWithPositionalEncoding(
    vocab_size=100,
    d_embed=1024, 
    d_model=512
)

encoder(
    enc(torch.tensor([1, 2, 4], device=dev).unsqueeze(0))
    )
encoder

TransformerEncoder(
  (att_layer): TransformerAttention(
    (dropout): Dropout(p=0.1, inplace=False)
    (q_proj): Linear(in_features=512, out_features=512, bias=True)
    (k_proj): Linear(in_features=512, out_features=512, bias=True)
    (v_proj): Linear(in_features=512, out_features=512, bias=True)
    (output_proj): Linear(in_features=512, out_features=512, bias=True)
  )
  (ffn): FeedForwardNetwork(
    (fc1): Linear(in_features=512, out_features=2048, bias=True)
    (fc2): Linear(in_features=2048, out_features=512, bias=True)
  )
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

## Transformer Decoder

only the final encoder output is used for all the decoder layers.

In [56]:
class TransformerDecoder(nn.Module):
    def __init__(self,
                 d_model: int,
                 num_heads: int,
                 d_ff: int,
                 dropout_p = 0.1):
        super().__init__()

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.dropout_p = dropout_p

        self.att_layer1 = TransformerAttention(
            d_model=self.d_model,
            num_heads=self.num_heads,
            dropout_p=self.dropout_p
        )

        self.norm1 = nn.LayerNorm(self.d_model, device=dev)

        self.att_layer2 = TransformerAttention(
            d_model=self.d_model,
            num_heads=self.num_heads,
            dropout_p=self.dropout_p
        )

        self.norm2 = nn.LayerNorm(self.d_model, device=dev)

        self.ffn = FeedForwardNetwork(
            d_model=self.d_model,
            d_ff = self.d_ff
        )

        self.norm3 = nn.LayerNorm(self.d_model, device=dev)

        self.dropout = nn.Dropout(p=self.dropout_p)

    @staticmethod
    def create_causal_mask(seq_len: int) -> torch.Tensor:
        mask = torch.triu(torch.ones(seq_len, seq_len, device=dev), diagonal=1)
        mask = mask.masked_fill(mask == 1, value=float('-inf'))
        return mask

    def forward(self, x: torch.Tensor, 
                cross_input:torch.Tensor,
                padding_mask:torch.Tensor = None
                ):
        batch_size, seq_length, d_model = x.size()

        causal_mask = self.create_causal_mask(seq_len=seq_length)
        causal_mask = causal_mask.unsqueeze(0).unsqueeze(1) #unsqeeze the mast for self attention

        x_att1 = self.att_layer1( #self-attention
            seq=x,
            att_mask = causal_mask
        )

        x_att1 = self.dropout(x_att1)
        x_norm1 = self.norm1(x_att1 + x)

        x_att2 = self.att_layer2( #cross attention
            seq=x_norm1,
            key_value_states=cross_input,
            att_mask = padding_mask
        )

        x_att2 = self.dropout(x_att2)
        x_norm2 = self.norm2(x_att2)

        x_ff = self.ffn(x_norm2)

        x_ff = self.dropout(x_ff)
        x_norm3 = self.norm3(x_ff)
        
        return x_norm3


In [57]:
enc = EmbeddingWithPositionalEncoding(
    vocab_size=100,
    d_embed=1024, 
    d_model=512
)

decoder = TransformerDecoder(
    d_model=512,
    num_heads=4,
    d_ff=2048
)

x = encoder(
    enc(torch.tensor([1, 2, 4], device=dev).unsqueeze(0))
    )

decoder.forward(enc(torch.tensor([1, 2, 4], device=dev).unsqueeze(0)), x)


tensor([[[ 0.5318, -0.0829, -0.0234,  ..., -0.5824, -1.5428,  0.5702],
         [-0.1794,  1.1398,  0.2151,  ...,  0.1587, -0.6725,  0.7063],
         [-0.3512,  1.3743,  0.6984,  ...,  0.2048, -0.0566,  1.0124]]],
       device='mps:0', grad_fn=<NativeLayerNormBackward0>)

## Transformer Encoder and Decoder 