### Transformer

<div align="center">
  <img src="https://machinelearningmastery.com/wp-content/uploads/2021/08/attention_research_1.png" alt="Transformer" width="300">
</div>

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy

In [34]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        """
        Args:
            vocab_size: size of vocabulary
            embed_dim: dimension of embeddings
        """
        super().__init__()
        self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim) # Embedding layer
        self.embed_dim = embed_dim

        # Initialize weights to improve training
        self.init_weights()
        
    def init_weights(self):
        """Initialize embedding weights using normal distribution with small standard deviation"""
        nn.init.normal_(self.embed.weight, mean=0, std=0.02)

    def forward(self, x):
        """
        Args:
            x: Input token indices [batch_size, seq_len]
        Returns:
            embeddings: Token embeddings [batch_size, seq_len, embed_dim]
        """
        # Scale embeddings by sqrt(embed_dim) to stabilize gradients
        return self.embed(x) * math.sqrt(self.embed_dim)

In [35]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_seq_len):
        super().__init__()
        # Create a matrix of shape (max_seq_len, embed_dim)
        pe = torch.zeros(max_seq_len, embed_dim)
        # Create a vector of shape (max_seq_len, 1)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)  #(max_seq_len, 1)
        # Calculate the division term: 10000^(2i/embed_dim)
        div_term = torch.pow(10000, torch.arange(0, embed_dim, 2).float() / embed_dim)  #(embed_dim/2)

        # Apply sine to even indices and cosine to odd indices
        pe[:, 0::2] = torch.sin(position / div_term) #(max_seq_len, embed_dim/2)
        pe[:, 1::2] = torch.cos(position / div_term) #(max_seq_len, embed_dim/2)

        # Add a batch dimension
        pe = pe.unsqueeze(0) #(1, max_seq_len, embed_dim)

        # Register buffer makes the parameter persistent but not a model parameter
        # that is updated during training
        self.register_buffer('pe', pe)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        # Add positional encoding to the input embeddings
        x = x + self.pe[:, :x.shape[1], :] # `pe` is not trainable but moves with the model
        return self.dropout(x)

In [36]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
        
        # Linear projections for Q, K, V
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        
        # Final output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None, return_attention=False):
        """
        Args:
            query: Query tensor [batch_size, query_len, embed_dim]
            key: Key tensor [batch_size, key_len, embed_dim]
            value: Value tensor [batch_size, value_len, embed_dim]
            mask: Optional mask tensor for masked attention
            return_attention: Whether to return attention weights
        Returns:
            output: Attention output [batch_size, query_len, embed_dim]
            attention_weights: (Optional) Attention weights
        """
        batch_size = query.size(0) 
        query_len = query.size(1)
        key_len = key.size(1)
        value_len = value.size(1)
        
        # Compute Q, K, V
        q = self.q_linear(query)  # [batch_size, query_len, embed_dim]
        k = self.k_linear(key)    # [batch_size, key_len, embed_dim]
        v = self.v_linear(value)  # [batch_size, value_len, embed_dim]
        
        # Reshape for multi-head attention
        # Split embed_dim into num_heads Ã— head_dim
        q = q.view(batch_size, query_len, self.num_heads, self.head_dim).transpose(1, 2) # q shape: [batch_size, num_heads, query_len, head_dim]
        k = k.view(batch_size, key_len, self.num_heads, self.head_dim).transpose(1, 2) # k shape: [batch_size, num_heads, key_len, head_dim]
        v = v.view(batch_size, value_len, self.num_heads, self.head_dim).transpose(1, 2) # v shape: [batch_size, num_heads, value_len, head_dim]
        
        # Calculate attention scores
        attn_scores = torch.matmul(q, k.transpose(-2, -1))  # [batch_size, num_heads, query_len, key_len]
        
        # Scale attention scores
        attn_scores = attn_scores / math.sqrt(self.head_dim)
        
        # Apply mask if provided
        # mask shape: [batch_size, seq_len] or [batch_size, 1, seq_len] or [batch_size, seq_len, seq_len]
        if mask is not None:
            # Expand mask to match the attention scores dimensions
            if mask.dim() == 2:
                # [batch_size, seq_len] -> [batch_size, 1, 1, seq_len]
                mask = mask.unsqueeze(1).unsqueeze(2)
            elif mask.dim() == 3 and mask.size(1) == 1:
                # [batch_size, 1, seq_len] -> [batch_size, 1, 1, seq_len]
                mask = mask.unsqueeze(1)
            elif mask.dim() == 3:
                # [batch_size, seq_len, seq_len] -> [batch_size, 1, seq_len, seq_len]
                mask = mask.unsqueeze(1)
            
            # Apply mask by setting masked positions to -inf before softmax
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf')) # [batch_size, num_heads, seq_len, seq_len]
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(attn_scores, dim=-1)  # [batch_size, num_heads, seq_len, seq_len]
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention weights to values
        attn_output = torch.matmul(attention_weights, v)  # [batch_size, num_heads, seq_len, head_dim]
        
        # Reshape back to original dimension
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_len, self.embed_dim)
        # attn_output shape: [batch_size, seq_len, embed_dim]
        
        # Final linear projection
        output = self.out_proj(attn_output)  # [batch_size, seq_len, embed_dim]
        
        if return_attention:
            return output, attention_weights
        return output

In [37]:
class FeedForward(nn.Module):
    def __init__(self, embed_dim, ff_hidden_dim, dropout=0.1):
        """
        Position-wise feed-forward network.
        
        Args:
            embed_dim: Input and output dimension
            ff_hidden_dim: Hidden layer dimension
            dropout: Dropout probability
        """
        super().__init__()
        self.linear1 = nn.Linear(embed_dim, ff_hidden_dim)
        self.linear2 = nn.Linear(ff_hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        """
        Forward pass of feed-forward network.
        
        Args:
            x: Input tensor [batch_size, seq_len, embed_dim]
            
        Returns:
            output: Output tensor [batch_size, seq_len, embed_dim]
        """
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

In [38]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, dropout=0.1):
        """
        Single transformer encoder layer with self-attention and feed-forward network.
        
        Args:
            embed_dim: Dimension of embeddings
            num_heads: Number of attention heads
            ff_hidden_dim: Hidden dimension of feed-forward network
            dropout: Dropout probability
        """
        super().__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.ff = FeedForward(embed_dim, ff_hidden_dim, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: Input tensor [batch_size, seq_len, embed_dim]
            mask: Optional attention mask
            
        Returns:
            x: Output tensor [batch_size, seq_len, embed_dim]
        """
        # Self-attention block with residual connection and layer normalization
        residual = x
        x = self.norm1(x)
        x = residual + self.dropout(self.self_attn(x, x, x, mask))
        
        # Feed-forward block with residual connection and layer normalization
        residual = x
        x = self.norm2(x)
        x = residual + self.dropout(self.ff(x))
        
        return x

In [39]:
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, dropout=0.1):
        """
        Single transformer decoder layer with masked self-attention, 
        cross-attention, and feed-forward network.
        
        Args:
            embed_dim: Dimension of embeddings
            num_heads: Number of attention heads
            ff_hidden_dim: Hidden dimension of feed-forward network
            dropout: Dropout probability
        """
        super().__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        
        self.cross_attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        self.ff = FeedForward(embed_dim, ff_hidden_dim, dropout)
        self.norm3 = nn.LayerNorm(embed_dim)
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, self_attn_mask=None, cross_attn_mask=None):
        """
        Forward pass of decoder layer.
        
        Args:
            x: Input tensor [batch_size, seq_len, embed_dim]
            enc_output: Encoder output [batch_size, enc_seq_len, embed_dim]
            self_attn_mask: Mask for self-attention (usually causal mask)
            cross_attn_mask: Mask for cross-attention
            
        Returns:
            x: Output tensor [batch_size, seq_len, embed_dim]
        """
        # Self-attention block with residual connection and layer normalization
        residual = x
        x = self.norm1(x)
        x = residual + self.dropout(self.self_attn(x, x, x, self_attn_mask))
        
        # Cross-attention block with residual connection and layer normalization
        residual = x
        x = self.norm2(x)
        x = residual + self.dropout(self.cross_attn(x, enc_output, enc_output, cross_attn_mask))
        
        # Feed-forward block with residual connection and layer normalization
        residual = x
        x = self.norm3(x)
        x = residual + self.dropout(self.ff(x))
        
        return x

In [40]:
class TransformerEncoder(nn.Module):
    def __init__(self, max_seq_len, vocab_size, embed_dim, num_heads, ff_hidden_dim, num_layers, dropout=0.1):
        """
        Full transformer encoder consisting of embedding layer, positional encoding, 
        and stack of encoder layers.
        
        Args:
            max_seq_len: Maximum sequence length
            vocab_size: Size of vocabulary
            embed_dim: Dimension of embeddings
            num_heads: Number of attention heads
            ff_hidden_dim: Hidden dimension of feed-forward network
            num_layers: Number of encoder layers
            dropout: Dropout probability
        """
        super().__init__()
        self.embedding = Embedding(vocab_size, embed_dim)
        self.pos_encoding = PositionalEncoding(embed_dim, max_seq_len)
        
        # Create stack of encoder layers
        self.layers = nn.ModuleList([
            EncoderLayer(embed_dim, num_heads, ff_hidden_dim, dropout)
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x, mask=None):
        """
        Forward pass of the encoder.
        
        Args:
            x: Input token indices [batch_size, seq_len]
            mask: Optional padding mask
            
        Returns:
            x: Encoder output [batch_size, seq_len, embed_dim]
        """
        # Convert tokens to embeddings and add positional encoding
        x = self.embedding(x)
        x = self.pos_encoding(x)
        
        # Pass through each encoder layer
        for layer in self.layers:
            x = layer(x, mask)
            
        # Final layer normalization
        return self.norm(x)

In [41]:
class TransformerDecoder(nn.Module):
    def __init__(self, max_seq_len, vocab_size, embed_dim, num_heads, ff_hidden_dim, num_layers, dropout=0.1):
        """
        Full transformer decoder consisting of embedding layer, positional encoding,
        and stack of decoder layers.
        
        Args:
            max_seq_len: Maximum sequence length
            vocab_size: Size of vocabulary
            embed_dim: Dimension of embeddings
            num_heads: Number of attention heads
            ff_hidden_dim: Hidden dimension of feed-forward network
            num_layers: Number of decoder layers
            dropout: Dropout probability
        """
        super().__init__()
        self.embedding = Embedding(vocab_size, embed_dim)
        self.pos_encoding = PositionalEncoding(embed_dim, max_seq_len)
        
        # Create stack of decoder layers
        self.layers = nn.ModuleList([
            DecoderLayer(embed_dim, num_heads, ff_hidden_dim, dropout)
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        self.output_proj = nn.Linear(embed_dim, vocab_size)
        
    def forward(self, x, enc_output, self_attn_mask=None, cross_attn_mask=None):
        """
        Forward pass of the decoder.
        
        Args:
            x: Input token indices [batch_size, seq_len]
            enc_output: Encoder output [batch_size, enc_seq_len, embed_dim]
            self_attn_mask: Mask for self-attention (usually causal mask)
            cross_attn_mask: Mask for cross-attention
            
        Returns:
            output: Output logits [batch_size, seq_len, vocab_size]
        """
        # Convert tokens to embeddings and add positional encoding
        x = self.embedding(x)
        x = self.pos_encoding(x)
        
        # Pass through each decoder layer
        for layer in self.layers:
            x = layer(x, enc_output, self_attn_mask, cross_attn_mask)
            
        # Final layer normalization and projection to vocabulary size
        x = self.norm(x)
        return self.output_proj(x)

In [48]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, max_seq_len=512, 
                 embed_dim=512, num_heads=8, ff_hidden_dim=2048, 
                 num_encoder_layers=6, num_decoder_layers=6, dropout=0.1):
        """
        Complete Transformer model with encoder and decoder.
        
        Args:
            src_vocab_size: Size of source vocabulary
            tgt_vocab_size: Size of target vocabulary
            max_seq_len: Maximum sequence length
            embed_dim: Dimension of embeddings
            num_heads: Number of attention heads
            ff_hidden_dim: Hidden dimension of feed-forward network
            num_encoder_layers: Number of encoder layers
            num_decoder_layers: Number of decoder layers
            dropout: Dropout probability
        """
        super().__init__()
        
        self.encoder = TransformerEncoder(
            max_seq_len=max_seq_len,
            vocab_size=src_vocab_size,
            embed_dim=embed_dim,
            num_heads=num_heads,
            ff_hidden_dim=ff_hidden_dim,
            num_layers=num_encoder_layers,
            dropout=dropout
        )
        
        self.decoder = TransformerDecoder(
            max_seq_len=max_seq_len,
            vocab_size=tgt_vocab_size,
            embed_dim=embed_dim,
            num_heads=num_heads,
            ff_hidden_dim=ff_hidden_dim,
            num_layers=num_decoder_layers,
            dropout=dropout
        )
        
    def create_masks(self, src, tgt):
        """
        Create necessary masks for transformer training.
        
        Args:
            src: Source sequence [batch_size, src_len]
            tgt: Target sequence [batch_size, tgt_len]
            
        Returns:
            src_mask: Source padding mask
            tgt_mask: Target padding and causal mask
            src_tgt_mask: Source-target padding mask for cross-attention
        """
        # Create padding masks
        src_pad_mask = (src != 0).unsqueeze(1).unsqueeze(2)  # (batch_size, 1, 1, src_len)
        tgt_pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)  # (batch_size, 1, 1, tgt_len)
        
        # Create causal mask for decoder's self-attention
        tgt_len = tgt.size(1)
        tgt_causal_mask = torch.tril(torch.ones(tgt_len, tgt_len)).to(tgt.device)
        tgt_causal_mask = tgt_causal_mask.unsqueeze(0).unsqueeze(1)  # (1, 1, tgt_len, tgt_len)
        
        # Combine padding mask and causal mask for decoder's self-attention
        tgt_mask = tgt_pad_mask & tgt_causal_mask
        
        # Create mask for decoder's cross-attention
        src_tgt_mask = src_pad_mask.transpose(-2, -1)  # (batch_size, 1, src_len, 1)
        
        return src_pad_mask, tgt_mask, src_tgt_mask
        
    def forward(self, src, tgt):
        """
        Forward pass of the full transformer model.
        
        Args:
            src: Source sequence [batch_size, src_len]
            tgt: Target sequence [batch_size, tgt_len]
            
        Returns:
            output: Output logits [batch_size, tgt_len, tgt_vocab_size]
        """
        # Create masks for attention
        src_mask, tgt_mask, src_tgt_mask = self.create_masks(src, tgt)
        
        # Encoder forward pass
        enc_output = self.encoder(src, src_mask)
        
        # Decoder forward pass
        output = self.decoder(tgt, enc_output, tgt_mask, src_tgt_mask)
        
        return output

### Types of Masks:
1. **Source Padding Mask `(src_pad_mask)`:** This mask is used to mask out the padding tokens in the source sequence (those with value `0`), ensuring that padding tokens do not influence the attention calculations.
2. **Target Padding Mask `(tgt_pad_mask)`:** Similar to the source padding mask, this mask ensures that padding tokens in the target sequence (with value `0`) are not attended to during the self-attention operation in the decoder.
3. **Target Causal Mask `(tgt_causal_mask)`:** This is a triangular mask used for the self-attention layer in the decoder to ensure that the prediction at time `t` does not depend on future tokens (i.e., the model cannot look ahead).
4. **Source-Target Mask `(src_tgt_mask)`:** This mask is used in the cross-attention layer in the decoder to ensure that the source sequence padding is not attended to.

In [49]:
# Example usage
def transformer_example():
    # Define model parameters
    src_vocab_size = 10000
    tgt_vocab_size = 10000
    max_seq_len = 512
    embed_dim = 512
    num_heads = 8
    ff_hidden_dim = 2048
    num_encoder_layers = 6
    num_decoder_layers = 6
    
    # Create model
    model = Transformer(
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size,
        max_seq_len=max_seq_len,
        embed_dim=embed_dim,
        num_heads=num_heads,
        ff_hidden_dim=ff_hidden_dim,
        num_encoder_layers=num_encoder_layers,
        num_decoder_layers=num_decoder_layers
    )
    
    # Example input data
    batch_size = 16
    src_len = 32
    tgt_len = 24
    
    src = torch.randint(1, src_vocab_size, (batch_size, src_len))
    tgt = torch.randint(1, tgt_vocab_size, (batch_size, tgt_len))
    
    # Forward pass
    output = model(src, tgt)
    
    print(f"Source shape: {src.shape}")
    print(f"Target shape: {tgt.shape}")
    print(f"Output shape: {output.shape}")
    
    return output

In [None]:
transformer_example()