In [1]:
import torch
import torch.nn as nn
import math
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
import time

In [2]:
class PositionalEncoding(nn.Module):
    """
    Implements the positional encoding layer as described in the Transformer model.
    Adds positional information to token embeddings.
    """
    def __init__(self, seq_len, d_model, vocab_size):
        """
        Args:
            seq_len (int): Maximum length of input sequences.
            d_model (int): Dimensionality of the embeddings.
            vocab_size (int): Size of the vocabulary for embedding lookup.
        """
        super(PositionalEncoding, self).__init__()
        
        # Create a matrix of [seq_len, d_model] for positional encodings
        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len).unsqueeze(1)  # Shape: (seq_len, 1)
        
        # Compute the div_term for sine/cosine arguments
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        
        # Apply sine to even indices, cosine to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Register as buffer so it is saved with the model but not a parameter
        self.register_buffer("pos_enc", pe.unsqueeze(0))  # Shape: (1, seq_len, d_model)
        
        # Token embedding layer
        self.embedder = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
        
    def forward(self, x):
        """
        Args:
            x (Tensor): Input indices of shape (batch_size, seq_len)
        Returns:
            Tensor: Embedded input + positional encoding, shape (batch_size, seq_len, d_model)
        """
        x = self.embedder(x)  # Shape: (batch_size, seq_len, d_model)
        # Add positional encoding (broadcasts over batch dimension)
        return x + self.pos_enc[:, :x.size(1), :]

In [3]:
class MultiHeadAttention(nn.Module):
    """
    Implements the Multi-Head Attention mechanism as used in Transformers.
    """
    def __init__(self, d_model, num_heads):
        """
        Args:
            d_model (int): Dimensionality of the model (embedding size).
            num_heads (int): Number of attention heads.
        """
        super().__init__()
        self.q_linear = nn.Linear(d_model, d_model)  # Linear layer for queries
        self.k_linear = nn.Linear(d_model, d_model)  # Linear layer for keys
        self.v_linear = nn.Linear(d_model, d_model)  # Linear layer for values
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dims = d_model // num_heads
        
        # Ensure d_model is divisible by num_heads
        if self.d_model % self.num_heads != 0:
            raise ValueError("d_model must be divisible by num_heads")
        
    def forward(self, query, key, value, mask=None):
        """
        Args:
            query, key, value (Tensor): Shape (batch_size, seq_len, d_model)
            mask (Tensor, optional): Mask tensor, shape broadcastable to attention matrix.
            
        Returns:
            Tensor: Output after multi-head attention, shape (batch_size, seq_len, d_model)
        """
        # Linear projections
        q = self.q_linear(query)
        k = self.k_linear(key)
        v = self.v_linear(value)
        
        batch_size, q_len, _ = query.shape
        k_len = key.shape[1]
        v_len = value.shape[1]
        
        # Reshape and permute for multi-head attention
        # New shape: (batch_size, num_heads, seq_len, head_dims)
        q = q.reshape(batch_size, q_len, self.num_heads, self.head_dims).permute(0, 2, 1, 3)
        k = k.reshape(batch_size, k_len, self.num_heads, self.head_dims).permute(0, 2, 1, 3)
        v = v.reshape(batch_size, v_len, self.num_heads, self.head_dims).permute(0, 2, 1, 3)
        
        # Scaled dot-product attention
        attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dims)  # (batch, heads, q_len, k_len)
        
        # Apply mask if provided
        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf'))  
        
        attn = torch.softmax(attn, dim=-1)
        attn_out = torch.matmul(attn, v)  # (batch, heads, q_len, head_dims)
        
        # Concatenate heads and reshape to original dimensions
        attn_out = attn_out.permute(0, 2, 1, 3).reshape(batch_size, q_len, self.d_model)
        return attn_out

In [4]:
class FFN(nn.Module):
    """
    Implements the Feed-Forward Network (FFN) used in Transformer blocks.
    Applies two linear transformations with a ReLU activation and dropout in between.
    """
    def __init__(self, d_model): 
        """
        Args:
            d_model (int): Dimensionality of the model (embedding size).
        """
        super().__init__()
        # First linear layer expands dimensionality (e.g., 512 -> 2048 as in original Transformer paper)
        self.layer1 = nn.Linear(d_model, d_model * 4)
        # Second linear layer projects back to original dimension (e.g., 2048 -> 512)
        self.layer2 = nn.Linear(d_model * 4, d_model)
        self.relu = nn.ReLU()
        self.drop = nn.Dropout(p=0.1)
        
    def forward(self, x):
        """
        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_len, d_model)
        Returns:
            Tensor: Output tensor of shape (batch_size, seq_len, d_model)
        """
        x = self.layer1(x)
        x = self.drop(self.relu(x))
        x = self.layer2(x)
        return x

In [5]:
class Encoder(nn.Module):
    """
    Implements a single Encoder block as used in the Transformer architecture.
    Consists of Multi-Head Attention, followed by Feed-Forward Network with Layer Normalization and residual connections.
    """
    def __init__(self, d_model, num_heads):
        """
        Args:
            d_model (int): Dimensionality of the model (embedding size).
            num_heads (int): Number of attention heads.
        """
        super().__init__()
        self.mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)  # Multi-Head Attention layer
        self.ffn = FFN(d_model)                                              # Feed-Forward Network
        self.norm1 = nn.LayerNorm(d_model)                                   # LayerNorm after attention
        self.norm2 = nn.LayerNorm(d_model)                                   # LayerNorm after FFN
        
    def forward(self, x, mask=None):
        """
        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_len, d_model)
            mask (Tensor, optional): Mask tensor for attention, shape broadcastable to attention matrix.
        Returns:
            Tensor: Output tensor of shape (batch_size, seq_len, d_model)
        """
        # Self-attention (q, k, v are all x in encoder)
        attn = self.mha(x, x, x, mask)
        # Residual connection + LayerNorm
        x = self.norm1(x + attn)
        # Feed-Forward Network + Residual connection + LayerNorm
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x

In [6]:
class EncoderBlock(nn.Module):
    """
    Implements the full Encoder Block consisting of positional encoding and a stack of Encoder layers.
    """
    def __init__(self, d_model, num_heads, num_layers, src_vocab_size, seq_len):
        """
        Args:
            d_model (int): Dimensionality of the model (embedding size).
            num_heads (int): Number of attention heads.
            num_layers (int): Number of Encoder layers to stack.
            src_vocab_size (int): Vocabulary size for source language.
            seq_len (int): Maximum input sequence length.
        """
        super().__init__()
        # Positional encoding with input embedding
        self.pos_enc = PositionalEncoding(seq_len=seq_len, d_model=d_model, vocab_size=src_vocab_size)
        
        # Stack of Encoder layers
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(Encoder(d_model=d_model,num_heads=num_heads))
        
        # Final layer normalization
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, x, mask=None):
        """
        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_len)
            mask (Tensor, optional): Mask tensor for attention.
        Returns:
            Tensor: Output tensor of shape (batch_size, seq_len, d_model)
        """
        x = self.pos_enc(x)  # Add token embeddings and positional encodings
        for layer in self.layers:
            x = layer(x, mask)  # Pass through each encoder layer
        return self.norm(x)     # Final normalization

In [7]:
class MaskedMultiheadAttention(nn.Module):
    """
    Implements Multi-Head Self-Attention with causal masking (for decoder blocks in Transformers).
    Ensures that each position can only attend to previous positions (leftward/self-masking).
    """
    def __init__(self, d_model, num_heads):
        """
        Args:
            d_model (int): Dimensionality of the model (embedding size).
            num_heads (int): Number of attention heads.
        """
        super().__init__()
        self.q_linear = nn.Linear(d_model, d_model)  # Linear layer for queries
        self.k_linear = nn.Linear(d_model, d_model)  # Linear layer for keys
        self.v_linear = nn.Linear(d_model, d_model)  # Linear layer for values
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.d_model = d_model
        if d_model % num_heads != 0:
            raise ValueError("d_model not divisible by num_heads")
            
    def forward(self, x, mask=None):
        """
        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_len, d_model)
            mask (Tensor, optional): Optional mask tensor (e.g., for padding).
        Returns:
            Tensor: Output tensor after masked multi-head attention, shape (batch_size, seq_len, d_model)
        """
        # Linear projections for Q, K, V
        q = self.q_linear(x)
        k = self.k_linear(x)
        v = self.v_linear(x)
        
        batch_size, seq_len, _ = x.shape
        
        # Reshape and permute for multi-head attention
        q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        
        # Scaled dot-product attention
        attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)  # (batch, heads, seq_len, seq_len)
        
        # Apply causal (look-ahead) mask if not provided
        if mask is None:
            mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
            attn = attn.masked_fill(mask == 1, float("-inf"))
        
        
        attn = torch.softmax(attn, dim=-1)
        attn = torch.matmul(attn, v)
        
        # Reshape back to (batch_size, seq_len, d_model)
        out = attn.permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.d_model)
        return out

In [8]:
class Decoder(nn.Module):
    """
    Implements a single Decoder block for Transformer models.
    Includes masked self-attention, encoder-decoder (cross) attention, and a feed-forward network, each with residual connection and layer normalization.
    """
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.masked_attn = MaskedMultiheadAttention(d_model=d_model, num_heads=num_heads)  # Masked self-attention 
        self.multihead_cross_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads)  # Encoder-decoder cross attention
        self.ffn = FFN(d_model=d_model)     # Feed-forward network
        self.norm1 = nn.LayerNorm(d_model)  # Norm after masked attention
        self.norm2 = nn.LayerNorm(d_model)  # Norm after cross attention
        self.norm3 = nn.LayerNorm(d_model)  # Norm after FFN
        self.drop = nn.Dropout(p=0.2)       # Final dropout

    def forward(self, x, encoder_out, mask=None):
        """
        Args:
            x (Tensor): Decoder input (batch_size, seq_len, d_model)
            encoder_out (Tensor): Encoder output (batch_size, seq_len, d_model)
            mask (Tensor, optional): Mask tensor for causal attention
        Returns:
            Tensor: Output tensor (batch_size, seq_len, d_model)
        """
        # Masked self-attention with residual connection and normalization
        x1 = self.norm1(x)
        attn1 = self.masked_attn(x1, mask)
        x = x + attn1
        
        # Encoder-decoder cross attention (query from decoder, key/value from encoder)
        x2 = self.norm2(x)
        attn2 = self.multihead_cross_attn(query=x2, key=encoder_out, value=encoder_out)
        x = x + attn2
        
        # Feed-forward network with residual connection and normalization
        x3 = self.norm3(x)
        ffn_out = self.ffn(x3)
        x = x + ffn_out

        x = self.drop(x)
        return x

In [10]:
class DecoderBlock(nn.Module):
    """
    Stacks multiple Decoder layers and applies positional encoding for the target sequence.
    """
    def __init__(self, num_heads, d_model, decoder_layer, seq_len, trg_vocab):
        super().__init__()
        self.pos = PositionalEncoding(seq_len=seq_len, d_model=d_model, vocab_size=trg_vocab)  # Target embedding + position
        self.layer = nn.ModuleList()
        for _ in range(decoder_layer):
            self.layer.append(Decoder(d_model = d_model,num_heads=num_heads))
          # Stack of decoder layers
                
    def forward(self, x, encoder_out, mask=None):
        x = self.pos(x)
        for layer in self.layer:
            x = layer(x, encoder_out, mask)
        return x

In [11]:
class Transformer(nn.Module):
    """
    Full Transformer model with encoder, decoder, and final output layer.
    """
    def __init__(self, d_model, num_heads, num_encoder_layer, src_vocab_size, trg_vocab_size,
                 seq_len, decoder_layer):
        super().__init__()
        self.enc = EncoderBlock(
            d_model=d_model,
            num_heads=num_heads,
            num_layers=num_encoder_layer,
            src_vocab_size=src_vocab_size,
            seq_len=seq_len
        )
        self.dec = DecoderBlock(
            num_heads=num_heads,
            d_model=d_model,
            decoder_layer=decoder_layer,
            seq_len=seq_len,
            trg_vocab=trg_vocab_size
        )
        self.fc_out = nn.Linear(d_model, trg_vocab_size)  # Output logits for target vocab
        self.dropout = nn.Dropout(p=0.1)
        
    def forward(self, source, target, mask=None):
        enc_out = self.enc(source)
        dec_out = self.dec(target, enc_out, mask)
        out = self.fc_out(dec_out)
        logits = self.dropout(out)
        return logits  # (batch, trg_seq_len, trg_vocab_size)