# Transformer Architecture from Scratch: Complete Implementation and Analysis

**"Attention is All You Need" - A Comprehensive Implementation**

**Authors:** PyTorch Mastery Hub Team  
**Institution:** Deep Learning Research Institute  
**Course:** Advanced Transformers and Attention Mechanisms  
**Date:** December 2024

## Overview

This notebook provides a complete implementation of the original Transformer architecture from the groundbreaking paper "Attention is All You Need" by Vaswani et al. We build every component from scratch, including multi-head self-attention, positional encoding, encoder-decoder stacks, and comprehensive training pipelines with detailed analysis and visualization.

## Key Objectives
1. Implement multi-head self-attention mechanism from first principles
2. Build comprehensive positional encoding systems (sinusoidal and learned)
3. Construct complete encoder and decoder transformer blocks
4. Create full Transformer architecture for sequence-to-sequence tasks
5. Train and evaluate on copy task to validate implementation
6. Visualize attention patterns and analyze model behavior
7. Provide production-ready code with detailed documentation

## 1. Setup and Environment Configuration

```python
# Import required libraries for comprehensive Transformer implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
import copy
import json
import pickle
from pathlib import Path
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Configure plotting environment
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

# Set device and seeds for reproducibility
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🤖 Transformer Architecture Implementation Initialized")
print(f"   Device: {device}")
print(f"   PyTorch Version: {torch.__version__}")
print(f"   CUDA Available: {torch.cuda.is_available()}")

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print("✅ Environment configured with deterministic settings")

# Create results directory for this notebook
notebook_results_dir = Path('results/transformers/from_scratch')
notebook_results_dir.mkdir(parents=True, exist_ok=True)

print(f"📁 Results will be saved to: {notebook_results_dir}")
```

## 2. Multi-Head Attention Mechanism Implementation

The core innovation of the Transformer: scaled dot-product attention with multiple heads for different representation subspaces.

```python
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention mechanism from 'Attention is All You Need'.
    
    This implementation includes:
    - Scaled dot-product attention
    - Multi-head parallel processing
    - Linear projections for Q, K, V
    - Attention dropout and output projection
    """
    
    def __init__(self, d_model, n_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        
        assert d_model % n_heads == 0, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize weights using Xavier uniform initialization."""
        for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        Compute scaled dot-product attention.
        
        Args:
            Q, K, V: Query, Key, Value tensors (batch_size, n_heads, seq_len, d_k)
            mask: Attention mask (batch_size, 1, seq_len, seq_len) or similar
            
        Returns:
            output: Attended values (batch_size, n_heads, seq_len, d_k)
            attention_weights: Attention probabilities (batch_size, n_heads, seq_len, seq_len)
        """
        batch_size, n_heads, seq_len, d_k = Q.size()
        
        # Compute attention scores: Q * K^T / sqrt(d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask if provided (set masked positions to large negative value)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention to values
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights
    
    def forward(self, query, key, value, mask=None):
        """
        Forward pass through multi-head attention.
        
        Args:
            query, key, value: Input tensors (batch_size, seq_len, d_model)
            mask: Attention mask
            
        Returns:
            output: Multi-head attention output (batch_size, seq_len, d_model)
            attention_weights: Attention weights (batch_size, n_heads, seq_len, seq_len)
        """
        batch_size, seq_len, d_model = query.size()
        
        # Linear projections and reshape for multi-head attention
        Q = self.W_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        # Shape after transpose: (batch_size, n_heads, seq_len, d_k)
        
        # Apply scaled dot-product attention
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Concatenate heads and put through final linear layer
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )
        
        # Final linear projection
        output = self.W_o(attention_output)
        
        return output, attention_weights

def test_multihead_attention():
    """Test multi-head attention implementation with comprehensive analysis."""
    print("🧠 Testing Multi-Head Attention Implementation...")
    
    # Test parameters
    d_model = 512
    n_heads = 8
    seq_len = 10
    batch_size = 2
    
    # Create multi-head attention module
    mha = MultiHeadAttention(d_model, n_heads)
    
    # Test input
    x = torch.randn(batch_size, seq_len, d_model)
    
    # Forward pass
    output, attention_weights = mha(x, x, x)
    
    print(f"✅ Multi-Head Attention Test Results:")
    print(f"   Input shape: {x.shape}")
    print(f"   Output shape: {output.shape}")
    print(f"   Attention weights shape: {attention_weights.shape}")
    print(f"   Parameters: {sum(p.numel() for p in mha.parameters()):,}")
    print(f"   Memory usage: ~{sum(p.numel() for p in mha.parameters()) * 4 / 1024**2:.1f} MB")
    
    # Verify attention weights sum to 1
    attention_sum = attention_weights.sum(dim=-1)
    print(f"   Attention weights sum check: {torch.allclose(attention_sum, torch.ones_like(attention_sum))}")
    
    return mha, output, attention_weights

# Test multi-head attention
mha_module, test_output, test_attention = test_multihead_attention()

# Visualize attention patterns
def visualize_attention_heads(attention_weights, save_path=None):
    """Visualize attention patterns for different heads."""
    batch_idx, n_heads, seq_len, _ = attention_weights.shape
    
    # Plot attention patterns for first sample, first 8 heads
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.flatten()
    
    for head in range(min(8, n_heads)):
        # Get attention matrix for first sample, specific head
        attn_matrix = attention_weights[0, head].detach().cpu().numpy()
        
        sns.heatmap(attn_matrix, cmap='Blues', ax=axes[head], 
                   cbar=True, square=True, cbar_kws={'shrink': 0.8})
        axes[head].set_title(f'Head {head + 1}')
        axes[head].set_xlabel('Key Position')
        axes[head].set_ylabel('Query Position')
    
    plt.suptitle('Multi-Head Attention Patterns', fontsize=16)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# Visualize test attention patterns
visualize_attention_heads(test_attention, notebook_results_dir / 'multihead_attention_patterns.png')

print("✅ Multi-head attention implementation complete and tested!")
```

## 3. Positional Encoding Systems

Since Transformers lack inherent sequence order awareness, we implement both sinusoidal and learned positional encodings.

```python
class PositionalEncoding(nn.Module):
    """
    Sinusoidal positional encoding from 'Attention is All You Need'.
    
    Uses sine and cosine functions of different frequencies to encode position information
    in a way that allows the model to attend to relative positions.
    """
    
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # Create div_term for sinusoidal pattern
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-math.log(10000.0) / d_model))
        
        # Apply sin to even indices and cos to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension and register as buffer (not a parameter)
        pe = pe.unsqueeze(0)  # Shape: (1, max_len, d_model)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        """
        Add positional encoding to input embeddings.
        
        Args:
            x: Input embeddings (batch_size, seq_len, d_model)
            
        Returns:
            x + positional encoding, with dropout applied
        """
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len, :]
        return self.dropout(x)

class LearnedPositionalEncoding(nn.Module):
    """
    Learned positional encoding alternative to sinusoidal encoding.
    
    Uses trainable embedding layer to learn optimal position representations.
    """
    
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super(LearnedPositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(dropout)
        self.pe = nn.Embedding(max_len, d_model)
        
        # Initialize with small random values
        nn.init.uniform_(self.pe.weight, -0.1, 0.1)
        
    def forward(self, x):
        """
        Add learned positional encoding to input embeddings.
        
        Args:
            x: Input embeddings (batch_size, seq_len, d_model)
            
        Returns:
            x + learned positional encoding, with dropout applied
        """
        batch_size, seq_len, d_model = x.size()
        
        # Create position indices
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
        
        # Add learned positional embeddings
        pos_encoding = self.pe(positions)
        x = x + pos_encoding
        
        return self.dropout(x)

def analyze_positional_encoding():
    """Comprehensive analysis of positional encoding mechanisms."""
    print("📍 Analyzing Positional Encoding Systems...")
    
    d_model = 512
    max_len = 100
    
    # Create both types of positional encoding
    sinusoidal_pe = PositionalEncoding(d_model, max_len)
    learned_pe = LearnedPositionalEncoding(d_model, max_len)
    
    # Test input
    test_input = torch.randn(4, 50, d_model)
    
    # Apply encodings
    sinusoidal_output = sinusoidal_pe(test_input)
    learned_output = learned_pe(test_input)
    
    print(f"✅ Positional Encoding Analysis:")
    print(f"   Input shape: {test_input.shape}")
    print(f"   Sinusoidal output shape: {sinusoidal_output.shape}")
    print(f"   Learned output shape: {learned_output.shape}")
    print(f"   Sinusoidal parameters: {sum(p.numel() for p in sinusoidal_pe.parameters()):,}")
    print(f"   Learned parameters: {sum(p.numel() for p in learned_pe.parameters()):,}")
    
    return sinusoidal_pe, learned_pe

# Analyze positional encodings
sin_pe, learned_pe = analyze_positional_encoding()

def visualize_positional_encoding():
    """Create comprehensive visualizations of positional encoding patterns."""
    print("🎨 Creating Positional Encoding Visualizations...")
    
    d_model = 512
    max_positions = 100
    
    # Get sinusoidal positional encoding matrix
    pe_matrix = sin_pe.pe.squeeze(0).numpy()  # Shape: (max_len, d_model)
    
    # Create comprehensive visualization
    fig = plt.figure(figsize=(20, 15))
    
    # 1. Full positional encoding heatmap
    plt.subplot(3, 2, 1)
    plt.imshow(pe_matrix[:max_positions, :128].T, cmap='RdYlBu', aspect='auto')
    plt.title('Sinusoidal Positional Encoding\n(First 128 dimensions, 100 positions)')
    plt.xlabel('Position')
    plt.ylabel('Dimension')
    plt.colorbar(shrink=0.8)
    
    # 2. Specific dimension patterns over positions
    plt.subplot(3, 2, 2)
    positions = np.arange(max_positions)
    for dim in [0, 1, 16, 17, 64, 65]:
        plt.plot(positions, pe_matrix[:max_positions, dim], 
                label=f'Dim {dim}', linewidth=2, alpha=0.8)
    plt.title('Positional Encoding Patterns by Dimension')
    plt.xlabel('Position')
    plt.ylabel('Encoding Value')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 3. Frequency analysis
    plt.subplot(3, 2, 3)
    frequencies = []
    for dim in range(0, d_model//2):
        div_term = math.exp(dim * 2 * (-math.log(10000.0) / d_model))
        freq = 1 / (2 * math.pi / div_term)
        frequencies.append(freq)
    
    plt.semilogy(range(0, d_model//2), frequencies, 'o-', alpha=0.7)
    plt.title('Frequency by Dimension')
    plt.xlabel('Dimension Index')
    plt.ylabel('Frequency (log scale)')
    plt.grid(True, alpha=0.3)
    
    # 4. Distance analysis between positions
    plt.subplot(3, 2, 4)
    reference_positions = [10, 20, 30, 40]
    for ref_pos in reference_positions:
        distances = []
        for pos in range(max_positions):
            dist = np.linalg.norm(pe_matrix[pos, :] - pe_matrix[ref_pos, :])
            distances.append(dist)
        plt.plot(distances, label=f'Distance from pos {ref_pos}', alpha=0.8)
    
    plt.title('Euclidean Distance Between Positions')
    plt.xlabel('Position')
    plt.ylabel('Distance')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 5. Attention pattern simulation
    plt.subplot(3, 2, 5)
    # Simulate attention scores based on positional similarity
    query_pos = 25
    similarities = []
    for pos in range(max_positions):
        similarity = np.dot(pe_matrix[query_pos, :64], pe_matrix[pos, :64])
        similarities.append(similarity)
    
    plt.plot(similarities, 'g-', linewidth=2, alpha=0.8)
    plt.axvline(query_pos, color='red', linestyle='--', alpha=0.8, label=f'Query position ({query_pos})')
    plt.title('Positional Similarity Pattern\n(Dot product with query position)')
    plt.xlabel('Key Position')
    plt.ylabel('Similarity Score')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 6. Comparison: Sinusoidal vs Random
    plt.subplot(3, 2, 6)
    random_encoding = np.random.randn(max_positions, d_model) * 0.1
    
    sin_distances = [np.linalg.norm(pe_matrix[i+1, :] - pe_matrix[i, :]) 
                    for i in range(max_positions-1)]
    random_distances = [np.linalg.norm(random_encoding[i+1, :] - random_encoding[i, :]) 
                       for i in range(max_positions-1)]
    
    plt.plot(sin_distances, label='Sinusoidal PE', alpha=0.8, linewidth=2)
    plt.plot(random_distances, label='Random PE', alpha=0.8, linewidth=2)
    plt.title('Consecutive Position Distances\n(Sinusoidal vs Random)')
    plt.xlabel('Position')
    plt.ylabel('Distance to Next Position')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(notebook_results_dir / 'positional_encoding_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Create summary statistics
    pe_stats = {
        'sinusoidal_properties': {
            'mean_encoding_value': float(pe_matrix.mean()),
            'std_encoding_value': float(pe_matrix.std()),
            'min_encoding_value': float(pe_matrix.min()),
            'max_encoding_value': float(pe_matrix.max()),
            'consecutive_position_distance_mean': float(np.mean(sin_distances)),
            'consecutive_position_distance_std': float(np.std(sin_distances))
        },
        'frequency_analysis': {
            'lowest_frequency': float(min(frequencies)),
            'highest_frequency': float(max(frequencies)),
            'frequency_range_log10': float(np.log10(max(frequencies)) - np.log10(min(frequencies)))
        }
    }
    
    return pe_stats

# Visualize and analyze positional encoding
pe_analysis_stats = visualize_positional_encoding()

print("✅ Positional encoding systems implemented and analyzed!")
```

## 4. Transformer Building Blocks

Core components including feed-forward networks, layer normalization, and complete encoder/decoder layers.

```python
class FeedForward(nn.Module):
    """
    Position-wise feed-forward network used in Transformer layers.
    
    Implements: FFN(x) = max(0, xW1 + b1)W2 + b2
    Two linear transformations with ReLU activation in between.
    """
    
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Initialize weights
        nn.init.xavier_uniform_(self.linear1.weight)
        nn.init.xavier_uniform_(self.linear2.weight)
        nn.init.zeros_(self.linear1.bias)
        nn.init.zeros_(self.linear2.bias)
        
    def forward(self, x):
        """
        Forward pass through feed-forward network.
        
        Args:
            x: Input tensor (batch_size, seq_len, d_model)
            
        Returns:
            Output tensor (batch_size, seq_len, d_model)
        """
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

class TransformerEncoderLayer(nn.Module):
    """
    Single layer of Transformer encoder.
    
    Consists of:
    1. Multi-head self-attention
    2. Add & Norm (residual connection + layer normalization)
    3. Feed-forward network
    4. Add & Norm (residual connection + layer normalization)
    """
    
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        
        self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        """
        Forward pass through encoder layer.
        
        Args:
            x: Input tensor (batch_size, seq_len, d_model)
            mask: Attention mask
            
        Returns:
            output: Transformed tensor (batch_size, seq_len, d_model)
            attention_weights: Self-attention weights
        """
        # Self-attention with residual connection and layer norm
        attn_output, attention_weights = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection and layer norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x, attention_weights

class TransformerDecoderLayer(nn.Module):
    """
    Single layer of Transformer decoder.
    
    Consists of:
    1. Masked multi-head self-attention
    2. Add & Norm
    3. Multi-head cross-attention to encoder output
    4. Add & Norm
    5. Feed-forward network
    6. Add & Norm
    """
    
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super(TransformerDecoderLayer, self).__init__()
        
        self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.cross_attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Forward pass through decoder layer.
        
        Args:
            x: Decoder input (batch_size, tgt_seq_len, d_model)
            encoder_output: Encoder output (batch_size, src_seq_len, d_model)
            src_mask: Source attention mask
            tgt_mask: Target attention mask (with look-ahead masking)
            
        Returns:
            output: Transformed tensor
            self_attention_weights: Self-attention weights
            cross_attention_weights: Cross-attention weights
        """
        # Masked self-attention
        self_attn_output, self_attention_weights = self.self_attention(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn_output))
        
        # Cross-attention to encoder output
        cross_attn_output, cross_attention_weights = self.cross_attention(
            x, encoder_output, encoder_output, src_mask
        )
        x = self.norm2(x + self.dropout(cross_attn_output))
        
        # Feed-forward
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x, self_attention_weights, cross_attention_weights

def create_padding_mask(seq, pad_idx=0):
    """
    Create padding mask to ignore padded tokens in attention.
    
    Args:
        seq: Input sequence (batch_size, seq_len)
        pad_idx: Padding token index
        
    Returns:
        mask: Padding mask (batch_size, 1, 1, seq_len)
    """
    return (seq != pad_idx).unsqueeze(1).unsqueeze(2)

def create_look_ahead_mask(size):
    """
    Create look-ahead mask for decoder to prevent attending to future tokens.
    
    Args:
        size: Sequence length
        
    Returns:
        mask: Lower triangular mask (size, size)
    """
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    return mask == 0

def test_transformer_blocks():
    """Test Transformer encoder and decoder layers."""
    print("🧱 Testing Transformer Building Blocks...")
    
    # Test parameters
    d_model = 512
    n_heads = 8
    d_ff = 2048
    seq_len = 20
    batch_size = 4
    
    # Create layers
    encoder_layer = TransformerEncoderLayer(d_model, n_heads, d_ff)
    decoder_layer = TransformerDecoderLayer(d_model, n_heads, d_ff)
    
    # Test inputs
    encoder_input = torch.randn(batch_size, seq_len, d_model)
    decoder_input = torch.randn(batch_size, seq_len, d_model)
    
    # Create masks
    src_seq = torch.randint(1, 1000, (batch_size, seq_len))
    tgt_seq = torch.randint(1, 1000, (batch_size, seq_len))
    
    src_mask = create_padding_mask(src_seq)
    tgt_mask = create_look_ahead_mask(seq_len).unsqueeze(0).unsqueeze(0)
    
    print(f"✅ Building Blocks Test Results:")
    print(f"   Encoder layer parameters: {sum(p.numel() for p in encoder_layer.parameters()):,}")
    print(f"   Decoder layer parameters: {sum(p.numel() for p in decoder_layer.parameters()):,}")
    
    # Test encoder layer
    encoder_output, encoder_attn = encoder_layer(encoder_input, src_mask)
    print(f"   Encoder output shape: {encoder_output.shape}")
    print(f"   Encoder attention shape: {encoder_attn.shape}")
    
    # Test decoder layer
    decoder_output, self_attn, cross_attn = decoder_layer(
        decoder_input, encoder_output, src_mask, tgt_mask
    )
    print(f"   Decoder output shape: {decoder_output.shape}")
    print(f"   Self-attention shape: {self_attn.shape}")
    print(f"   Cross-attention shape: {cross_attn.shape}")
    
    return encoder_layer, decoder_layer, src_mask, tgt_mask

# Test building blocks
test_enc_layer, test_dec_layer, test_src_mask, test_tgt_mask = test_transformer_blocks()

# Visualize attention masks
def visualize_attention_masks():
    """Visualize different types of attention masks."""
    print("🎭 Visualizing Attention Masks...")
    
    seq_len = 10
    batch_size = 1
    
    # Create sample sequences and masks
    src_seq = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0, 0, 0]])  # Padded sequence
    tgt_seq = torch.tensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])  # Padded sequence
    
    src_mask = create_padding_mask(src_seq)
    look_ahead = create_look_ahead_mask(seq_len)
    tgt_padding_mask = create_padding_mask(tgt_seq)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Source padding mask
    axes[0].imshow(src_mask[0, 0].numpy(), cmap='Blues', vmin=0, vmax=1)
    axes[0].set_title('Source Padding Mask\n(1=attend, 0=ignore)')
    axes[0].set_xlabel('Source Position')
    axes[0].set_ylabel('Query Position')
    
    # Look-ahead mask
    axes[1].imshow(look_ahead.numpy(), cmap='Reds', vmin=0, vmax=1)
    axes[1].set_title('Look-Ahead Mask\n(1=attend, 0=ignore future)')
    axes[1].set_xlabel('Target Position')
    axes[1].set_ylabel('Query Position')
    
    # Combined target mask
    combined_tgt_mask = tgt_padding_mask[0, 0] & look_ahead
    axes[2].imshow(combined_tgt_mask.numpy(), cmap='Greens', vmin=0, vmax=1)
    axes[2].set_title('Combined Target Mask\n(padding + look-ahead)')
    axes[2].set_xlabel('Target Position')
    axes[2].set_ylabel('Query Position')
    
    plt.tight_layout()
    plt.savefig(notebook_results_dir / 'attention_masks.png', dpi=300, bbox_inches='tight')
    plt.show()

# Visualize masks
visualize_attention_masks()

print("✅ Transformer building blocks implemented and tested!")
```

## 5. Complete Transformer Architecture

Full encoder-decoder Transformer implementation for sequence-to-sequence tasks.

```python
class TransformerEncoder(nn.Module):
    """
    Stack of Transformer encoder layers.
    
    Processes source sequences with self-attention to create contextualized representations.
    """
    
    def __init__(self, num_layers, d_model, n_heads, d_ff, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.num_layers = num_layers
        
    def forward(self, x, mask=None):
        """
        Forward pass through encoder stack.
        
        Args:
            x: Input embeddings (batch_size, src_seq_len, d_model)
            mask: Source padding mask
            
        Returns:
            output: Encoded representations (batch_size, src_seq_len, d_model)
            attention_weights: List of attention weights from each layer
        """
        attention_weights = []
        
        for layer in self.layers:
            x, attn = layer(x, mask)
            attention_weights.append(attn)
            
        return x, attention_weights

class TransformerDecoder(nn.Module):
    """
    Stack of Transformer decoder layers.
    
    Generates target sequences using masked self-attention and cross-attention to encoder output.
    """
    
    def __init__(self, num_layers, d_model, n_heads, d_ff, dropout=0.1):
        super(TransformerDecoder, self).__init__()
        
        self.layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.num_layers = num_layers
        
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Forward pass through decoder stack.
        
        Args:
            x: Target embeddings (batch_size, tgt_seq_len, d_model)
            encoder_output: Encoder output (batch_size, src_seq_len, d_model)
            src_mask: Source padding mask
            tgt_mask: Target mask (padding + look-ahead)
            
        Returns:
            output: Decoded representations (batch_size, tgt_seq_len, d_model)
            self_attention_weights: List of self-attention weights
            cross_attention_weights: List of cross-attention weights
        """
        self_attention_weights = []
        cross_attention_weights = []
        
        for layer in self.layers:
            x, self_attn, cross_attn = layer(x, encoder_output, src_mask, tgt_mask)
            self_attention_weights.append(self_attn)
            cross_attention_weights.append(cross_attn)
            
        return x, self_attention_weights, cross_attention_weights

class Transformer(nn.Module):
    """
    Complete Transformer model for sequence-to-sequence tasks.
    
    Implements the full architecture from "Attention is All You Need" including:
    - Source and target embeddings
    - Positional encoding
    - Encoder and decoder stacks
    - Output projection layer
    """
    
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_heads=8, 
                 num_encoder_layers=6, num_decoder_layers=6, d_ff=2048, 
                 max_seq_length=5000, dropout=0.1, pad_idx=0, use_learned_pe=False):
        super(Transformer, self).__init__()
        
        self.d_model = d_model
        self.pad_idx = pad_idx
        
        # Embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model, padding_idx=pad_idx)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model, padding_idx=pad_idx)
        
        # Positional encoding
        if use_learned_pe:
            self.pos_encoding = LearnedPositionalEncoding(d_model, max_seq_length, dropout)
        else:
            self.pos_encoding = PositionalEncoding(d_model, max_seq_length, dropout)
        
        # Encoder and Decoder
        self.encoder = TransformerEncoder(num_encoder_layers, d_model, n_heads, d_ff, dropout)
        self.decoder = TransformerDecoder(num_decoder_layers, d_model, n_heads, d_ff, dropout)
        
        # Output projection
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)
        
        # Store architecture parameters for analysis
        self.architecture_config = {
            'src_vocab_size': src_vocab_size,
            'tgt_vocab_size': tgt_vocab_size,
            'd_model': d_model,
            'n_heads': n_heads,
            'num_encoder_layers': num_encoder_layers,
            'num_decoder_layers': num_decoder_layers,
            'd_ff': d_ff,
            'max_seq_length': max_seq_length,
            'dropout': dropout,
            'use_learned_pe': use_learned_pe
        }
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize model weights using Xavier uniform initialization."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.xavier_uniform_(module.weight)
    
    def create_masks(self, src, tgt):
        """
        Create all necessary masks for attention mechanisms.
        
        Args:
            src: Source sequences (batch_size, src_seq_len)
            tgt: Target sequences (batch_size, tgt_seq_len)
            
        Returns:
            src_mask: Source padding mask
            tgt_mask: Combined target mask (padding + look-ahead)
        """
        # Source mask (padding mask)
        src_mask = create_padding_mask(src, self.pad_idx)
        
        # Target mask (padding + look-ahead)
        tgt_seq_len = tgt.size(1)
        tgt_padding_mask = create_padding_mask(tgt, self.pad_idx)
        tgt_look_ahead_mask = create_look_ahead_mask(tgt_seq_len).to(tgt.device)
        
        # Combine masks: both must be True for attention
        tgt_mask = tgt_padding_mask & tgt_look_ahead_mask.unsqueeze(0).unsqueeze(0)
        
        return src_mask, tgt_mask
    
    def encode(self, src, src_mask=None):
        """
        Encode source sequence.
        
        Args:
            src: Source sequences (batch_size, src_seq_len)
            src_mask: Source padding mask
            
        Returns:
            encoder_output: Encoded representations
            encoder_attention: Encoder attention weights
        """
        if src_mask is None:
            src_mask = create_padding_mask(src, self.pad_idx)
            
        # Embedding + positional encoding
        src_embedded = self.src_embedding(src) * math.sqrt(self.d_model)
        src_embedded = self.pos_encoding(src_embedded)
        
        # Encode
        encoder_output, encoder_attention = self.encoder(src_embedded, src_mask)
        
        return encoder_output, encoder_attention
    
    def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
        """
        Decode target sequence.
        
        Args:
            tgt: Target sequences (batch_size, tgt_seq_len)
            encoder_output: Encoder output
            src_mask: Source padding mask
            tgt_mask: Target mask
            
        Returns:
            decoder_output: Decoded representations
            self_attention: Decoder self-attention weights
            cross_attention: Decoder cross-attention weights
        """
        if tgt_mask is None:
            tgt_seq_len = tgt.size(1)
            tgt_mask = create_look_ahead_mask(tgt_seq_len).to(tgt.device)
            tgt_mask = tgt_mask.unsqueeze(0).unsqueeze(0)
            
        # Embedding + positional encoding
        tgt_embedded = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        tgt_embedded = self.pos_encoding(tgt_embedded)
        
        # Decode
        decoder_output, self_attention, cross_attention = self.decoder(
            tgt_embedded, encoder_output, src_mask, tgt_mask
        )
        
        return decoder_output, self_attention, cross_attention
    
    def forward(self, src, tgt):
        """
        Complete forward pass through Transformer.
        
        Args:
            src: Source sequences (batch_size, src_seq_len)
            tgt: Target sequences (batch_size, tgt_seq_len)
            
        Returns:
            output: Logits for next token prediction (batch_size, tgt_seq_len-1, tgt_vocab_size)
            attention_weights: Dictionary of all attention weights
        """
        # Create masks
        src_mask, tgt_mask = self.create_masks(src, tgt)
        
        # Encode source
        encoder_output, encoder_attention = self.encode(src, src_mask)
        
        # Decode target (exclude last token for teacher forcing)
        decoder_input = tgt[:, :-1]
        decoder_tgt_mask = tgt_mask[:, :, :-1, :-1]
        
        decoder_output, self_attention, cross_attention = self.decode(
            decoder_input, encoder_output, src_mask, decoder_tgt_mask
        )
        
        # Project to vocabulary
        output = self.output_projection(decoder_output)
        
        return output, {
            'encoder_attention': encoder_attention,
            'decoder_self_attention': self_attention,
            'decoder_cross_attention': cross_attention
        }
    
    def get_model_info(self):
        """Get comprehensive model information and statistics."""
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        
        # Calculate memory usage (rough estimate)
        memory_usage_mb = total_params * 4 / (1024**2)  # 4 bytes per float32 parameter
        
        # Component parameter breakdown
        component_params = {
            'embeddings': (sum(p.numel() for p in self.src_embedding.parameters()) + 
                          sum(p.numel() for p in self.tgt_embedding.parameters())),
            'positional_encoding': sum(p.numel() for p in self.pos_encoding.parameters()),
            'encoder': sum(p.numel() for p in self.encoder.parameters()),
            'decoder': sum(p.numel() for p in self.decoder.parameters()),
            'output_projection': sum(p.numel() for p in self.output_projection.parameters())
        }
        
        return {
            'architecture': self.architecture_config,
            'parameters': {
                'total': total_params,
                'trainable': trainable_params,
                'component_breakdown': component_params
            },
            'memory_usage_mb': memory_usage_mb,
            'model_size_mb': total_params * 4 / (1024**2)
        }

def test_complete_transformer():
    """Test the complete Transformer architecture."""
    print("🤖 Testing Complete Transformer Architecture...")
    
    # Model parameters
    src_vocab_size = 1000
    tgt_vocab_size = 1000
    d_model = 512
    n_heads = 8
    num_encoder_layers = 6
    num_decoder_layers = 6
    d_ff = 2048
    max_seq_length = 100
    
    # Create Transformer
    transformer = Transformer(
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size,
        d_model=d_model,
        n_heads=n_heads,
        num_encoder_layers=num_encoder_layers,
        num_decoder_layers=num_decoder_layers,
        d_ff=d_ff,
        max_seq_length=max_seq_length,
        use_learned_pe=False
    ).to(device)
    
    # Test input
    batch_size = 4
    src_seq_len = 20
    tgt_seq_len = 15
    
    src = torch.randint(1, src_vocab_size, (batch_size, src_seq_len)).to(device)
    tgt = torch.randint(1, tgt_vocab_size, (batch_size, tgt_seq_len)).to(device)
    
    # Get model info
    model_info = transformer.get_model_info()
    
    print(f"✅ Complete Transformer Test Results:")
    print(f"   Total parameters: {model_info['parameters']['total']:,}")
    print(f"   Memory usage: ~{model_info['memory_usage_mb']:.1f} MB")
    print(f"   Component breakdown:")
    for component, params in model_info['parameters']['component_breakdown'].items():
        percentage = (params / model_info['parameters']['total']) * 100
        print(f"     {component}: {params:,} ({percentage:.1f}%)")
    
    # Forward pass
    with torch.no_grad():
        output, attention_weights = transformer(src, tgt)
    
    print(f"   Input shapes - Source: {src.shape}, Target: {tgt.shape}")
    print(f"   Output shape: {output.shape}")
    print(f"   Expected shape: ({batch_size}, {tgt_seq_len-1}, {tgt_vocab_size})")
    
    # Analyze attention patterns
    encoder_attn = attention_weights['encoder_attention']
    decoder_self_attn = attention_weights['decoder_self_attention']
    decoder_cross_attn = attention_weights['decoder_cross_attention']
    
    print(f"   Attention Analysis:")
    print(f"     Encoder layers: {len(encoder_attn)}, each shape: {encoder_attn[0].shape}")
    print(f"     Decoder self-attention: {len(decoder_self_attn)}, each shape: {decoder_self_attn[0].shape}")
    print(f"     Decoder cross-attention: {len(decoder_cross_attn)}, each shape: {decoder_cross_attn[0].shape}")
    
    return transformer, model_info, (src, tgt), attention_weights

# Test complete Transformer
full_transformer, transformer_info, test_inputs, test_attention_weights = test_complete_transformer()

# Save architecture information
with open(notebook_results_dir / 'transformer_architecture.json', 'w') as f:
    json.dump(transformer_info, f, indent=2)

print("✅ Complete Transformer architecture implemented and tested!")
```

## 6. Training Pipeline and Copy Task

Comprehensive training system with the copy task to validate Transformer functionality.

```python
class CopyTaskDataset(Dataset):
    """
    Copy task dataset for testing Transformer implementation.
    
    The model learns to copy input sequences, which tests:
    - Attention mechanisms
    - Sequence modeling
    - Teacher forcing during training
    - Auto-regressive generation during inference
    """
    
    def __init__(self, vocab_size, seq_len, num_samples, sos_token=1, eos_token=2, pad_token=0):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.num_samples = num_samples
        self.sos_token = sos_token
        self.eos_token = eos_token
        self.pad_token = pad_token
        
        self.data = self._generate_data()
        
        print(f"📝 Copy Task Dataset Created:")
        print(f"   Vocabulary size: {vocab_size}")
        print(f"   Sequence length: {seq_len}")
        print(f"   Number of samples: {num_samples}")
        print(f"   Special tokens - SOS: {sos_token}, EOS: {eos_token}, PAD: {pad_token}")
    
    def _generate_data(self):
        """Generate copy task data with variable sequence lengths."""
        data = []
        
        for _ in range(self.num_samples):
            # Generate random sequence length (3 to seq_len-2 to leave room for SOS/EOS)
            seq_length = torch.randint(3, self.seq_len - 2, (1,)).item()
            
            # Generate random sequence (excluding special tokens 0, 1, 2)
            sequence = torch.randint(3, self.vocab_size, (seq_length,))
            
            # Create source: SOS + sequence + EOS
            src = torch.cat([
                torch.tensor([self.sos_token]), 
                sequence, 
                torch.tensor([self.eos_token])
            ])
            
            # Create target: SOS + sequence + EOS (same as source for copy task)
            tgt = torch.cat([
                torch.tensor([self.sos_token]), 
                sequence, 
                torch.tensor([self.eos_token])
            ])
            
            # Pad to fixed length
            src_len = len(src)
            tgt_len = len(tgt)
            
            if src_len < self.seq_len:
                src = torch.cat([src, torch.tensor([self.pad_token] * (self.seq_len - src_len))])
            
            if tgt_len < self.seq_len:
                tgt = torch.cat([tgt, torch.tensor([self.pad_token] * (self.seq_len - tgt_len))])
            
            # Ensure exact length
            src = src[:self.seq_len]
            tgt = tgt[:self.seq_len]
            
            data.append((src, tgt))
        
        return data
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    def get_sample_info(self, idx):
        """Get human-readable information about a sample."""
        src, tgt = self.data[idx]
        
        # Find actual sequence (between SOS and EOS/PAD)
        src_seq = []
        tgt_seq = []
        
        for token in src:
            if token == self.eos_token or token == self.pad_token:
                break
            if token != self.sos_token:
                src_seq.append(int(token))
        
        for token in tgt:
            if token == self.eos_token or token == self.pad_token:
                break
            if token != self.sos_token:
                tgt_seq.append(int(token))
        
        return {
            'source_sequence': src_seq,
            'target_sequence': tgt_seq,
            'source_full': src.tolist(),
            'target_full': tgt.tolist(),
            'sequence_length': len(src_seq)
        }

class TransformerTrainer:
    """
    Comprehensive Transformer trainer with advanced features.
    
    Includes:
    - Learning rate scheduling
    - Gradient clipping
    - Early stopping
    - Comprehensive metrics tracking
    - Validation and testing
    """
    
    def __init__(self, model, device, pad_token=0):
        self.model = model
        self.device = device
        self.pad_token = pad_token
        
        # Training history
        self.history = {
            'train_loss': [], 'val_loss': [],
            'train_acc': [], 'val_acc': [],
            'learning_rates': [], 'epoch_times': [],
            'train_perplexity': [], 'val_perplexity': []
        }
        
        # Best model tracking
        self.best_val_loss = float('inf')
        self.best_model_state = None
        self.patience_counter = 0
        
        print(f"🚂 Transformer Trainer Initialized")
        print(f"   Device: {device}")
        print(f"   Pad token: {pad_token}")
    
    def calculate_accuracy(self, outputs, targets, pad_token=0):
        """Calculate token-level accuracy ignoring padding tokens."""
        predictions = torch.argmax(outputs, dim=-1)
        
        # Create mask to ignore padding tokens
        mask = (targets != pad_token)
        
        # Calculate accuracy
        correct = (predictions == targets) & mask
        total = mask.sum()
        
        if total == 0:
            return 0.0
        
        return correct.sum().float() / total.float()
    
    def calculate_perplexity(self, loss):
        """Calculate perplexity from cross-entropy loss."""
        return torch.exp(torch.tensor(loss))
    
    def train_epoch(self, dataloader, optimizer, criterion, accumulation_steps=1):
        """Train for one epoch with gradient accumulation."""
        self.model.train()
        total_loss = 0
        total_accuracy = 0
        num_batches = 0
        
        optimizer.zero_grad()
        
        progress_bar = tqdm(dataloader, desc="Training", leave=False)
        for batch_idx, (src, tgt) in enumerate(progress_bar):
            src, tgt = src.to(self.device), tgt.to(self.device)
            
            # Forward pass
            outputs, _ = self.model(src, tgt)
            
            # Calculate loss (exclude SOS token from target)
            targets = tgt[:, 1:]  # Remove SOS token
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1))
            
            # Scale loss for gradient accumulation
            loss = loss / accumulation_steps
            loss.backward()
            
            # Gradient accumulation
            if (batch_idx + 1) % accumulation_steps == 0:
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                optimizer.step()
                optimizer.zero_grad()
            
            # Calculate metrics
            accuracy = self.calculate_accuracy(outputs, targets, self.pad_token)
            
            total_loss += loss.item() * accumulation_steps
            total_accuracy += accuracy.item()
            num_batches += 1
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f'{loss.item() * accumulation_steps:.4f}',
                'acc': f'{accuracy.item():.4f}'
            })
        
        # Final gradient step if needed
        if len(dataloader) % accumulation_steps != 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()
        
        avg_loss = total_loss / num_batches
        avg_accuracy = total_accuracy / num_batches
        
        return avg_loss, avg_accuracy
    
    def evaluate(self, dataloader, criterion):
        """Evaluate the model on given dataloader."""
        self.model.eval()
        total_loss = 0
        total_accuracy = 0
        num_batches = 0
        
        with torch.no_grad():
            progress_bar = tqdm(dataloader, desc="Evaluating", leave=False)
            for src, tgt in progress_bar:
                src, tgt = src.to(self.device), tgt.to(self.device)
                
                outputs, _ = self.model(src, tgt)
                
                targets = tgt[:, 1:]
                loss = criterion(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1))
                
                accuracy = self.calculate_accuracy(outputs, targets, self.pad_token)
                
                total_loss += loss.item()
                total_accuracy += accuracy.item()
                num_batches += 1
                
                # Update progress bar
                progress_bar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{accuracy.item():.4f}'
                })
        
        avg_loss = total_loss / num_batches
        avg_accuracy = total_accuracy / num_batches
        
        return avg_loss, avg_accuracy
    
    def train(self, train_loader, val_loader, num_epochs, lr=0.0001, patience=5, 
              accumulation_steps=1, warmup_steps=4000):
        """
        Complete training loop with advanced features.
        
        Args:
            train_loader: Training data loader
            val_loader: Validation data loader
            num_epochs: Number of training epochs
            lr: Base learning rate
            patience: Early stopping patience
            accumulation_steps: Gradient accumulation steps
            warmup_steps: Learning rate warmup steps
        """
        print(f"🚀 Starting Transformer Training...")
        print(f"   Epochs: {num_epochs}")
        print(f"   Base learning rate: {lr}")
        print(f"   Warmup steps: {warmup_steps}")
        print(f"   Patience: {patience}")
        print(f"   Accumulation steps: {accumulation_steps}")
        
        # Loss function
        criterion = nn.CrossEntropyLoss(ignore_index=self.pad_token)
        
        # Optimizer (AdamW with specific parameters from paper)
        optimizer = optim.AdamW(
            self.model.parameters(), 
            lr=lr, 
            betas=(0.9, 0.98), 
            eps=1e-9,
            weight_decay=0.01
        )
        
        # Learning rate scheduler (Transformer paper schedule)
        def lr_lambda(step):
            d_model = self.model.d_model
            step = max(1, step)
            return min(step**(-0.5), step * warmup_steps**(-1.5)) * d_model**(-0.5)
        
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        
        # Training loop
        start_time = time.time()
        
        for epoch in range(num_epochs):
            epoch_start_time = time.time()
            
            # Training
            train_loss, train_acc = self.train_epoch(
                train_loader, optimizer, criterion, accumulation_steps
            )
            
            # Validation
            val_loss, val_acc = self.evaluate(val_loader, criterion)
            
            # Calculate perplexity
            train_perplexity = self.calculate_perplexity(train_loss)
            val_perplexity = self.calculate_perplexity(val_loss)
            
            # Update learning rate
            scheduler.step()
            current_lr = optimizer.param_groups[0]['lr']
            
            # Calculate epoch time
            epoch_time = time.time() - epoch_start_time
            
            # Store history
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_acc'].append(val_acc)
            self.history['train_perplexity'].append(float(train_perplexity))
            self.history['val_perplexity'].append(float(val_perplexity))
            self.history['learning_rates'].append(current_lr)
            self.history['epoch_times'].append(epoch_time)
            
            # Early stopping check
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.best_model_state = copy.deepcopy(self.model.state_dict())
                self.patience_counter = 0
                improvement = "✅"
            else:
                self.patience_counter += 1
                improvement = "⏸️" if self.patience_counter >= patience else ""
            
            # Print progress
            print(f"Epoch {epoch+1:2d}/{num_epochs}: "
                  f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, "
                  f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}, "
                  f"Perplexity={float(val_perplexity):.2f}, "
                  f"LR={current_lr:.2e}, Time={epoch_time:.1f}s {improvement}")
            
            # Early stopping
            if self.patience_counter >= patience:
                print(f"🛑 Early stopping after {epoch+1} epochs")
                break
        
        # Load best model
        if self.best_model_state:
            self.model.load_state_dict(self.best_model_state)
            print(f"🏆 Best model loaded (Val Loss: {self.best_val_loss:.4f})")
        
        total_time = time.time() - start_time
        print(f"✅ Training completed in {total_time:.1f}s")
        print(f"   Average epoch time: {np.mean(self.history['epoch_times']):.1f}s")
        
        return self.best_val_loss

# Create copy task datasets
def create_copy_task_data():
    """Create copy task datasets for training and validation."""
    print("📝 Creating Copy Task Datasets...")
    
    vocab_size = 100
    seq_len = 20
    train_samples = 5000
    val_samples = 1000
    
    train_dataset = CopyTaskDataset(vocab_size, seq_len, train_samples)
    val_dataset = CopyTaskDataset(vocab_size, seq_len, val_samples)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
    
    print(f"✅ Copy Task Data Created:")
    print(f"   Training batches: {len(train_loader)}")
    print(f"   Validation batches: {len(val_loader)}")
    
    # Display sample data
    sample_idx = 0
    sample_info = train_dataset.get_sample_info(sample_idx)
    print(f"   Sample {sample_idx}:")
    print(f"     Source sequence: {sample_info['source_sequence']}")
    print(f"     Target sequence: {sample_info['target_sequence']}")
    print(f"     Sequence length: {sample_info['sequence_length']}")
    
    return train_dataset, val_dataset, train_loader, val_loader, vocab_size, seq_len

# Create datasets and loaders
copy_train_dataset, copy_val_dataset, copy_train_loader, copy_val_loader, copy_vocab_size, copy_seq_len = create_copy_task_data()

# Create smaller Transformer for training
def create_training_transformer():
    """Create appropriately sized Transformer for copy task training."""
    print("🔧 Creating Training Transformer...")
    
    training_transformer = Transformer(
        src_vocab_size=copy_vocab_size,
        tgt_vocab_size=copy_vocab_size,
        d_model=256,
        n_heads=8,
        num_encoder_layers=3,
        num_decoder_layers=3,
        d_ff=1024,
        max_seq_length=copy_seq_len,
        dropout=0.1,
        pad_idx=0,
        use_learned_pe=False
    ).to(device)
    
    model_info = training_transformer.get_model_info()
    print(f"✅ Training Transformer Created:")
    print(f"   Parameters: {model_info['parameters']['total']:,}")
    print(f"   Memory usage: ~{model_info['memory_usage_mb']:.1f} MB")
    
    return training_transformer

# Create training model
training_transformer = create_training_transformer()

# Train the model
print("\n🚂 STARTING TRANSFORMER TRAINING ON COPY TASK")
print("=" * 60)

trainer = TransformerTrainer(training_transformer, device, pad_token=0)

# Start training
import time
training_start_time = time.time()

best_val_loss = trainer.train(
    train_loader=copy_train_loader,
    val_loader=copy_val_loader,
    num_epochs=15,
    lr=0.0001,
    patience=7,
    accumulation_steps=1,
    warmup_steps=1000
)

training_end_time = time.time()
total_training_time = training_end_time - training_start_time

print(f"\n🎯 Training Summary:")
print(f"   Best validation loss: {best_val_loss:.4f}")
print(f"   Total training time: {total_training_time:.1f}s")
print(f"   Final validation accuracy: {trainer.history['val_acc'][-1]:.4f}")

# Save trained model
model_save_path = notebook_results_dir / 'trained_transformer.pth'
torch.save({
    'model_state_dict': training_transformer.state_dict(),
    'model_config': training_transformer.architecture_config,
    'training_history': trainer.history,
    'best_val_loss': best_val_loss,
    'total_training_time': total_training_time
}, model_save_path)

print(f"💾 Model saved to: {model_save_path}")

print("✅ Transformer training completed successfully!")
```

## 7. Training Analysis and Visualization

Comprehensive analysis of training progress and model performance.

```python
def plot_training_history(trainer, save_path=None):
    """Create comprehensive training history visualization."""
    history = trainer.history
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss curves
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Training Loss', linewidth=2)
    axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training and Validation Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Accuracy curves
    axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Training Accuracy', linewidth=2)
    axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Validation Accuracy', linewidth=2)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].set_title('Training and Validation Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Perplexity curves
    axes[0, 2].plot(epochs, history['train_perplexity'], 'b-', label='Training Perplexity', linewidth=2)
    axes[0, 2].plot(epochs, history['val_perplexity'], 'r-', label='Validation Perplexity', linewidth=2)
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('Perplexity')
    axes[0, 2].set_title('Training and Validation Perplexity')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    # Learning rate schedule
    axes[1, 0].plot(epochs, history['learning_rates'], 'g-', linewidth=2)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Learning Rate')
    axes[1, 0].set_title('Learning Rate Schedule')
    axes[1, 0].set_yscale('log')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Epoch times
    axes[1, 1].plot(epochs, history['epoch_times'], 'm-', linewidth=2)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Time (seconds)')
    axes[1, 1].set_title('Training Time per Epoch')
    axes[1, 1].grid(True, alpha=0.3)
    
    # Overfitting indicator
    loss_gap = [val - train for train, val in zip(history['train_loss'], history['val_loss'])]
    axes[1, 2].plot(epochs, loss_gap, 'orange', linewidth=2)
    axes[1, 2].axhline(y=0, color='black', linestyle='-', alpha=0.5)
    axes[1, 2].set_xlabel('Epoch')
    axes[1, 2].set_ylabel('Val Loss - Train Loss')
    axes[1, 2].set_title('Overfitting Indicator')
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.suptitle('Transformer Training History', fontsize=16)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

# Plot training history
plot_training_history(trainer, notebook_results_dir / 'training_history.png')

def analyze_training_metrics(trainer):
    """Analyze training metrics and provide insights."""
    history = trainer.history
    
    print("📊 Training Metrics Analysis:")
    print("=" * 50)
    
    # Final metrics
    final_train_loss = history['train_loss'][-1]
    final_val_loss = history['val_loss'][-1]
    final_train_acc = history['train_acc'][-1]
    final_val_acc = history['val_acc'][-1]
    
    print(f"📈 Final Performance:")
    print(f"   Training Loss: {final_train_loss:.4f}")
    print(f"   Validation Loss: {final_val_loss:.4f}")
    print(f"   Training Accuracy: {final_train_acc:.4f}")
    print(f"   Validation Accuracy: {final_val_acc:.4f}")
    
    # Convergence analysis
    loss_improvement = history['val_loss'][0] - history['val_loss'][-1]
    acc_improvement = history['val_acc'][-1] - history['val_acc'][0]
    
    print(f"\n📊 Learning Progress:")
    print(f"   Loss improvement: {loss_improvement:.4f}")
    print(f"   Accuracy improvement: {acc_improvement:.4f}")
    print(f"   Best validation loss: {trainer.best_val_loss:.4f}")
    
    # Overfitting analysis
    final_gap = final_val_loss - final_train_loss
    print(f"\n🔍 Overfitting Analysis:")
    print(f"   Final train-val gap: {final_gap:.4f}")
    if final_gap < 0.1:
        print("   Status: ✅ No significant overfitting")
    elif final_gap < 0.3:
        print("   Status: ⚠️ Mild overfitting")
    else:
        print("   Status: ❌ Significant overfitting")
    
    # Training efficiency
    avg_epoch_time = np.mean(history['epoch_times'])
    total_time = sum(history['epoch_times'])
    
    print(f"\n⏱️ Training Efficiency:")
    print(f"   Average epoch time: {avg_epoch_time:.1f}s")
    print(f"   Total training time: {total_time:.1f}s")
    print(f"   Samples per second: {len(copy_train_dataset) * len(history['train_loss']) / total_time:.1f}")
    
    return {
        'final_metrics': {
            'train_loss': final_train_loss,
            'val_loss': final_val_loss,
            'train_acc': final_train_acc,
            'val_acc': final_val_acc
        },
        'improvements': {
            'loss_improvement': loss_improvement,
            'acc_improvement': acc_improvement
        },
        'overfitting': {
            'train_val_gap': final_gap
        },
        'efficiency': {
            'avg_epoch_time': avg_epoch_time,
            'total_time': total_time
        }
    }

# Analyze training metrics
training_analysis = analyze_training_metrics(trainer)

print("✅ Training analysis completed!")
```

## 8. Model Testing and Copy Task Evaluation

Comprehensive testing of the trained Transformer on the copy task.

```python
def test_copy_task_performance(model, dataset, device, num_samples=10):
    """Test the trained model on copy task with detailed analysis."""
    print("🧪 Testing Copy Task Performance:")
    print("=" * 50)
    
    model.eval()
    
    correct_copies = 0
    total_samples = 0
    detailed_results = []
    
    for i in range(num_samples):
        src, tgt = dataset[i]
        src_batch = src.unsqueeze(0).to(device)
        tgt_batch = tgt.unsqueeze(0).to(device)
        
        with torch.no_grad():
            outputs, attention_weights = model(src_batch, tgt_batch)
            predictions = torch.argmax(outputs, dim=-1)
            
            # Get sample info
            sample_info = dataset.get_sample_info(i)
            
            # Extract predictions (excluding SOS token)
            pred_seq = predictions.squeeze(0).cpu().numpy()
            pred_tokens = []
            
            for j, token in enumerate(pred_seq):
                if j >= len(sample_info['target_sequence']) - 1:  # -1 for SOS
                    break
                pred_tokens.append(int(token))
            
            # Compare with target sequence (excluding SOS and EOS)
            target_seq = sample_info['target_sequence'][:-1] if sample_info['target_sequence'] and sample_info['target_sequence'][-1] == 2 else sample_info['target_sequence']
            source_seq = sample_info['source_sequence']
            
            # Check if copy is correct
            is_correct = (pred_tokens == target_seq)
            if is_correct:
                correct_copies += 1
            
            total_samples += 1
            
            # Store detailed results
            detailed_results.append({
                'sample_id': i,
                'source': source_seq,
                'target': target_seq,
                'predicted': pred_tokens,
                'correct': is_correct,
                'attention_weights': attention_weights
            })
            
            # Display sample
            print(f"\nSample {i+1}:")
            print(f"  Source:    {source_seq}")
            print(f"  Target:    {target_seq}")
            print(f"  Predicted: {pred_tokens}")
            print(f"  Correct:   {'✅' if is_correct else '❌'}")
    
    accuracy = correct_copies / total_samples
    print(f"\n🎯 Copy Task Results:")
    print(f"   Accuracy: {accuracy:.2%} ({correct_copies}/{total_samples})")
    print(f"   Perfect copies: {correct_copies}")
    print(f"   Failed copies: {total_samples - correct_copies}")
    
    return accuracy, detailed_results

# Test copy task performance
copy_accuracy, copy_results = test_copy_task_performance(
    training_transformer, copy_val_dataset, device, num_samples=10
)

def analyze_copy_task_errors(results):
    """Analyze patterns in copy task errors."""
    print("\n🔍 Copy Task Error Analysis:")
    print("=" * 40)
    
    correct_samples = [r for r in results if r['correct']]
    incorrect_samples = [r for r in results if not r['correct']]
    
    print(f"📊 Success Rate Analysis:")
    print(f"   Correct copies: {len(correct_samples)}")
    print(f"   Incorrect copies: {len(incorrect_samples)}")
    
    if incorrect_samples:
        print(f"\n❌ Error Patterns:")
        
        # Analyze error types
        length_errors = 0
        position_errors = []
        
        for sample in incorrect_samples:
            source = sample['source']
            target = sample['target']
            predicted = sample['predicted']
            
            # Length analysis
            if len(predicted) != len(target):
                length_errors += 1
            
            # Position error analysis
            min_len = min(len(target), len(predicted))
            for pos in range(min_len):
                if pos < len(target) and pos < len(predicted):
                    if target[pos] != predicted[pos]:
                        position_errors.append(pos)
        
        print(f"   Length errors: {length_errors}")
        if position_errors:
            position_error_freq = Counter(position_errors)
            print(f"   Most error-prone positions: {position_error_freq.most_common(3)}")
        
        # Show detailed error examples
        print(f"\n🔎 Error Examples:")
        for i, sample in enumerate(incorrect_samples[:3]):
            print(f"   Example {i+1}:")
            print(f"     Source:    {sample['source']}")
            print(f"     Expected:  {sample['target']}")
            print(f"     Predicted: {sample['predicted']}")
    else:
        print("   🎉 No errors found in sample!")
    
    return len(correct_samples), len(incorrect_samples)

# Analyze copy task errors
correct_count, error_count = analyze_copy_task_errors(copy_results)

print("✅ Copy task evaluation completed!")
```

## 9. Attention Pattern Visualization and Analysis

Comprehensive visualization and analysis of learned attention patterns.

```python
def visualize_transformer_attention(model, src, tgt, sample_idx=0):
    """Create comprehensive attention visualization for trained Transformer."""
    print("🎨 Visualizing Transformer Attention Patterns...")
    
    model.eval()
    
    with torch.no_grad():
        outputs, attention_weights = model(src, tgt)
        
        # Extract attention weights
        encoder_attn = attention_weights['encoder_attention']
        decoder_self_attn = attention_weights['decoder_self_attention'] 
        decoder_cross_attn = attention_weights['decoder_cross_attention']
        
        # Focus on specified sample
        sample_encoder_attn = [layer_attn[sample_idx] for layer_attn in encoder_attn]
        sample_decoder_self_attn = [layer_attn[sample_idx] for layer_attn in decoder_self_attn]
        sample_decoder_cross_attn = [layer_attn[sample_idx] for layer_attn in decoder_cross_attn]
        
        # Create comprehensive visualization
        fig = plt.figure(figsize=(20, 16))
        
        # 1. Encoder Self-Attention (Layer 0, multiple heads)
        plt.subplot(4, 4, 1)
        enc_attn_avg = sample_encoder_attn[0].mean(0).cpu().numpy()
        sns.heatmap(enc_attn_avg, cmap='Blues', cbar=True, square=True)
        plt.title('Encoder Self-Attention\n(Layer 0, All Heads Avg)')
        plt.xlabel('Key Position')
        plt.ylabel('Query Position')
        
        # 2-4. Encoder Self-Attention (individual heads)
        for head in range(3):
            plt.subplot(4, 4, 2 + head)
            enc_head_attn = sample_encoder_attn[0][head].cpu().numpy()
            sns.heatmap(enc_head_attn, cmap='Blues', cbar=True, square=True)
            plt.title(f'Encoder Head {head + 1}')
            plt.xlabel('Key Position')
            plt.ylabel('Query Position')
        
        # 5. Decoder Self-Attention (Layer 0, average)
        plt.subplot(4, 4, 5)
        dec_self_attn_avg = sample_decoder_self_attn[0].mean(0).cpu().numpy()
        sns.heatmap(dec_self_attn_avg, cmap='Reds', cbar=True, square=True)
        plt.title('Decoder Self-Attention\n(Layer 0, All Heads Avg)')
        plt.xlabel('Key Position')
        plt.ylabel('Query Position')
        
        # 6-8. Decoder Self-Attention (individual heads)
        for head in range(3):
            plt.subplot(4, 4, 6 + head)
            dec_self_head_attn = sample_decoder_self_attn[0][head].cpu().numpy()
            sns.heatmap(dec_self_head_attn, cmap='Reds', cbar=True, square=True)
            plt.title(f'Decoder Self Head {head + 1}')
            plt.xlabel('Key Position')
            plt.ylabel('Query Position')
        
        # 9. Decoder Cross-Attention (Layer 0, average)
        plt.subplot(4, 4, 9)
        dec_cross_attn_avg = sample_decoder_cross_attn[0].mean(0).cpu().numpy()
        sns.heatmap(dec_cross_attn_avg, cmap='Greens', cbar=True)
        plt.title('Decoder Cross-Attention\n(Layer 0, All Heads Avg)')
        plt.xlabel('Encoder Position')
        plt.ylabel('Decoder Position')
        
        # 10-12. Decoder Cross-Attention (individual heads)
        for head in range(3):
            plt.subplot(4, 4, 10 + head)
            dec_cross_head_attn = sample_decoder_cross_attn[0][head].cpu().numpy()
            sns.heatmap(dec_cross_head_attn, cmap='Greens', cbar=True)
            plt.title(f'Cross-Attention Head {head + 1}')
            plt.xlabel('Encoder Position')
            plt.ylabel('Decoder Position')
        
        # 13. Attention across layers (encoder)
        plt.subplot(4, 4, 13)
        layer_avg_attn = []
        for layer_attn in sample_encoder_attn:
            layer_avg = layer_attn.mean().cpu().numpy()
            layer_avg_attn.append(layer_avg)
        
        plt.bar(range(len(layer_avg_attn)), layer_avg_attn, color='skyblue', alpha=0.8)
        plt.title('Average Attention by Encoder Layer')
        plt.xlabel('Layer')
        plt.ylabel('Average Attention')
        plt.grid(True, alpha=0.3)
        
        # 14. Head specialization analysis
        plt.subplot(4, 4, 14)
        head_entropies = []
        for head in range(sample_encoder_attn[0].size(0)):
            head_attn = sample_encoder_attn[0][head].cpu().numpy()
            # Calculate entropy for this head
            attn_flat = head_attn.flatten()
            attn_probs = attn_flat + 1e-10
            entropy = -np.sum(attn_probs * np.log(attn_probs))
            head_entropies.append(entropy)
        
        plt.bar(range(len(head_entropies)), head_entropies, color='orange', alpha=0.8)
        plt.title('Attention Entropy by Head\n(Layer 0, Encoder)')
        plt.xlabel('Head')
        plt.ylabel('Entropy')
        plt.grid(True, alpha=0.3)
        
        # 15. Cross-attention alignment visualization
        plt.subplot(4, 4, 15)
        # Show how well decoder positions align with encoder positions
        cross_attn_matrix = dec_cross_attn_avg
        alignment_scores = []
        for dec_pos in range(cross_attn_matrix.shape[0]):
            # Find which encoder position gets most attention
            max_attn_pos = np.argmax(cross_attn_matrix[dec_pos, :])
            alignment_scores.append(max_attn_pos)
        
        plt.plot(range(len(alignment_scores)), alignment_scores, 'o-', color='purple', alpha=0.8)
        plt.plot(range(len(alignment_scores)), range(len(alignment_scores)), '--', color='gray', alpha=0.5, label='Perfect Alignment')
        plt.title('Cross-Attention Alignment')
        plt.xlabel('Decoder Position')
        plt.ylabel('Most Attended Encoder Position')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # 16. Attention pattern summary
        plt.subplot(4, 4, 16)
        plt.text(0.1, 0.8, f'Sample: {sample_idx}', fontsize=12, transform=plt.gca().transAxes)
        plt.text(0.1, 0.7, f'Encoder Layers: {len(encoder_attn)}', fontsize=10, transform=plt.gca().transAxes)
        plt.text(0.1, 0.6, f'Decoder Layers: {len(decoder_self_attn)}', fontsize=10, transform=plt.gca().transAxes)
        plt.text(0.1, 0.5, f'Attention Heads: {sample_encoder_attn[0].size(0)}', fontsize=10, transform=plt.gca().transAxes)
        plt.text(0.1, 0.4, f'Seq Length: {sample_encoder_attn[0].size(1)}', fontsize=10, transform=plt.gca().transAxes)
        plt.text(0.1, 0.2, f'Avg Encoder Attn: {np.mean(layer_avg_attn):.4f}', fontsize=10, transform=plt.gca().transAxes)
        plt.title('Attention Summary')
        plt.axis('off')
        
        plt.suptitle('Comprehensive Transformer Attention Analysis', fontsize=16)
        plt.tight_layout()
        plt.savefig(notebook_results_dir / 'comprehensive_attention_analysis.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        return {
            'encoder_attention': encoder_attn,
            'decoder_self_attention': decoder_self_attn,
            'decoder_cross_attention': decoder_cross_attn,
            'attention_summary': {
                'avg_encoder_attention': np.mean(layer_avg_attn),
                'head_entropies': head_entropies,
                'alignment_scores': alignment_scores
            }
        }

# Visualize attention patterns
sample_src, sample_tgt = copy_val_dataset[0]
sample_src_batch = sample_src.unsqueeze(0).to(device)
sample_tgt_batch = sample_tgt.unsqueeze(0).to(device)

attention_analysis = visualize_transformer_attention(
    training_transformer, sample_src_batch, sample_tgt_batch, sample_idx=0
)

def analyze_attention_patterns(attention_analysis):
    """Analyze learned attention patterns for insights."""
    print("\n🧠 Attention Pattern Analysis:")
    print("=" * 50)
    
    summary = attention_analysis['attention_summary']
    
    print(f"📊 Attention Statistics:")
    print(f"   Average encoder attention: {summary['avg_encoder_attention']:.4f}")
    print(f"   Head entropy range: {min(summary['head_entropies']):.2f} - {max(summary['head_entropies']):.2f}")
    
    # Analyze head specialization
    head_entropies = summary['head_entropies']
    low_entropy_heads = sum(1 for e in head_entropies if e < np.mean(head_entropies) - np.std(head_entropies))
    high_entropy_heads = sum(1 for e in head_entropies if e > np.mean(head_entropies) + np.std(head_entropies))
    
    print(f"\n🎯 Head Specialization:")
    print(f"   Focused heads (low entropy): {low_entropy_heads}")
    print(f"   Broad heads (high entropy): {high_entropy_heads}")
    print(f"   Regular heads: {len(head_entropies) - low_entropy_heads - high_entropy_heads}")
    
    # Analyze cross-attention alignment
    alignment_scores = summary['alignment_scores']
    perfect_alignment = sum(1 for i, pos in enumerate(alignment_scores) if abs(pos - i) <= 1)
    alignment_rate = perfect_alignment / len(alignment_scores)
    
    print(f"\n🎯 Cross-Attention Alignment:")
    print(f"   Perfect/near-perfect alignment: {perfect_alignment}/{len(alignment_scores)} ({alignment_rate:.1%})")
    
    if alignment_rate > 0.8:
        print("   Status: ✅ Excellent alignment - model learned to copy well")
    elif alignment_rate > 0.6:
        print("   Status: ⚠️ Good alignment - model mostly learned to copy")
    else:
        print("   Status: ❌ Poor alignment - model struggled to learn copying")
    
    return {
        'attention_focus': {
            'focused_heads': low_entropy_heads,
            'broad_heads': high_entropy_heads
        },
        'alignment_quality': {
            'alignment_rate': alignment_rate,
            'perfect_alignments': perfect_alignment
        }
    }

# Analyze attention patterns
pattern_analysis = analyze_attention_patterns(attention_analysis)

print("✅ Attention analysis completed!")
```

## 10. Comprehensive Results Summary and Conclusions

Final analysis, insights, and comprehensive documentation of the Transformer implementation.

```python
def generate_comprehensive_results():
    """Generate comprehensive results summary of the Transformer implementation."""
    print("\n" + "="*80)
    print("📊 COMPREHENSIVE TRANSFORMER IMPLEMENTATION RESULTS")
    print("="*80)
    
    # Collect all results
    results_summary = {
        'implementation_info': {
            'timestamp': pd.Timestamp.now().isoformat(),
            'device_used': str(device),
            'pytorch_version': torch.__version__
        },
        'architecture_details': training_transformer.architecture_config,
        'model_statistics': training_transformer.get_model_info(),
        'training_results': {
            'best_validation_loss': trainer.best_val_loss,
            'final_metrics': training_analysis['final_metrics'],
            'training_time': training_analysis['efficiency']['total_time'],
            'convergence_epochs': len(trainer.history['train_loss'])
        },
        'copy_task_performance': {
            'accuracy': copy_accuracy,
            'correct_samples': correct_count,
            'total_samples': correct_count + error_count
        },
        'attention_analysis': {
            'pattern_quality': pattern_analysis['alignment_quality'],
            'head_specialization': pattern_analysis['attention_focus']
        }
    }
    
    # Display comprehensive summary
    print(f"\n🤖 ARCHITECTURE SUMMARY:")
    arch = results_summary['architecture_details']
    print(f"   Model: Transformer (Encoder-Decoder)")
    print(f"   Parameters: {results_summary['model_statistics']['parameters']['total']:,}")
    print(f"   Layers: {arch['num_encoder_layers']} encoder, {arch['num_decoder_layers']} decoder")
    print(f"   Attention heads: {arch['n_heads']}")
    print(f"   Model dimension: {arch['d_model']}")
    print(f"   Feed-forward dimension: {arch['d_ff']}")
    print(f"   Vocabulary size: {arch['src_vocab_size']}")
    
    print(f"\n📈 TRAINING RESULTS:")
    training_res = results_summary['training_results']
    print(f"   Best validation loss: {training_res['best_validation_loss']:.4f}")
    print(f"   Final validation accuracy: {training_res['final_metrics']['val_acc']:.4f}")
    print(f"   Training time: {training_res['training_time']:.1f} seconds")
    print(f"   Convergence epochs: {training_res['convergence_epochs']}")
    print(f"   Parameters per second: {results_summary['model_statistics']['parameters']['total'] / training_res['training_time']:,.0f}")
    
    print(f"\n🎯 COPY TASK PERFORMANCE:")
    copy_perf = results_summary['copy_task_performance']
    print(f"   Copy accuracy: {copy_perf['accuracy']:.1%}")
    print(f"   Perfect copies: {copy_perf['correct_samples']}/{copy_perf['total_samples']}")
    
    if copy_perf['accuracy'] >= 0.9:
        print("   Status: ✅ Excellent - Model successfully learned sequence copying")
    elif copy_perf['accuracy'] >= 0.7:
        print("   Status: ⚠️ Good - Model partially learned sequence copying")
    else:
        print("   Status: ❌ Poor - Model struggled with sequence copying")
    
    print(f"\n🧠 ATTENTION ANALYSIS:")
    attn_analysis = results_summary['attention_analysis']
    print(f"   Cross-attention alignment: {attn_analysis['pattern_quality']['alignment_rate']:.1%}")
    print(f"   Focused attention heads: {attn_analysis['head_specialization']['focused_heads']}")
    print(f"   Broad attention heads: {attn_analysis['head_specialization']['broad_heads']}")
    
    print(f"\n🔍 KEY INSIGHTS:")
    insights = []
    
    # Training insights
    if training_res['final_metrics']['val_acc'] > 0.9:
        insights.append("✅ Model achieved excellent accuracy on copy task")
    
    if training_res['best_validation_loss'] < 0.1:
        insights.append("✅ Model converged to low loss, indicating good learning")
    
    # Architecture insights
    params_per_layer = results_summary['model_statistics']['parameters']['total'] / (arch['num_encoder_layers'] + arch['num_decoder_layers'])
    if params_per_layer < 1000000:
        insights.append("✅ Efficient architecture with reasonable parameter count per layer")
    
    # Attention insights
    if attn_analysis['pattern_quality']['alignment_rate'] > 0.8:
        insights.append("✅ Attention mechanism learned proper sequence alignment")
    
    if attn_analysis['head_specialization']['focused_heads'] > 0:
        insights.append("✅ Some attention heads specialized for focused attention")
    
    # Performance insights
    throughput = copy_perf['total_samples'] / training_res['training_time'] * training_res['convergence_epochs']
    if throughput > 100:
        insights.append("✅ Good training throughput achieved")
    
    for insight in insights:
        print(f"   {insight}")
    
    print(f"\n📚 IMPLEMENTATION ACHIEVEMENTS:")
    achievements = [
        "🤖 Complete Transformer architecture implemented from scratch",
        "🧠 Multi-head self-attention mechanism with proper scaling",
        "📍 Both sinusoidal and learned positional encoding options",
        "🔄 Full encoder-decoder architecture with proper masking",
        "🚂 Comprehensive training pipeline with modern optimizations",
        "📊 Detailed attention visualization and analysis tools",
        "🎯 Successful validation on copy task demonstrating functionality",
        "📈 Professional-grade code with extensive documentation"
    ]
    
    for achievement in achievements:
        print(f"   {achievement}")
    
    print(f"\n💡 POTENTIAL IMPROVEMENTS:")
    improvements = [
        "🔧 Implement more sophisticated tasks (translation, summarization)",
        "⚡ Add model parallelism for larger architectures",
        "🎯 Implement beam search for better generation quality",
        "📊 Add more comprehensive evaluation metrics",
        "🔍 Implement attention head pruning for efficiency",
        "📈 Add support for different positional encoding schemes",
        "🚀 Optimize for production deployment with ONNX/TensorRT"
    ]
    
    for improvement in improvements:
        print(f"   {improvement}")
    
    return results_summary

# Generate comprehensive results
final_results = generate_comprehensive_results()

# Save all results
def save_all_results():
    """Save comprehensive results and artifacts."""
    print(f"\n💾 Saving Comprehensive Results...")
    
    # Save main results
    with open(notebook_results_dir / 'comprehensive_results.json', 'w') as f:
        json.dump(final_results, f, indent=2, default=str)
    
    # Save training history
    with open(notebook_results_dir / 'training_history.pkl', 'wb') as f:
        pickle.dump(trainer.history, f)
    
    # Save copy task results
    copy_task_results = {
        'accuracy': copy_accuracy,
        'detailed_results': [
            {
                'sample_id': r['sample_id'],
                'source': r['source'],
                'target': r['target'],
                'predicted': r['predicted'],
                'correct': r['correct']
            }
            for r in copy_results
        ]
    }
    
    with open(notebook_results_dir / 'copy_task_results.json', 'w') as f:
        json.dump(copy_task_results, f, indent=2)
    
    # Save attention analysis
    attention_summary = {
        'attention_statistics': attention_analysis['attention_summary'],
        'pattern_analysis': pattern_analysis
    }
    
    with open(notebook_results_dir / 'attention_analysis.json', 'w') as f:
        json.dump(attention_summary, f, indent=2, default=str)
    
    # Create README for results directory
    readme_content = f"""# Transformer from Scratch - Results

## Overview
Complete implementation and training results for Transformer architecture built from scratch.

## Files
- `comprehensive_results.json`: Complete results summary
- `training_history.pkl`: Detailed training metrics
- `copy_task_results.json`: Copy task evaluation results
- `attention_analysis.json`: Attention pattern analysis
- `transformer_architecture.json`: Model architecture details
- `trained_transformer.pth`: Complete trained model checkpoint

## Key Results
- **Copy Task Accuracy**: {copy_accuracy:.1%}
- **Final Validation Loss**: {trainer.best_val_loss:.4f}
- **Total Parameters**: {final_results['model_statistics']['parameters']['total']:,}
- **Training Time**: {final_results['training_results']['training_time']:.1f} seconds

## Visualizations
- `multihead_attention_patterns.png`: Multi-head attention visualization
- `positional_encoding_analysis.png`: Positional encoding analysis
- `attention_masks.png`: Attention mask examples
- `training_history.png`: Training progress curves
- `comprehensive_attention_analysis.png`: Complete attention analysis

## Implementation Status
✅ **COMPLETE**: Full Transformer implementation with successful training and validation

Generated: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}
"""
    
    with open(notebook_results_dir / 'README.md', 'w') as f:
        f.write(readme_content)
    
    # List all files created
    print(f"   📁 Results directory: {notebook_results_dir}")
    all_files = list(notebook_results_dir.glob('*'))
    total_size_mb = sum(f.stat().st_size for f in all_files if f.is_file()) / (1024 * 1024)
    
    print(f"   📄 Files created: {len(all_files)}")
    print(f"   💾 Total size: {total_size_mb:.1f} MB")
    
    for file_path in sorted(all_files):
        if file_path.is_file():
            size_mb = file_path.stat().st_size / (1024 * 1024)
            print(f"     📄 {file_path.name} ({size_mb:.2f} MB)")

# Save all results
save_all_results()

print(f"\n" + "="*80)
print("🎉 TRANSFORMER FROM SCRATCH - IMPLEMENTATION COMPLETE!")
print("="*80)

implementation_summary = f"""
🏆 **SUCCESS METRICS**:
   ✅ Complete Transformer architecture: {final_results['model_statistics']['parameters']['total']:,} parameters
   ✅ Successful training convergence: {final_results['training_results']['convergence_epochs']} epochs
   ✅ Copy task mastery: {copy_accuracy:.1%} accuracy
   ✅ Attention mechanism validation: {pattern_analysis['alignment_quality']['alignment_rate']:.1%} alignment
   ✅ Professional implementation: Comprehensive documentation and analysis

🎯 **TECHNICAL ACHIEVEMENTS**:
   🤖 Multi-head attention with {final_results['architecture_details']['n_heads']} heads
   📍 Positional encoding (sinusoidal and learned options)
   🔄 Complete encoder-decoder with proper masking
   🚂 Modern training pipeline with warmup and scheduling
   👁️ Comprehensive attention visualization system
   📊 Detailed performance analysis and metrics

🚀 **READY FOR**:
   📚 Advanced sequence-to-sequence tasks
   🔬 Research experimentation and modification
   🏭 Production deployment and scaling
   📖 Educational demonstrations and tutorials
   🎯 Further architectural improvements
"""

print(implementation_summary)
print("✨ Transformer Implementation Journey Complete! ✨")
```