# Transformer implementation Using Pytorch

## Below is the Todo-list
Will be updated as we go about the code

In [2]:
#TODO: figure out inference for transformer model - use of self.training
#TODO: look into torchtext
#TODO: get the data
#TODO: pre-process the data - tokenize it as well
#TODO: implement the training loop
#TODO: Train the model

## Implementation Starts here

In [13]:
# Author: Rishabh Agarwal
# All the imports for the code

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tokenizers import Tokenizer
from torch.utils.data import Dataset, DataLoader

import math

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [6]:
# Feed Forward Network
# Used in the Transformer Block
class FFN(nn.Module):
    def __init__(self, device, d_model=512, dff = 2048):
        super().__init__()
        self.linear1 = nn.Linear(d_model, dff, device=device)
        self.linear2 = nn.Linear(dff, d_model, device=device)
        self.gelu = nn.GELU()
        # One reason of choosing GELU - not in the initial implementation of the transformer in the paper
        """
        relu can suffer from "problems where significant amount of neuron in the network become zero and don’t practically do anything." 
        gelu is smoother near zero and "is differentiable in all ranges, and allows to have gradients(although small) in negative range" which helps with this problem.
        """

        self.device = device

    def forward(self, x):
        x = x.to(self.device)
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.linear2(x)
        return x 

In [17]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, d_model, dropout=0.1, device = 'cuda'):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.device = device
        self.d_k = d_model // num_heads # using d_model and number of heads for d_k calculation
        self.num_heads = num_heads
        
        # Combined linear projections for Q, K, V
        self.q_proj = nn.Linear(d_model, d_model,device=self.device)
        self.k_proj = nn.Linear(d_model, d_model, device=self.device)
        self.v_proj = nn.Linear(d_model, d_model, device=self.device)
        self.out_proj = nn.Linear(d_model, d_model, device= self.device)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.d_k)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # Linear projections and reshape for attention
        q = self.q_proj(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        k = self.k_proj(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = self.v_proj(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Scaled dot-product attention (can use internal implementation as well)
        scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        # Apply attention to values
        context = torch.matmul(attn, v)
        
        # Reshape and apply output projection
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        return self.out_proj(context)

In [22]:
# Test the MultiHeadedAttention class
def test_multi_headed_attention(device):
    h = 8
    d_model = 512
    batch_size = 32

    mha = MultiHeadAttention(num_heads=h, d_model= d_model, device=device)

    Q = torch.randn(batch_size, 10, d_model, device=device)
    K = torch.randn(batch_size, 10, d_model, device=device)
    V = torch.randn(batch_size, 10, d_model, device=device)

    out = mha(Q, K, V)

    assert out.size() == (batch_size, 10, d_model)

    print("MultiHeadedAttention test passed")
test_multi_headed_attention(device)

MultiHeadedAttention test passed


In [27]:
# One of N Encoder Blocks in the Transformer
class Encoder(nn.Module):
    def __init__(self, num_heads, d_model, device, dropout=0.1):
        super().__init__()
        self.mha = MultiHeadAttention(num_heads, d_model, dropout, device=device)
        self.layernorm1 = nn.LayerNorm(d_model, device=device)

        self.ffn = FFN(d_model=d_model, device=device)
        self.layernorm2 = nn.LayerNorm(d_model, device =device)

        self.device = device
    def forward(self, input_emb, mask=None):
        input_emb = input_emb.to(self.device)
        x = self.mha(input_emb, input_emb, input_emb, mask)
        x = self.layernorm1(x + input_emb)
        ffn_out = self.ffn(x).to(self.device)
        x = self.layernorm2(x + ffn_out)
        return x

In [28]:
# Test the encoder block
def test_encoder_block(device):
    num_heads = 8
    d_model = 512
    batch_size = 32

    encoder_block = Encoder(num_heads, d_model, device= device)

    input_emb = torch.randn(batch_size, 10, d_model, device=device)

    out = encoder_block(input_emb)

    assert out.size() == (batch_size, 10, d_model)

    print("Encoder block test passed")
test_encoder_block(device)

Encoder block test passed


In [29]:
# One of N Decoders' Implementation in the Transformer
class Decoder(nn.Module):
    def __init__(self,num_heads, d_model, device, dropout=0.1 ):
        super().__init__()
        self.self_attn = MultiHeadAttention(num_heads, d_model, dropout, device=device)
        self.cross_attn = MultiHeadAttention(num_heads, d_model,dropout, device=device)

        self.ffn = FFN(d_model =d_model, device=device)

        self.layernorm1 = nn.LayerNorm(d_model, device=device)
        self.layernorm2 = nn.LayerNorm(d_model, device=device)
        self.layernorm3 = nn.LayerNorm(d_model, device=device)

        self.device = device
        

        self.dropout = nn.Dropout(dropout)
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # Self-attention
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.layernorm1(x + self.dropout(attn_output))
        
        # Cross-attention
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.layernorm2(x + self.dropout(attn_output))
        
        # FFN
        ffn_output = self.ffn(x)
        return self.layernorm3(x + self.dropout(ffn_output))

In [30]:
# Test the decoder block
def test_decoder_block(device):
    h = 8
    d_model = 512
    batch_size = 32

    decoder_block = Decoder(h,  d_model, device=device)

    output_emb = torch.randn(batch_size, 10, d_model, device=device)
    residual = torch.randn(batch_size, 10, d_model, device=device)

    out = decoder_block(output_emb, residual)

    assert out.size() == (batch_size, 10, d_model)

    print("Decoder block test passed")
test_decoder_block(device)


AttributeError: 'Decoder' object has no attribute 'dropout'

In [18]:
class Transformer(nn.Module):
    def __init__(self, n, h, d_k, seq_len, d_model, device, sos_token=None, dropout=0.1):
        super().__init__()
        self.sos_token = sos_token
        self.embed1 = nn.Embedding(seq_len, d_model, device=device)
        self.embed2 = nn.Embedding(seq_len, d_model, device=device)

        self.n = n
        self.d_model = d_model
        self.device = device
        self.encoder = nn.ModuleList([Encoder(h, d_k, d_model, device, dropout) for _ in range(n)])
        self.decoder = nn.ModuleList([Decoder(h, d_k, d_model, device, dropout) for _ in range(n)])
        self.final = nn.Sequential(
            nn.Linear(d_model, d_model, device=device),
        )

    def shift_right(self, x):
        """
        Shifts the input tensor to the right - removes the last element for prediction
        """

        
        batch_size, _, emb = x.shape
        if self.sos_token is not None:
            sos_column = torch.full((batch_size, 1), self.sos_token, dtype=x.dtype, device=x.device)
        else:
            sos_column = torch.zeros_like(x[:, :1, :])
        shifted = torch.cat([sos_column, x[:, :-1, :]], dim=1) # gets 0 in the first column and removes the last column (token)
        return shifted.to(self.device)
    
    def positional_embedding(self, input_emb, output_emb):
        _, seq_len, _ = input_emb.shape  # Sequence length

         # Create positional encoding matrix (seq_len, d_model)
        log_10000 = torch.log(torch.tensor(10000.0))
        position = torch.arange(seq_len, dtype=torch.float32).unsqueeze(1)  # Shape: (seq_len, 1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(log_10000 / self.d_model))  # Shape: (d_model // 2)

        # Compute sine and cosine components
        pos_enc = torch.zeros((seq_len, self.d_model), dtype=torch.float32)
        pos_enc[:, 0::2] = torch.sin(position * div_term)  # Apply sine to even indices
        pos_enc[:, 1::2] = torch.cos(position * div_term)  # Apply cosine to odd indices

        # Broadcast and add positional encoding to input and output embeddings
        pos_enc = pos_enc.unsqueeze(0)  # Shape: (1, seq_len, d_model), ready for broadcasting
        input_emb = input_emb + pos_enc.to(self.device)
        output_emb = output_emb + pos_enc.to(self.device)
        
        return input_emb, output_emb
    
    def infer(self):
        #TODO: needs to be changed to accomodate self.training for different inference methods
        #TODO: figure out inference for transformer model - use of self.training
        ...
    
    def forward(self, inp, out, targets =None):

        if self.training:
            input_emb = self.embed1(inp)

            output_emb = self.shift_right(out)
            output_emb = self.embed2(output_emb)
        
            input_emb, output_emb = self.positional_embedding(input_emb, output_emb) # apply embeddings
            # pass through encoder and decoder stack
            for i in range(self.n):
                input_emb = self.encoder[i](input_emb)
                output_emb = self.decoder[i](output_emb, input_emb)
            logits = self.final(output_emb) # compute the probabilities
            loss = F.cross_entropy(input=logits, target=targets) if targets is not None else None
            # We do not apply softmax in the last layer as cross entropy expects raw logits 
            return logits , loss
        else:
            self.infer() #TODO: to be implemented

## Data Processing and Training