# Implementing the Transformer

Reference: [Implementation_Tutorial](Transformer_Implementation_Tutorial.ipynb)

In [2]:
from torch import nn 
import torch.nn.functional as F
import torch
from math import log, sqrt

In [3]:
dev = 'mps' if torch.backends.mps.is_available() else 'cpu'

## Embdedding and Position Encoding Module

In [4]:
class EmbeddingWithPositionalEncoding(nn.Module):
    def __init__(self, vocab_size: int, 
                 d_embed: int, 
                 d_model: int,
                 dropout_p: float = 0.1
                 ):
        super().__init__()
        self.d_model = d_model
        self.d_embed = d_embed
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=d_embed,
            device=dev
        )
        self.projection = nn.Linear(
            in_features=d_embed,
            out_features=d_model,
            device=dev
        )
        self.scaling = float(sqrt(self.d_model))

        self.layerNorm = nn.LayerNorm(
            self.d_model,
            device=dev
        )
        
        self.dropout = nn.Dropout(p=dropout_p)

    @staticmethod # decorator that indicates that the following function doesn't operate on `self`
    def create_positional_encoding(seq_length:int, 
                                   d_model:int, 
                                   batch_size:int
                                   ):

        positions = torch.arange(seq_length, dtype=torch.long, device=dev)\
            .unsqueeze(1) # shape (seq_length, 1) i.e. makes it vertical
        
        div_term = torch.exp(
            (torch.arange(0, d_model, 2)/d_model)*(-4)*log(10)
        ).to(dev)
        
        pe = torch.zeros(size=(seq_length, d_model), dtype=torch.long, device=dev) # the tensor to be multiplied to positions tensor to get pe
        pe[:, 0::2] = torch.sin(positions*div_term) # for even dimensions
        pe[:, 1::2] = torch.cos(positions*div_term) # for odd dimensions
        pe = pe.unsqueeze(0).expand(batch_size, -1, -1) # copy out the encodings for each batch
        return pe
    
    def forward(self, x):
        batch_size, seq_length = x.shape

        # step 1: make embeddings
        token_embedding = self.embedding(x)

        # step 2: go from d_embed to d_model
        token_embedding = self.projection(token_embedding) \
            * self.scaling # multiplying with scaling factor, just like in the paper

        # step 3: add positional encoding
        pos_encoding = self.create_positional_encoding(
            seq_length=seq_length, 
            d_model = self.d_model,
            batch_size=batch_size
        )

        #step 4: normalize the sum of pos encoding and token_embed
        norm_sum = self.layerNorm(pos_encoding + token_embedding)
        op = self.dropout(norm_sum)
        return op



## Attention Module

- Two types of attention I learnt:
  - **Self-Attention:** key values come from the same input tensor
  - **Cross-Attention:** key values come fromt he output of a different multi-head attention module

In [5]:
class TransformerAttention(nn.Module):
    def __init__(self, 
                 d_model: int,
                 num_heads: int,
                 dropout_p: float = 0.1
                 ):
        super().__init__()
        if (d_model % num_heads) != 0: raise ValueError(f'`d_model` not divisible by `num_heads`')
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_heads = self.d_model // self.num_heads
        self.scale_factor = float(1.0 / sqrt(self.d_heads))
        self.dropout = nn.Dropout(p=dropout_p)

        #linear transformations
        self.q_proj = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_model,
            device=dev
        )

        self.k_proj = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_model,
            device=dev
        )

        self.v_proj = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_model,
            device=dev
        )

        self.output_proj = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_model,
            device=dev
        )

    def forward(self, 
                seq: torch.Tensor, 
                key_value_states:torch.Tensor = None, 
                att_mask: torch.Tensor = None):
        batch_size, seq_length, d_model = seq.size()

        Q_state: torch.Tensor = self.q_proj(seq)
        if key_value_states is not None:
            kv_seq_len = key_value_states.size(1)
            K_state: torch.Tensor = self.k_proj(key_value_states)
            V_state: torch.Tensor = self.v_proj(key_value_states)
        else:
            kv_seq_len = seq_length
            K_state: torch.Tensor = self.k_proj(seq)
            V_state: torch.Tensor = self.v_proj(seq)

        Q_state = Q_state.view(batch_size, seq_length, self.num_heads, self.d_heads).transpose(1, 2)
        K_state = K_state.view(batch_size, kv_seq_len, self.num_heads, self.d_heads).transpose(1, 2)
        V_state = V_state.view(batch_size, kv_seq_len, self.num_heads, self.d_heads).transpose(1, 2)

        Q_state = Q_state * self.scale_factor
        
        self.att_matrix = torch.matmul(Q_state, K_state.transpose(-1, -2))
        

        if att_mask is not None:
            self.att_matrix = self.att_matrix + att_mask # yes, in this case the mask is not multiplied, but added. This is to ensure that after softmax the things to be excluded are 0
        
        att_score = F.softmax(self.att_matrix, dim=-1) # torch.nn.Softmax() is used in __init__, F.softmax() is used for these inline operations.
        att_score = self.dropout(att_score)
        att_op = torch.matmul(att_score, V_state)

        #concatenating all heads 
        att_op = att_op.transpose(1, 2)
        att_op = att_op.contiguous().view(batch_size, seq_length, self.num_heads*self.d_heads)

        att_op = self.output_proj(att_op)

        return att_op



##

In [6]:
enc = EmbeddingWithPositionalEncoding(
    vocab_size=100,
    d_embed=512, 
    d_model=256
)

att_layer = TransformerAttention(
    d_model=256, 
    num_heads=8
)

x = enc(torch.tensor([1, 2, 3], device=dev).unsqueeze(0))
att_layer(x)

tensor([[[-5.6083e-02, -2.7584e-01,  1.2233e-01,  1.0724e-01,  1.3438e-01,
           5.8057e-02, -1.5747e-01,  6.2093e-02, -1.0438e-01,  1.4907e-01,
           1.0775e-01, -1.8889e-01, -4.2672e-01, -2.6931e-01,  1.7088e-01,
           9.8167e-02, -2.5291e-01,  6.9719e-02,  2.9859e-01, -4.9563e-02,
          -9.1499e-02,  6.5453e-02,  3.4010e-01,  2.3936e-01,  4.1414e-01,
          -2.1820e-01, -1.4620e-01, -1.7827e-01, -1.0829e-01,  1.2271e-01,
           6.5892e-02,  9.5665e-02, -3.1794e-01, -4.1636e-02,  9.9729e-02,
          -4.3897e-01, -2.3396e-01,  1.1752e-01, -1.2415e-01,  1.5225e-01,
           1.6180e-01,  2.2708e-01,  1.4612e-01, -4.0081e-01, -2.2000e-01,
           1.8675e-01,  2.6911e-01, -6.4734e-02, -1.2841e-01,  9.5680e-02,
          -3.5189e-01,  1.7776e-01,  1.4775e-01, -2.2859e-01,  1.3291e-01,
          -3.9022e-01,  1.0306e-01, -1.6720e-01,  9.4699e-02,  1.5673e-01,
          -2.1388e-01,  1.0639e-01,  1.4410e-01, -8.3046e-02, -4.2238e-02,
           1.6983e-01, -1

## Feed-Forward Network

- According, to section 3.3 of the paper, this has 2 layers
- d_model -> d_ff -> d_model
- same parameters for every position.

In [7]:
class FeedForwardNetwork(nn.Module):
    def __init__(self,
                 d_model: int,
                 d_ff: int):
        
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff

        self.fc1 = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_ff,
            device=dev
        )

        self.fc2 = nn.Linear(
            in_features=self.d_ff,
            out_features=self.d_model,
            device=dev
        )
        
    def forward(self, input:torch.Tensor):
        batch_size, seq_length, d_input = input.size()
        f1 = F.relu(self.fc1(input))
        f2 = self.fc2(f1)
        return f2


## The Encoder Module

In [8]:
class TransformerEncoder(nn.Module):
    def __init__(self,
                 d_model: int, 
                 num_heads: int,
                 d_ff: int,
                 dropout_p = 0.1
                 ):
        
        super().__init__()

        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout_p = dropout_p
        self.d_ff = d_ff

        self.att_layer = TransformerAttention(
            d_model=self.d_model,
            num_heads=self.num_heads,
            dropout_p=self.dropout_p
        )

        self.ffn = FeedForwardNetwork(
            d_model=self.d_model,
            d_ff = self.d_ff
        )

        self.norm1 = nn.LayerNorm(self.d_model, device=dev)
        self.norm2 = nn.LayerNorm(self.d_model, device=dev)

        self.dropout = nn.Dropout(p=self.dropout_p)
        
    def forward(self, x: torch.Tensor):
        x_att = self.att_layer(x)

        x_att = self.dropout(x_att)
        x_norm1 = self.norm1(x + x_att)

        x_ff = self.ffn(x_norm1)

        x_ff = self.dropout(x_ff)
        x_norm2 = self.norm2(x_ff + x_norm1)
        
        return x_norm2

In [9]:
encoder = TransformerEncoder(
    d_model= 512,
    d_ff=2048,
    num_heads=4
)

In [10]:
enc = EmbeddingWithPositionalEncoding(
    vocab_size=100,
    d_embed=1024, 
    d_model=512
)

encoder(
    enc(torch.tensor([1, 2, 4], device=dev).unsqueeze(0))
    )
encoder

TransformerEncoder(
  (att_layer): TransformerAttention(
    (dropout): Dropout(p=0.1, inplace=False)
    (q_proj): Linear(in_features=512, out_features=512, bias=True)
    (k_proj): Linear(in_features=512, out_features=512, bias=True)
    (v_proj): Linear(in_features=512, out_features=512, bias=True)
    (output_proj): Linear(in_features=512, out_features=512, bias=True)
  )
  (ffn): FeedForwardNetwork(
    (fc1): Linear(in_features=512, out_features=2048, bias=True)
    (fc2): Linear(in_features=2048, out_features=512, bias=True)
  )
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

## Transformer Decoder

only the final encoder output is used for all the decoder layers.

In [11]:
class TransformerDecoder(nn.Module):
    def __init__(self,
                 d_model: int,
                 num_heads: int,
                 d_ff: int,
                 dropout_p = 0.1):
        super().__init__()

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.dropout_p = dropout_p

        self.att_layer1 = TransformerAttention(
            d_model=self.d_model,
            num_heads=self.num_heads,
            dropout_p=self.dropout_p
        )

        self.norm1 = nn.LayerNorm(self.d_model, device=dev)

        self.att_layer2 = TransformerAttention(
            d_model=self.d_model,
            num_heads=self.num_heads,
            dropout_p=self.dropout_p
        )

        self.norm2 = nn.LayerNorm(self.d_model, device=dev)

        self.ffn = FeedForwardNetwork(
            d_model=self.d_model,
            d_ff = self.d_ff
        )

        self.norm3 = nn.LayerNorm(self.d_model, device=dev)

        self.dropout = nn.Dropout(p=self.dropout_p)

    @staticmethod
    def create_causal_mask(seq_len: int) -> torch.Tensor:
        mask = torch.triu(torch.ones(seq_len, seq_len, device=dev), diagonal=1)
        mask = mask.masked_fill(mask == 1, value=float('-inf'))
        return mask

    def forward(self, x: torch.Tensor, 
                cross_input:torch.Tensor,
                padding_mask:torch.Tensor = None
                ):
        batch_size, seq_length, d_model = x.size()

        causal_mask = self.create_causal_mask(seq_len=seq_length)
        causal_mask = causal_mask.unsqueeze(0).unsqueeze(1) #unsqeeze the mast for self attention

        x_att1 = self.att_layer1( #self-attention
            seq=x,
            att_mask = causal_mask
        )

        x_att1 = self.dropout(x_att1)
        x_norm1 = self.norm1(x_att1 + x)

        x_att2 = self.att_layer2( #cross attention
            seq=x_norm1,
            key_value_states=cross_input,
            att_mask = padding_mask
        )

        x_att2 = self.dropout(x_att2)
        x_norm2 = self.norm2(x_att2)

        x_ff = self.ffn(x_norm2)

        x_ff = self.dropout(x_ff)
        x_norm3 = self.norm3(x_ff)
        
        return x_norm3


In [12]:
enc = EmbeddingWithPositionalEncoding(
    vocab_size=100,
    d_embed=1024, 
    d_model=512
)

decoder = TransformerDecoder(
    d_model=512,
    num_heads=4,
    d_ff=2048
)

x = encoder(
    enc(torch.tensor([1, 2, 4], device=dev).unsqueeze(0))
    )

decoder.forward(enc(torch.tensor([1, 2, 4], device=dev).unsqueeze(0)), x)


tensor([[[-1.1775, -2.3496,  0.2176,  ..., -0.0558,  0.3318,  0.4629],
         [-1.0657, -2.5011, -0.0984,  ..., -0.0984, -0.0984, -0.2286],
         [-0.9635, -2.5325,  0.1274,  ..., -0.8475,  0.0596,  0.2640]]],
       device='mps:0', grad_fn=<NativeLayerNormBackward0>)

## Transformer Encoder and Decoder 

In [21]:
class TransformerEncoderDecoder(nn.Module):
    def __init__(self,
                 N_enc: int,
                 N_dec: int, 
                 d_model:int,
                 num_heads: int, 
                 d_ff: int,
                 dropout_p = 0.1
                 ):
        
        super().__init__()

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.dropout_p = dropout_p

        self.encoder_stack = nn.ModuleList([
            TransformerEncoder(
                d_model=self.d_model,
                num_heads=self.num_heads,
                d_ff=self.d_ff,
                dropout_p=self.dropout_p
            ) for _ in range(N_enc)
        ])

        self.decoder_stack = nn.ModuleList([
            TransformerDecoder(
                d_model=self.d_model,
                num_heads=self.num_heads,
                d_ff=self.d_ff,
                dropout_p=self.dropout_p
            )
        ])

    def forward(self, x: torch.Tensor, y:torch.Tensor) -> torch.Tensor:
        #pass through the encoder stack
        encoder_output = x
        for encoder in self.encoder_stack:
            encoder_output = encoder(encoder_output)

        #pass through the decoder stack
        #uses only the final encoder input
        decoder_output = y
        for decoder in self.decoder_stack:
            decoder_output = decoder(decoder_output, cross_input=encoder_output)

        return decoder_output

In [22]:
enc = EmbeddingWithPositionalEncoding(
    vocab_size=100,
    d_embed=1024, 
    d_model=512
)

tf = TransformerEncoderDecoder(
    N_enc = 6,
    N_dec=6,
    d_model=512,
    num_heads=4,
    d_ff=2048
)

tf(enc(
    torch.tensor([1, 2, 3], device=dev).unsqueeze(0)
    ), enc(
    torch.tensor([1, 2, 3], device=dev).unsqueeze(0)
    ))

tensor([[[-1.1387, -0.0440, -1.0679,  ..., -0.0614,  0.7674,  0.8670],
         [-1.0070,  1.0224, -0.7265,  ...,  0.3842,  0.6753,  0.3150],
         [-1.0564,  1.3978, -1.8241,  ..., -0.2619,  0.8187, -0.0142]]],
       device='mps:0', grad_fn=<NativeLayerNormBackward0>)

## Full Transformer

In [67]:
class Transformer(nn.Module):
    def __init__(self,
                 N_enc: int,
                 N_dec: int,
                 vocab_size:int, 
                 d_embed: int,
                 d_model: int,
                 num_heads: int,
                 d_ff: int,
                 d_tgt_vocab: int,
                 dropout_p = 0.1):
        super().__init__()

        self.N_enc = N_enc
        self.N_dec = N_dec
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.d_embed = d_embed
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.dropout_p = dropout_p
        self.d_tgt_vocab = d_tgt_vocab

        self.src_embedder = EmbeddingWithPositionalEncoding(
            vocab_size=self.vocab_size,
            d_embed=self.d_embed,
            d_model=self.d_model,
            dropout_p=self.dropout_p
        )

        self.tgt_embedder = EmbeddingWithPositionalEncoding(
            vocab_size=self.vocab_size,
            d_embed=self.d_embed,
            d_model=self.d_model,
            dropout_p=self.dropout_p
        )

        self.encoder_decoder_stack = TransformerEncoderDecoder(
            N_enc=self.N_enc,
            N_dec=self.N_dec,
            d_model=self.d_model,
            num_heads=self.num_heads,
            d_ff=self.d_ff,
            dropout_p=self.dropout_p
        )

        self.output_proj = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_tgt_vocab,
            device=dev
        )

        self.softmax = nn.LogSoftmax(dim=-1)

    @staticmethod
    def shift_target_right(tgt_tokens: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len = tgt_tokens.size() # no d_model since, no Embedding done
        zer = torch.zeros(
            size=(batch_size, 1),
            device=dev,
            dtype=torch.long
        )
        return torch.concat([
            zer, 
            tgt_tokens[:, :-1]], 
            dim=1).to(dev)

    def forward(self, 
                src_tokens:torch.Tensor, 
                tgt_tokens:torch.Tensor
                ) -> torch.Tensor:
        
        tgt_tokens = self.shift_target_right(tgt_tokens) 
        # shifting is needed to prevent information leakage. 
        # it allows parallel traning in spite of hiding the token 
        # to be predicted.
        inp_embed = self.src_embedder(src_tokens)
        tgt_embed = self.tgt_embedder(tgt_tokens)

        enc_dec_out = self.encoder_decoder_stack.forward(inp_embed, tgt_embed)
        
        out = self.output_proj(enc_dec_out)
        log_probs = self.softmax(out)
        return log_probs


In [68]:
tf = Transformer(
    N_enc = 6,
    N_dec=6,
    d_model=512,
    num_heads=4,
    d_ff=2048,
    vocab_size=100,
    d_embed=1024,
    d_tgt_vocab=100
)

In [69]:
total_params = sum(p.numel() for p in tf.parameters())
print(f"Total parameters: {total_params:,}")

Total parameters: 24,426,084


In [70]:
tf(
    src_tokens=torch.tensor([1., 2., 3.], device=dev, dtype=torch.long).unsqueeze(0),
    tgt_tokens=torch.tensor([1., 2., 3.], device=dev, dtype=torch.long).unsqueeze(0),
).shape

torch.Size([1, 3, 100])