## Transformer Achtiecture


In [1]:
## Positional Embedding
import torch
import torch.nn as nn

class PositionalEmbedding(nn.Module):
    """Adds learnable positional information to embeddings"""
    def __init__(self, max_length, embed_dim, dropout=0.1):
        super().__init__()
        # Learnable positional embeddings (max_length x embed_dim)
        self.pos_embed = nn.Parameter(torch.randn(max_length, embed_dim) * 0.02)
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, X):
        # Add positional encoding to input embeddings, slice to match sequence length
        return self.dropout(X + self.pos_embed[:X.size(1)])
    
class MultiheadAttention(nn.Module):
    """Multi-head self-attention mechanism"""
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.h = num_heads  # Number of attention heads
        self.d = embed_dim // num_heads  # Dimension per head
        # Linear projections for queries, keys, values
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        # Output projection to combine heads
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def split_heads(self, X):
        """Split embedding dimension into multiple heads"""
        # (B, L, embed_dim) -> (B, L, h, d) -> (B, h, L, d)
        return X.view(X.size(0), X.size(1), self.h, self.d).transpose(1, 2)
    
    def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
        # Project and split into heads
        q = self.split_heads(self.q_proj(query))  # (B, h, Lq, d)
        k = self.split_heads(self.k_proj(key))    # (B, h, Lk, d)
        v = self.split_heads(self.v_proj(value))  # (B, h, Lv, d) where Lv=Lk
        
        # Compute attention scores: Q * K^T / sqrt(d)
        scores = q @ k.transpose(2, 3) / self.d**0.5  # (B, h, Lq, Lk)
        
        # Apply attention mask (for causal/future masking)
        if attn_mask is not None:
            # Set masked positions to -inf (will become 0 after softmax)
            scores = scores.masked_fill(attn_mask, -torch.inf)
        
        # Apply padding mask (ignore padding tokens)
        if key_padding_mask is not None:
            # Expand mask dimensions to match scores shape
            mask = key_padding_mask.unsqueeze(1).unsqueeze(2)  # (B, 1, 1, Lk)
            scores = scores.masked_fill(mask, -torch.inf)
        
        # Compute attention weights via softmax
        weights = scores.softmax(dim=-1)  # (B, h, Lq, Lk)
        
        # Apply dropout and multiply by values
        Z = self.dropout(weights) @ v  # (B, h, Lq, d)
        
        # Merge heads back together
        Z = Z.transpose(1, 2)  # (B, Lq, h, d)
        Z = Z.reshape(Z.size(0), Z.size(1), self.h * self.d)  # (B, Lq, embed_dim)
        
        # Final output projection and return with attention weights
        return (self.out_proj(Z), weights)


In [None]:

class TransformerEncoderLayer(nn.Module):
    """Single layer of transformer encoder (self-attention + feedforward)"""
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        # Multi-head self-attention
        self.self_attn = MultiheadAttention(d_model, nhead, dropout)
        # Two-layer feedforward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        # Layer normalization (applied after residual connections)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # Self-attention block with residual connection
        attn, _ = self.self_attn(src, src, src, attn_mask=src_mask,
                                 key_padding_mask=src_key_padding_mask)
        # Add & Norm: residual connection + layer normalization
        Z = self.norm1(src + self.dropout(attn))
        
        # Feedforward block: linear -> ReLU -> dropout -> linear -> dropout
        ff = self.dropout(self.linear2(self.dropout(self.linear1(Z).relu())))
        # Add & Norm: residual connection + layer normalization
        return self.norm2(Z + ff)


class TransformerDecoderLayer(nn.Module):
    """Single layer of transformer decoder (masked self-attention + cross-attention + feedforward)"""
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        # Masked self-attention (decoder attends to previous tokens)
        self.self_attn = MultiheadAttention(d_model, nhead, dropout)
        # Cross-attention (decoder attends to encoder output)
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout)
        self.dropout = nn.Dropout(dropout)
        # Two-layer feedforward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        # Three layer norms (one for each sub-layer)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # Masked self-attention (target attends to itself with causal mask)
        attn1, _ = self.self_attn(tgt, tgt, tgt,
                                  attn_mask=tgt_mask,
                                  key_padding_mask=tgt_key_padding_mask)
        # Add & Norm
        Z = self.norm1(tgt + self.dropout(attn1))
        
        # Cross-attention (decoder queries encoder's memory)
        attn2, _ = self.multihead_attn(Z, memory, memory, attn_mask=memory_mask,
                                       key_padding_mask=memory_key_padding_mask)
        # Add & Norm
        Z = self.norm2(Z + self.dropout(attn2))
        
        # Feedforward block
        ff = self.dropout(self.linear2(self.dropout(self.linear1(Z).relu())))
        # Add & Norm
        return self.norm3(Z + ff)


In [None]:

from copy import deepcopy

class TransformerEncoder(nn.Module):
    """Stack of N encoder layers"""
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        # Create N independent copies of the encoder layer
        self.layers = nn.ModuleList([deepcopy(encoder_layer)
                                     for _ in range(num_layers)])
        # Optional final layer normalization
        self.norm = norm

    def forward(self, src, mask=None, src_key_padding_mask=None):
        Z = src  # Start with input
        # Pass through each encoder layer sequentially
        for layer in self.layers:
            Z = layer(Z, mask, src_key_padding_mask)
        # Apply final normalization if provided
        if self.norm is not None:
            Z = self.norm(Z)
        return Z


class TransformerDecoder(nn.Module):
    """Stack of N decoder layers"""
    def __init__(self, decoder_layer, num_layers, norm=None):
        super().__init__()
        # Create N independent copies of the decoder layer
        self.layers = nn.ModuleList([deepcopy(decoder_layer)
                                     for _ in range(num_layers)])
        # Optional final layer normalization
        self.norm = norm

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None):
        Z = tgt  # Start with target input
        # Pass through each decoder layer sequentially
        for layer in self.layers:
            Z = layer(Z, memory, tgt_mask, memory_mask,
                      tgt_key_padding_mask, memory_key_padding_mask)
        # Apply final normalization if provided
        if self.norm is not None:
            Z = self.norm(Z)
        return Z


In [None]:
class Transformer(nn.Module):
    """Complete transformer model (encoder + decoder)"""
    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        # Create single encoder layer template
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout)
        norm1 = nn.LayerNorm(d_model)
        # Build encoder stack (6 layers by default)
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers,
                                          norm1)
        
        # Create single decoder layer template
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
                                                dropout)
        norm2 = nn.LayerNorm(d_model)
        # Build decoder stack (6 layers by default)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers,
                                          norm2)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None,
                src_key_padding_mask=None, tgt_key_padding_mask=None,
                memory_key_padding_mask=None):
        # Encode source sequence into memory
        memory = self.encoder(src, src_mask, src_key_padding_mask)
        # Decode target sequence using encoder's memory
        output = self.decoder(tgt, memory, tgt_mask, memory_mask,
                              tgt_key_padding_mask, memory_key_padding_mask)
        return output