# Transformer implementation Using Pytorch

## Below is the Todo-list
Will be updated as we go about the code

In [42]:
#TODO: look into torchtext
#TODO: get the data
#TODO: pre-process the data - tokenize it as well
#TODO: implement the training loop
#TODO: Train the model

## Implementation Starts here

In [43]:
# Author: Rishabh Agarwal
# All the imports for the code

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tokenizers import Tokenizer
from torch.utils.data import Dataset, DataLoader

import math

In [44]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [45]:
# Feed Forward Network
# Used in the Transformer Block
class FFN(nn.Module):
    def __init__(self, device, d_model=512, dff = 2048):
        super().__init__()
        self.linear1 = nn.Linear(d_model, dff, device=device)
        self.linear2 = nn.Linear(dff, d_model, device=device)
        self.gelu = nn.GELU()
        # One reason of choosing GELU - not in the initial implementation of the transformer in the paper
        """
        relu can suffer from "problems where significant amount of neuron in the network become zero and don’t practically do anything." 
        gelu is smoother near zero and "is differentiable in all ranges, and allows to have gradients(although small) in negative range" which helps with this problem.
        """

        self.device = device

    def forward(self, x):
        x = x.to(self.device)
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.linear2(x)
        return x 

In [46]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, d_model, dropout=0.1, device = 'cuda'):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.device = device
        self.d_k = d_model // num_heads # using d_model and number of heads for d_k calculation
        self.num_heads = num_heads
        
        # Combined linear projections for Q, K, V
        self.q_proj = nn.Linear(d_model, d_model,device=self.device)
        self.k_proj = nn.Linear(d_model, d_model, device=self.device)
        self.v_proj = nn.Linear(d_model, d_model, device=self.device)
        self.out_proj = nn.Linear(d_model, d_model, device= self.device)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.d_k)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # Linear projections and reshape for attention
        q = self.q_proj(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        k = self.k_proj(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = self.v_proj(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Scaled dot-product attention (can use internal implementation as well)
        scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        # Apply attention to values
        context = torch.matmul(attn, v)
        
        # Reshape and apply output projection
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        return self.out_proj(context)

In [47]:
# Test the MultiHeadedAttention class
def test_multi_headed_attention(device):
    h = 8
    d_model = 512
    batch_size = 32

    mha = MultiHeadAttention(num_heads=h, d_model= d_model, device=device)

    Q = torch.randn(batch_size, 10, d_model, device=device)
    K = torch.randn(batch_size, 10, d_model, device=device)
    V = torch.randn(batch_size, 10, d_model, device=device)

    out = mha(Q, K, V)

    assert out.size() == (batch_size, 10, d_model)

    print("MultiHeadedAttention test passed")
test_multi_headed_attention(device)

MultiHeadedAttention test passed


In [48]:
# One of N Encoder Blocks in the Transformer
class Encoder(nn.Module):
    def __init__(self, num_heads, d_model, dff,  device, dropout=0.1):
        super().__init__()
        self.mha = MultiHeadAttention(num_heads, d_model, dropout, device=device)
        self.layernorm1 = nn.LayerNorm(d_model, device=device)

        self.ffn = FFN(d_model=d_model, device=device, dff=dff)
        self.layernorm2 = nn.LayerNorm(d_model, device =device)

        self.device = device
    def forward(self, input_emb, mask=None):
        input_emb = input_emb.to(self.device)
        x = self.mha(input_emb, input_emb, input_emb, mask)
        x = self.layernorm1(x + input_emb)
        ffn_out = self.ffn(x).to(self.device)
        x = self.layernorm2(x + ffn_out)
        return x

In [49]:
# Test the encoder block
def test_encoder_block(device):
    num_heads = 8
    d_model = 512
    dff = 2048
    batch_size = 32

    encoder_block = Encoder(num_heads, d_model, dff= dff, device= device)

    input_emb = torch.randn(batch_size, 10, d_model, device=device)

    out = encoder_block(input_emb)

    assert out.size() == (batch_size, 10, d_model)

    print("Encoder block test passed")
test_encoder_block(device)

Encoder block test passed


In [50]:
# One of N Decoders' Implementation in the Transformer
class Decoder(nn.Module):
    def __init__(self,num_heads, d_model, dff, device, dropout=0.1 ):
        super().__init__()
        self.self_attn = MultiHeadAttention(num_heads, d_model, dropout, device=device)
        self.cross_attn = MultiHeadAttention(num_heads, d_model,dropout, device=device)

        self.ffn = FFN(d_model =d_model, device=device, dff= dff)

        self.layernorm1 = nn.LayerNorm(d_model, device=device)
        self.layernorm2 = nn.LayerNorm(d_model, device=device)
        self.layernorm3 = nn.LayerNorm(d_model, device=device)

        self.device = device

        self.dropout = nn.Dropout(dropout)
    def forward(self, x, residual, src_mask=None, tgt_mask=None):
        # Self-attention
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.layernorm1(x + self.dropout(attn_output))
        
        # Cross-attention
        attn_output = self.cross_attn(x, residual, residual, src_mask)
        x = self.layernorm2(x + self.dropout(attn_output))
        
        # FFN
        ffn_output = self.ffn(x)
        return self.layernorm3(x + self.dropout(ffn_output))

In [51]:
# Test the decoder block
def test_decoder_block(device):
    h = 8
    d_model = 512
    dff = 2048
    batch_size = 32

    decoder_block = Decoder(h,  d_model, device=device, dff=dff)

    output_emb = torch.randn(batch_size, 10, d_model, device=device)
    residual = torch.randn(batch_size, 10, d_model, device=device)

    out = decoder_block(output_emb, residual)

    assert out.size() == (batch_size, 10, d_model)

    print("Decoder block test passed")
test_decoder_block(device)


Decoder block test passed


In [64]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8,
                 num_encoder_layers=6, num_decoder_layers=6, d_ff=2048, dropout=0.1,
                 pad_idx=0, max_seq_len=1000, device='cuda'):
        """
        
        src stands for source/input
        tgt stands for target/output

        """
        super().__init__()

        self.pad_index = pad_idx
        self.d_model = d_model
        self.device = device


        # Embeddings
        self.src_embed = nn.Embedding(src_vocab_size, d_model, padding_idx=self.pad_index)
        self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model, padding_idx=self.pad_index)

        # Positional encoding
        self.pos_encoding = self.create_positional_encoding(max_seq_len, d_model)  # Max sequence length of 1000
        # rather than calculating this every forward pass we calculate this once and store it

        
        self.encoder = nn.ModuleList([Encoder(num_heads=num_heads, d_model=d_model, dropout=dropout, device=device, dff=d_ff) for _ in range(num_encoder_layers)])
        self.decoder = nn.ModuleList([Decoder(num_heads=num_heads, d_model=d_model,dropout= dropout, device=device, dff=d_ff) for _ in range(num_decoder_layers)])
        self.final = nn.Sequential(
            nn.Linear(d_model, tgt_vocab_size, device=device),
        )

        self.dropout = nn.Dropout(dropout) 

        self.init_parameters()

    def create_positional_encoding(self, max_seq_len, d_model):
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        return pe.unsqueeze(0)

    def create_padding_mask(self, src):
        return (src != self.pad_index).unsqueeze(1).unsqueeze(2)

    def create_causal_mask(self, tgt):
        seq_len = tgt.size(1)
        return torch.tril(torch.ones(seq_len, seq_len)).bool()

    def init_parameters(self):
        # Utilizaing Xavier Initialization
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def forward(self, input_emb, output_emb):

        src_mask = self.create_padding_mask(input_emb)
        tgt_mask = self.create_causal_mask(output_emb)
        
        # Embedding and positional encoding for source
        src_embedded = self.src_embed(input_emb) * math.sqrt(self.d_model)
        src_embedded = src_embedded + self.pos_encoding[:, :src_embedded.size(1)].to(src_embedded.device)
        enc_output = self.dropout(src_embedded)

        for enc in self.encoder:
            enc_output = enc(enc_output, src_mask)
        
        # Embedding and positional encoding for target
        tgt_embedded = self.tgt_embed(output_emb) * math.sqrt(self.d_model)
        tgt_embedded = tgt_embedded + self.pos_encoding[:, :tgt_embedded.size(1)].to(tgt_embedded.device)
        dec_output = self.dropout(tgt_embedded)

        # Decoder
        for dec in self.decoder:
            dec_output = dec(x = dec_output, residual = enc_output, src_mask=src_mask, tgt_mask=tgt_mask)
        
        return self.final(dec_output)
    
    def generate(self, inputs, max_len, sos_token, eos_token):
        self.eval()
        with torch.inference_mode():
            # Encode source sequence
            src_mask = self.create_padding_mask(inputs)
            
            src_embedded = self.src_embed(inputs) * math.sqrt(self.d_model)
            src_embedded = src_embedded + self.pos_encoding[:, :src_embedded.size(1)].to(src_embedded.device)
            enc_output = self.dropout(src_embedded)
            
            for enc in self.encoder:
                enc_output = enc(enc_output, src_mask)
            
            # Initialize target sequence with SOS token
            target = torch.full((inputs.size(0), 1), sos_token, dtype=torch.long, device=self.device)
            
            # Generate tokens one by one
            for _ in range(max_len - 1):
                tgt_mask = self.create_causal_mask(target)
                
                tgt_embedded = self.tgt_embed(target) * math.sqrt(self.d_model)
                tgt_embedded = tgt_embedded + self.pos_encoding[:, :tgt_embedded.size(1)].to(tgt_embedded.device)
                dec_output = self.dropout(tgt_embedded)
                
                for dec in self.decoder:
                    dec_output = dec(x = dec_output, residual = enc_output, src_mask=src_mask, tgt_mask=tgt_mask)
                
                output = self.final_layer(dec_output)
                next_token = output[:, -1].argmax(dim=-1).unsqueeze(1)
                target = torch.cat([target, next_token], dim=1)
                
                # Stop if EOS token is generated
                if next_token.item() == eos_token:
                    break
            
            return target



In [65]:
def test_transformer(device):
    src_vocab_size = 1000
    tgt_vocab_size = 1000
    d_model = 512
    num_heads = 8
    num_encoder_layers = 6
    num_decoder_layers = 6
    d_ff = 2048
    dropout = 0.1
    pad_idx = 0
    max_seq_len = 1000  
    model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, d_ff, dropout, pad_idx, max_seq_len, device).to(device)
    
    test_inp = torch.randint(0, src_vocab_size, (32, 10), device=device)
    test_out = torch.randint(0, tgt_vocab_size, (32, 10), device=device)
    
    out = model(test_inp, test_out)
    
    assert out.size() == (32, 10, tgt_vocab_size)
    
    print(model)
    print("Transformer test passed")

test_transformer(device)

Transformer(
  (src_embed): Embedding(1000, 512, padding_idx=0)
  (tgt_embed): Embedding(1000, 512, padding_idx=0)
  (encoder): ModuleList(
    (0-5): 6 x Encoder(
      (mha): MultiHeadAttention(
        (q_proj): Linear(in_features=512, out_features=512, bias=True)
        (k_proj): Linear(in_features=512, out_features=512, bias=True)
        (v_proj): Linear(in_features=512, out_features=512, bias=True)
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layernorm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ffn): FFN(
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (gelu): GELU(approximate='none')
      )
      (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
  )
  (decoder): ModuleList(
    (0-5): 6 x Decoder(
      (self_attn): MultiHeadAttention(
 