# Transformers Paper Implementation from scratch in Pytorch

<img src="resources/transformer.png" alt="Description" width="500" height="500">


In [1]:
!pip install torchtext==0.1.1

Collecting torchtext==0.1.1
  Downloading torchtext-0.1.1-py3-none-any.whl.metadata (386 bytes)
Downloading torchtext-0.1.1-py3-none-any.whl (24 kB)
Installing collected packages: torchtext
Successfully installed torchtext-0.1.1


In [2]:
import torchtext

In [3]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import math,re
import warnings
import numpy as np
import torchtext
import torch.optim as optim
warnings.simplefilter("ignore")
print(torch.__version__)

2.5.1+cu121


### Token Embedding in Transformers  

- This `Embedding` class converts token indices into meaningful dense vectors, making it easier for the transformer to understand text.  
- Think of it as a lookup table where each word (or subword) gets its own unique representation, which the model learns and refines during training.  


In [4]:
class Embedding(nn.Module):
    def __init__(self,vocab_size,d_model): 
        super().__init__()
        self.embedding_layer = nn.Embedding(vocab_size, d_model)
    def forward(self, x):
        return self.embedding_layer(x)

### Positional Encoding in Transformers  
  
- This implementation uses sine and cosine functions at different frequencies to create unique position-based patterns, helping the model distinguish word positions.  


In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self,d_model, max_len): 
        super().__init__()
        self.d_model = d_model
        pos_matrix = torch.zeros((max_len, d_model))
        
        for pos in range(max_len): 
            for i in range(0, d_model//2):
                pos_matrix[pos,i*2] = math.sin(pos/10000**((2*i)/d_model))
                pos_matrix[pos,i*2+1] = math.cos(pos/10000**((2*i+1)/d_model))

        pos_matrix = pos_matrix.unsqueeze(0) #batch_size, max_seq_len, d_model
        self.register_buffer('pos_matrix', pos_matrix)
        
    def forward(self, x):
        # print(f'In PosEnc.forward->x.shape: {x.shape}')
        x = x*math.sqrt(self.d_model)
        seq_len = x.shape[1] #TODO: 
        return x + self.pos_matrix[:,:seq_len] 
        

## Multi-Heads Part 

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self,d_model,num_heads): 
        super().__init__()
        assert d_model%num_heads == 0 , "d_model must be divisible by num_heads"

        self.d_key = d_model // num_heads

        # matrices for all heads are combined
        self.Q = nn.Linear(d_model,d_model, bias=False) 
        self.K = nn.Linear(d_model,d_model, bias=False)
        self.V = nn.Linear(d_model,d_model, bias=False)
        self.O = nn.Linear(d_model,d_model, bias=True)

        self.num_heads = num_heads
        self.d_model = d_model
        
    def forward(self,Q,K,V,mask = None):
        batch_size = K.shape[0]
        len_seq = K.shape[1]
        len_query = Q.shape[1]

        Q = self.Q(Q)
        K = self.K(K)
        V = self.V(V)

        Q = Q.view(batch_size, len_query, self.num_heads, self.d_key)
        K = K.view(batch_size, len_seq, self.num_heads, self.d_key)
        V = V.view(batch_size, len_seq, self.num_heads, self.d_key)

        # seperate out heads, heads as channel dimension
        K = K.transpose(1,2) # batch_size, num_heads, len_seq, d_key
        Q = Q.transpose(1,2)
        V = V.transpose(1,2) 

        # K.T
        K = K.transpose(-1,-2) #batch_size, num_heads, d_key, len_seq
        #TODO: Understand the multiplication here
        # Q*K.T
        raw_attention_matrix = torch.matmul(Q,K)
        if mask is not None: 
            raw_attention_matrix = raw_attention_matrix.masked_fill(mask == 0, -1e9)

        
        raw_attention_matrix = raw_attention_matrix/math.sqrt(self.d_key)
        scores = F.softmax(raw_attention_matrix, dim=-1) # seq_len * seq_len
        
        new_embeddings = torch.matmul(scores,V) # batch_size,num_heads,seq_len,d_model
        new_embeddings = new_embeddings.transpose(1,2) # batch_size,seq_len,num_heads,d_key
        new_embeddings = new_embeddings.contiguous().view(batch_size,len_seq,self.d_model)
        
        return self.O(new_embeddings) # batch_size, seq_len, d_model  


In [8]:
class FeedForward(nn.Module): 
    def __init__(self, d_model, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(d_model, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, d_model)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x        

## Encoder Layer

In [9]:
class EncoderLayer(nn.Module): 
    def __init__(self, d_model,num_heads,hidden_dim, dropout_rate = 0.2):
        super().__init__()
        self.att = MultiHeadAttention(d_model, num_heads)
        self.drop1 = nn.Dropout(dropout_rate)
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = FeedForward(d_model,hidden_dim)
        self.drop2 = nn.Dropout(dropout_rate)
        self.norm2 = nn.LayerNorm(d_model)
        

    def forward(self,x,mask=None):  # dropout is added before adding res 
        x = self.norm1(x+ self.drop1(self.att(x,x,x,mask)))
        x = self.norm2(x+ self.drop2(self.ff(x)))
        return x   

## Decoder Part

In [10]:
class DecoderLayer(nn.Module): 

    def __init__(self,d_model, num_heads, hidden_dim, dropout=0.2): 
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.cross_attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model,hidden_dim)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.drop1 = nn.Dropout(dropout)
        self.drop2 = nn.Dropout(dropout)
        self.drop3 = nn.Dropout(dropout)


    def forward(self,x,attention_output,source_mask,target_mask): 
        x = self.norm1(x + self.drop1(self.self_attention(x,x,x,target_mask)))
        x = self.norm2(x + self.drop2(self.cross_attention(x,attention_output,attention_output,source_mask)))
        x = self.norm3(x + self.drop3(self.feed_forward(x)))
        return x

## Let's pack it up into Transformer

In [11]:
class Transformer(nn.Module):
    def __init__(self,
                 d_model,
                 num_layers,
                 num_heads,
                 hidden_dim,
                 max_len,
                 source_vocab_size,
                 target_vocab_size,
                 dropout=0.2):
        super().__init__()
        # Embeddings
        self.encoder_emb = nn.Embedding(source_vocab_size, d_model)
        self.decoder_emb = nn.Embedding(target_vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
        # Encoder and decoder layers
        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads, hidden_dim, dropout) for _ in range(num_layers)]
        )
        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(d_model, num_heads, hidden_dim, dropout) for _ in range(num_layers)]
        )
        
        # Final linear layer mapping to target vocab size
        self.fc = nn.Linear(d_model, target_vocab_size)

    def generate_mask(self, src, tgt):
        # Example mask generation (adjust as needed)
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, source, target):
        # Generate masks
        source_mask, target_mask = self.generate_mask(source, target)
        
        # Apply embedding and positional encoding with dropout
        source_emb = self.dropout(self.pos_enc(self.encoder_emb(source)))
        target_emb = self.dropout(self.pos_enc(self.decoder_emb(target)))

        # Pass through encoder layers
        enc_output = source_emb
        for layer in self.encoder_layers:
            enc_output = layer(enc_output, source_mask)

        # Pass through decoder layers
        dec_output = target_emb
        for layer in self.decoder_layers:
            dec_output = layer(dec_output, enc_output, source_mask, target_mask)

        # Final linear mapping
        return self.fc(dec_output)


In [17]:
d_model=512
num_layers=6
num_heads=4
hidden_dim=2048
MAX_LEN = 100
source_vocab_size=1000
target_vocab_size=1000
dropout=0.1

In [18]:
transformer = Transformer(d_model=d_model,
                          num_layers=num_layers,
                          num_heads=num_heads,
                          hidden_dim=hidden_dim,
                          max_len = MAX_LEN,
                          source_vocab_size=source_vocab_size,
                          target_vocab_size=target_vocab_size,
                          dropout=dropout
                         )

# Simulate some random tokens

In [44]:
# Generate random sample data
source_data = torch.randint(1, 1000, (64, max_seq_length))  # (batch_size, seq_length)
target_data = torch.randint(1, 1000, (64, max_seq_length))  # (batch_size, seq_length)

# The training loop

In [None]:
# criterion = nn.CrossEntropyLoss(ignore_index=0)
# optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

# transformer.train()

# for epoch in range(100):
#     optimizer.zero_grad()
#     output = transformer(source_data, target_data[:, :-1])
#     loss = criterion(output.contiguous().view(-1, target_vocab_size), target_data[:, 1:].contiguous().view(-1))
#     loss.backward()
#     optimizer.step()
#     print(f"Epoch: {epoch+1}, Loss: {loss.item()}")