In [None]:
!pip install torch numpy matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from torch.nn import TransformerEncoder, TransformerDecoder



In [None]:

# ============================================================================
# 1. POSITIONAL ENCODING
# ============================================================================

class PositionalEncoding(nn.Module):
    """
    Implements the sinusoidal positional encoding from the paper.
    
    PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
    PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
    """
    
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        """
        Args:
            d_model: Embedding dimension
            max_len: Maximum sequence length
            dropout: Dropout probability
        """
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # Compute the div term: 10000^(2i/d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                            (-math.log(10000.0) / d_model))
        
        # Apply sin to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        
        # Apply cos to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension: [max_len, d_model] -> [1, max_len, d_model]
        pe = pe.unsqueeze(0)
        
        # Register as buffer (not a parameter, but part of state)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        Args:
            x: Input tensor of shape [batch_size, seq_len, d_model]
        
        Returns:
            Tensor with positional encoding added
        """
        # Add positional encoding to input
        # x.shape = [batch_size, seq_len, d_model]
        # pe[:, :x.size(1)] = [1, seq_len, d_model]
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


# ============================================================================
# 2. SCALED DOT-PRODUCT ATTENTION
# ============================================================================

def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
    """
    Compute scaled dot-product attention.
    
    Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
    
    Args:
        query: Query tensor [batch, heads, seq_len, d_k]
        key: Key tensor [batch, heads, seq_len, d_k]
        value: Value tensor [batch, heads, seq_len, d_k]
        mask: Optional mask tensor
        dropout: Optional dropout layer
    
    Returns:
        output: Attention output [batch, heads, seq_len, d_k]
        attention_weights: Attention weights [batch, heads, seq_len, seq_len]
    """
    d_k = query.size(-1)
    
    # Step 1: Compute attention scores (QK^T)
    # [batch, heads, seq_len, d_k] x [batch, heads, d_k, seq_len]
    # -> [batch, heads, seq_len, seq_len]
    scores = torch.matmul(query, key.transpose(-2, -1))
    
    # Step 2: Scale by sqrt(d_k)
    scores = scores / math.sqrt(d_k)
    
    # Step 3: Apply mask (if provided)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Step 4: Apply softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # Step 5: Apply dropout (if provided)
    if dropout is not None:
        attention_weights = dropout(attention_weights)
    
    # Step 6: Compute weighted sum of values
    # [batch, heads, seq_len, seq_len] x [batch, heads, seq_len, d_k]
    # -> [batch, heads, seq_len, d_k]
    output = torch.matmul(attention_weights, value)
    
    return output, attention_weights


# ============================================================================
# 3. MULTI-HEAD ATTENTION
# ============================================================================

class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention mechanism.
    
    MultiHead(Q,K,V) = Concat(head_1, ..., head_h)W^O
    where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
    """
    
    def __init__(self, d_model, num_heads, dropout=0.1):
        """
        Args:
            d_model: Model dimension
            num_heads: Number of attention heads
            dropout: Dropout probability
        """
        super(MultiHeadAttention, self).__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Dimension per head
        
        # Linear projections for Q, K, V (one for each)
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def split_heads(self, x):
        """
        Split the last dimension into (num_heads, d_k).
        
        Args:
            x: [batch_size, seq_len, d_model]
        
        Returns:
            [batch_size, num_heads, seq_len, d_k]
        """
        batch_size, seq_len, d_model = x.size()
        
        # Reshape to [batch, seq_len, num_heads, d_k]
        x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
        
        # Transpose to [batch, num_heads, seq_len, d_k]
        return x.transpose(1, 2)
    
    def combine_heads(self, x):
        """
        Inverse of split_heads.
        
        Args:
            x: [batch_size, num_heads, seq_len, d_k]
        
        Returns:
            [batch_size, seq_len, d_model]
        """
        batch_size, num_heads, seq_len, d_k = x.size()
        
        # Transpose to [batch, seq_len, num_heads, d_k]
        x = x.transpose(1, 2)
        
        # Reshape to [batch, seq_len, d_model]
        return x.contiguous().view(batch_size, seq_len, self.d_model)
    
    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: [batch_size, seq_len_q, d_model]
            key: [batch_size, seq_len_k, d_model]
            value: [batch_size, seq_len_v, d_model]
            mask: Optional mask
        
        Returns:
            output: [batch_size, seq_len_q, d_model]
            attention_weights: [batch_size, num_heads, seq_len_q, seq_len_k]
        """
        batch_size = query.size(0)
        
        # Step 1: Linear projections
        Q = self.W_q(query)  # [batch, seq_len_q, d_model]
        K = self.W_k(key)    # [batch, seq_len_k, d_model]
        V = self.W_v(value)  # [batch, seq_len_v, d_model]
        
        # Step 2: Split into multiple heads
        Q = self.split_heads(Q)  # [batch, num_heads, seq_len_q, d_k]
        K = self.split_heads(K)  # [batch, num_heads, seq_len_k, d_k]
        V = self.split_heads(V)  # [batch, num_heads, seq_len_v, d_k]
        
        # Step 3: Apply attention
        attn_output, attention_weights = scaled_dot_product_attention(
            Q, K, V, mask, self.dropout
        )
        # attn_output: [batch, num_heads, seq_len_q, d_k]
        
        # Step 4: Combine heads
        output = self.combine_heads(attn_output)
        # output: [batch, seq_len_q, d_model]
        
        # Step 5: Final linear projection
        output = self.W_o(output)
        
        return output, attention_weights


# ============================================================================
# 4. POSITION-WISE FEED-FORWARD NETWORK
# ============================================================================

class PositionWiseFeedForward(nn.Module):
    """
    Position-wise Feed-Forward Network.
    
    FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
    
    Applied to each position separately and identically.
    """
    
    def __init__(self, d_model, d_ff, dropout=0.1):
        """
        Args:
            d_model: Model dimension
            d_ff: Feed-forward dimension (typically 4 * d_model)
            dropout: Dropout probability
        """
        super(PositionWiseFeedForward, self).__init__()
        
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        
        Returns:
            [batch_size, seq_len, d_model]
        """
        # x -> [batch, seq, d_ff] -> [batch, seq, d_model]
        return self.linear2(self.dropout(F.relu(self.linear1(x))))


# ============================================================================
# 5. ENCODER LAYER
# ============================================================================

class EncoderLayer(nn.Module):
    """
    Single Transformer Encoder Layer.
    
    Consists of:
    1. Multi-head self-attention
    2. Add & Norm
    3. Feed-forward network
    4. Add & Norm
    """
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        """
        Args:
            d_model: Model dimension
            num_heads: Number of attention heads
            d_ff: Feed-forward dimension
            dropout: Dropout probability
        """
        super(EncoderLayer, self).__init__()
        
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: [batch_size, seq_len, d_model]
            mask: Optional attention mask
        
        Returns:
            [batch_size, seq_len, d_model]
        """
        # Step 1: Multi-head self-attention
        attn_output, _ = self.self_attn(x, x, x, mask)
        
        # Step 2: Add & Norm
        x = self.norm1(x + self.dropout1(attn_output))
        
        # Step 3: Feed-forward network
        ff_output = self.feed_forward(x)
        
        # Step 4: Add & Norm
        x = self.norm2(x + self.dropout2(ff_output))
        
        return x


# ============================================================================
# 6. DECODER LAYER
# ============================================================================

class DecoderLayer(nn.Module):
    """
    Single Transformer Decoder Layer.
    
    Consists of:
    1. Masked multi-head self-attention
    2. Add & Norm
    3. Multi-head cross-attention (attend to encoder output)
    4. Add & Norm
    5. Feed-forward network
    6. Add & Norm
    """
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        """
        Args:
            d_model: Model dimension
            num_heads: Number of attention heads
            d_ff: Feed-forward dimension
            dropout: Dropout probability
        """
        super(DecoderLayer, self).__init__()
        
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: [batch_size, tgt_seq_len, d_model]
            encoder_output: [batch_size, src_seq_len, d_model]
            src_mask: Optional source mask
            tgt_mask: Optional target mask (for masking future positions)
        
        Returns:
            [batch_size, tgt_seq_len, d_model]
        """
        # Step 1: Masked multi-head self-attention
        self_attn_output, _ = self.self_attn(x, x, x, tgt_mask)
        
        # Step 2: Add & Norm
        x = self.norm1(x + self.dropout1(self_attn_output))
        
        # Step 3: Multi-head cross-attention
        cross_attn_output, _ = self.cross_attn(
            x, encoder_output, encoder_output, src_mask
        )
        
        # Step 4: Add & Norm
        x = self.norm2(x + self.dropout2(cross_attn_output))
        
        # Step 5: Feed-forward network
        ff_output = self.feed_forward(x)
        
        # Step 6: Add & Norm
        x = self.norm3(x + self.dropout3(ff_output))
        
        return x


# ============================================================================
# 7. COMPLETE TRANSFORMER MODEL
# ============================================================================

class Transformer(nn.Module):
    """
    Complete Transformer model for sequence-to-sequence tasks.
    """
    
    def __init__(
        self,
        src_vocab_size,
        tgt_vocab_size,
        d_model=512,
        num_heads=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        d_ff=2048,
        max_seq_length=5000,
        dropout=0.1
    ):
        """
        Args:
            src_vocab_size: Source vocabulary size
            tgt_vocab_size: Target vocabulary size
            d_model: Model dimension
            num_heads: Number of attention heads
            num_encoder_layers: Number of encoder layers
            num_decoder_layers: Number of decoder layers
            d_ff: Feed-forward dimension
            max_seq_length: Maximum sequence length
            dropout: Dropout probability
        """
        super(Transformer, self).__init__()
        
        self.d_model = d_model
        
        # Embedding layers
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        
        # Positional encoding
        self.pos_encoding = PositionalEncoding(d_model, max_seq_length, dropout)
        
        # Encoder layers
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_encoder_layers)
        ])
        
        # Decoder layers
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_decoder_layers)
        ])
        
        # Final linear layer
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        
        self.dropout = nn.Dropout(dropout)
        
        # Initialize parameters
        self._init_parameters()
    
    def _init_parameters(self):
        """Initialize parameters with Xavier uniform initialization."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def generate_square_subsequent_mask(self, sz):
        """
        Generate mask for decoder (prevents attending to future positions).
        
        Args:
            sz: Sequence length
        
        Returns:
            Mask tensor [sz, sz]
        """
        mask = torch.triu(torch.ones(sz, sz), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask
    
    def encode(self, src, src_mask=None):
        """
        Encode source sequence.
        
        Args:
            src: [batch_size, src_seq_len]
            src_mask: Optional source mask
        
        Returns:
            [batch_size, src_seq_len, d_model]
        """
        # Embedding + Positional encoding
        x = self.src_embedding(src) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        
        # Pass through encoder layers
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        
        return x
    
    def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
        """
        Decode target sequence.
        
        Args:
            tgt: [batch_size, tgt_seq_len]
            encoder_output: [batch_size, src_seq_len, d_model]
            src_mask: Optional source mask
            tgt_mask: Optional target mask
        
        Returns:
            [batch_size, tgt_seq_len, d_model]
        """
        # Embedding + Positional encoding
        x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        
        # Pass through decoder layers
        for layer in self.decoder_layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        
        return x
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        Forward pass.
        
        Args:
            src: [batch_size, src_seq_len]
            tgt: [batch_size, tgt_seq_len]
            src_mask: Optional source mask
            tgt_mask: Optional target mask
        
        Returns:
            [batch_size, tgt_seq_len, tgt_vocab_size]
        """
        # Encode
        encoder_output = self.encode(src, src_mask)
        
        # Decode
        decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
        
        # Project to vocabulary
        output = self.fc_out(decoder_output)
        
        return output


# ============================================================================
# 8. UTILITY FUNCTIONS
# ============================================================================

def create_padding_mask(seq, pad_idx=0):
    """
    Create padding mask for sequences.
    
    Args:
        seq: [batch_size, seq_len]
        pad_idx: Padding index
    
    Returns:
        Mask [batch_size, 1, 1, seq_len]
    """
    mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
    return mask


def create_look_ahead_mask(size):
    """
    Create look-ahead mask for decoder (prevents attending to future).
    
    Args:
        size: Sequence length
    
    Returns:
        Mask [1, 1, size, size]
    """
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    mask = mask.masked_fill(mask == 1, 0)
    return mask.unsqueeze(0).unsqueeze(0)


# ============================================================================
# 9. EXAMPLE USAGE
# ============================================================================

if __name__ == "__main__":
    
    print("="*80)
    print("TRANSFORMER MODEL IMPLEMENTATION")
    print("="*80)
    
    # Model parameters
    src_vocab_size = 10000
    tgt_vocab_size = 10000
    d_model = 512
    num_heads = 8
    num_encoder_layers = 6
    num_decoder_layers = 6
    d_ff = 2048
    max_seq_length = 100
    dropout = 0.1
    
    # Create model
    model = Transformer(
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size,
        d_model=d_model,
        num_heads=num_heads,
        num_encoder_layers=num_encoder_layers,
        num_decoder_layers=num_decoder_layers,
        d_ff=d_ff,
        max_seq_length=max_seq_length,
        dropout=dropout
    )
    
    print(f"\nModel created with {sum(p.numel() for p in model.parameters()):,} parameters")
    
    # Example input
    batch_size = 32
    src_seq_len = 20
    tgt_seq_len = 15
    
    src = torch.randint(1, src_vocab_size, (batch_size, src_seq_len))
    tgt = torch.randint(1, tgt_vocab_size, (batch_size, tgt_seq_len))
    
    print(f"\nInput shapes:")
    print(f"  Source: {src.shape}")
    print(f"  Target: {tgt.shape}")
    
    # Create masks
    src_mask = create_padding_mask(src)
    tgt_mask = create_look_ahead_mask(tgt_seq_len)
    
    # Forward pass
    output = model(src, tgt, src_mask, tgt_mask)
    
    print(f"\nOutput shape: {output.shape}")
    print(f"  Expected: [batch_size={batch_size}, tgt_seq_len={tgt_seq_len}, tgt_vocab_size={tgt_vocab_size}]")
    
    # Detailed component test
    print("\n" + "="*80)
    print("COMPONENT TESTING")
    print("="*80)
    
    # Test 1: Positional Encoding
    print("\n1. Positional Encoding:")
    pe = PositionalEncoding(d_model=8, max_len=10)
    x = torch.randn(2, 5, 8)  # [batch, seq_len, d_model]
    x_with_pe = pe(x)
    print(f"   Input shape: {x.shape}")
    print(f"   Output shape: {x_with_pe.shape}")
    print(f"   First position encoding:\n{pe.pe[0, 0, :]}")
    
    # Test 2: Multi-Head Attention
    print("\n2. Multi-Head Attention:")
    mha = MultiHeadAttention(d_model=512, num_heads=8)
    q = k = v = torch.randn(2, 10, 512)  # [batch, seq_len, d_model]
    attn_output, attn_weights = mha(q, k, v)
    print(f"   Input shape: {q.shape}")
    print(f"   Output shape: {attn_output.shape}")
    print(f"   Attention weights shape: {attn_weights.shape}")
    
    # Test 3: Feed-Forward Network
    print("\n3. Feed-Forward Network:")
    ffn = PositionWiseFeedForward(d_model=512, d_ff=2048)
    x = torch.randn(2, 10, 512)
    ffn_output = ffn(x)
    print(f"   Input shape: {x.shape}")
    print(f"   Output shape: {ffn_output.shape}")
    
    # Test 4: Encoder Layer
    print("\n4. Encoder Layer:")
    enc_layer = EncoderLayer(d_model=512, num_heads=8, d_ff=2048)
    x = torch.randn(2, 10, 512)
    enc_output = enc_layer(x)
    print(f"   Input shape: {x.shape}")
    print(f"   Output shape: {enc_output.shape}")
    
    # Test 5: Decoder Layer
    print("\n5. Decoder Layer:")
    dec_layer = DecoderLayer(d_model=512, num_heads=8, d_ff=2048)
    x = torch.randn(2, 8, 512)
    encoder_output = torch.randn(2, 10, 512)
    dec_output = dec_layer(x, encoder_output)
    print(f"   Decoder input shape: {x.shape}")
    print(f"   Encoder output shape: {encoder_output.shape}")
    print(f"   Decoder output shape: {dec_output.shape}")
    
    print("\n" + "="*80)
    print("All tests passed successfully!")
    print("="*80)


# ============================================================================
# 10. TRAINING EXAMPLE
# ============================================================================

class SimpleTrainer:
    """Simple trainer for demonstration purposes."""
    
    def __init__(self, model, optimizer, criterion, device='cpu'):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        self.model.to(device)
    
    def train_step(self, src, tgt):
        """Single training step."""
        self.model.train()
        
        # Move to device
        src = src.to(self.device)
        tgt = tgt.to(self.device)
        
        # Prepare target input and output
        tgt_input = tgt[:, :-1]  # All but last token
        tgt_output = tgt[:, 1:]  # All but first token
        
        # Create masks
        src_mask = create_padding_mask(src).to(self.device)
        tgt_mask = create_look_ahead_mask(tgt_input.size(1)).to(self.device)
        
        # Forward pass
        output = self.model(src, tgt_input, src_mask, tgt_mask)
        
        # Compute loss
        loss = self.criterion(
            output.contiguous().view(-1, output.size(-1)),
            tgt_output.contiguous().view(-1)
        )
        
        # Backward pass
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()
    
    def evaluate(self, src, tgt):
        """Evaluation step."""
        self.model.eval()
        
        with torch.no_grad():
            src = src.to(self.device)
            tgt = tgt.to(self.device)
            
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]
            
            src_mask = create_padding_mask(src).to(self.device)
            tgt_mask = create_look_ahead_mask(tgt_input.size(1)).to(self.device)
            
            output = self.model(src, tgt_input, src_mask, tgt_mask)
            
            loss = self.criterion(
                output.contiguous().view(-1, output.size(-1)),
                tgt_output.contiguous().view(-1)
            )
        
        return loss.item()


# ============================================================================
# 11. INFERENCE EXAMPLE
# ============================================================================

def greedy_decode(model, src, src_mask, max_len, start_symbol, end_symbol, device='cpu'):
    """
    Greedy decoding for inference.
    
    Args:
        model: Trained transformer model
        src: Source sequence [1, src_len]
        src_mask: Source mask
        max_len: Maximum generation length
        start_symbol: Start token ID
        end_symbol: End token ID
        device: Device to run on
    
    Returns:
        Decoded sequence
    """
    model.eval()
    
    src = src.to(device)
    src_mask = src_mask.to(device)
    
    # Encode source
    with torch.no_grad():
        encoder_output = model.encode(src, src_mask)
    
    # Initialize target with start symbol
    tgt = torch.ones(1, 1).fill_(start_symbol).type_as(src)
    
    for i in range(max_len - 1):
        with torch.no_grad():
            # Create target mask
            tgt_mask = create_look_ahead_mask(tgt.size(1)).to(device)
            
            # Decode
            decoder_output = model.decode(tgt, encoder_output, src_mask, tgt_mask)
            
            # Get probabilities for next token
            output = model.fc_out(decoder_output)
            prob = output[:, -1, :]  # Last position
            
            # Get token with highest probability
            _, next_token = torch.max(prob, dim=-1)
            next_token = next_token.item()
            
            # Append to target
            tgt = torch.cat([tgt, torch.ones(1, 1).fill_(next_token).type_as(src)], dim=1)
            
            # Stop if end symbol generated
            if next_token == end_symbol:
                break
    
    return tgt


# ============================================================================
# Example Training Loop
# ============================================================================

def example_training():
    """Example training loop."""
    
    print("\n" + "="*80)
    print("EXAMPLE TRAINING LOOP")
    print("="*80)
    
    # Parameters
    src_vocab_size = 1000
    tgt_vocab_size = 1000
    d_model = 256
    num_heads = 8
    num_layers = 3
    d_ff = 1024
    max_seq_length = 50
    dropout = 0.1
    
    # Create model
    model = Transformer(
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size,
        d_model=d_model,
        num_heads=num_heads,
        num_encoder_layers=num_layers,
        num_decoder_layers=num_layers,
        d_ff=d_ff,
        max_seq_length=max_seq_length,
        dropout=dropout
    )
    
    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding
    
    # Create trainer
    trainer = SimpleTrainer(model, optimizer, criterion)
    
    # Dummy data
    num_epochs = 3
    batch_size = 16
    
    print(f"\nTraining for {num_epochs} epochs...")
    
    for epoch in range(num_epochs):
        # Generate dummy batch
        src = torch.randint(1, src_vocab_size, (batch_size, 20))
        tgt = torch.randint(1, tgt_vocab_size, (batch_size, 15))
        
        # Train
        loss = trainer.train_step(src, tgt)
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}")
    
    print("\nTraining complete!")
    
    # Example inference
    print("\n" + "="*80)
    print("EXAMPLE INFERENCE")
    print("="*80)
    
    # Create dummy source
    src = torch.randint(1, src_vocab_size, (1, 10))
    src_mask = create_padding_mask(src)
    
    # Decode
    output = greedy_decode(
        model, src, src_mask,
        max_len=20,
        start_symbol=1,
        end_symbol=2
    )
    
    print(f"\nSource shape: {src.shape}")
    print(f"Generated output shape: {output.shape}")
    print(f"Generated tokens: {output[0].tolist()}")


if __name__ == "__main__":
    # Run example training
    example_training()

In [None]:
# Create model
model = Transformer(
    src_vocab_size=10000,
    tgt_vocab_size=10000,
    d_model=512,
    num_heads=8,
    num_encoder_layers=6,
    num_decoder_layers=6
)

# Prepare data
src = torch.randint(1, 10000, (32, 20))  # [batch, seq_len]
tgt = torch.randint(1, 10000, (32, 15))

# Forward pass
output = model(src, tgt)  # [batch, tgt_len, vocab_size]

In [None]:
# Setup
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(num_epochs):
    for batch in dataloader:
        src, tgt = batch
        
        # Forward
        output = model(src, tgt[:, :-1])
        loss = criterion(output.view(-1, vocab_size), tgt[:, 1:].view(-1))
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
# Greedy decoding
output = greedy_decode(
    model, source, source_mask,
    max_len=50,
    start_symbol=START_TOKEN,
    end_symbol=END_TOKEN
)

Component               Parameters
-----------------------------------------
Embeddings (2x)         ~40M
Encoder (6 layers)      ~37M
Decoder (6 layers)      ~50M
Output projection       ~5M
-----------------------------------------
Total                   ~132M parameters