In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("Ready to implement Transformer attention components!")


In [None]:
class PositionalEncoding(nn.Module):
    """Positional encoding using sin/cos functions"""
    
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # Create division term for frequencies
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-math.log(10000.0) / d_model))
        
        # Apply sin to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        
        # Apply cos to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension and register as buffer
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """Add positional encoding to input embeddings"""
        return x + self.pe[:x.size(0), :]

def visualize_positional_encoding():
    """Visualize the positional encoding patterns"""
    
    d_model = 128
    max_len = 100
    
    # Create positional encoding
    pos_encoding = PositionalEncoding(d_model, max_len)
    
    # Get the encoding matrix
    pe_matrix = pos_encoding.pe.squeeze().numpy()
    
    # Create visualization
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 10))
    
    # Plot full positional encoding heatmap
    im1 = ax1.imshow(pe_matrix.T, cmap='RdBu', aspect='auto')
    ax1.set_title('Positional Encoding Matrix\n(Position vs Embedding Dimension)')
    ax1.set_xlabel('Position')
    ax1.set_ylabel('Embedding Dimension')
    plt.colorbar(im1, ax=ax1, shrink=0.8)
    
    # Plot specific dimensions over positions
    dimensions_to_plot = [0, 1, 4, 5, 8, 9]
    for dim in dimensions_to_plot:
        ax2.plot(pe_matrix[:50, dim], label=f'Dim {dim}')
    ax2.set_title('Positional Encoding for Different Dimensions')
    ax2.set_xlabel('Position')
    ax2.set_ylabel('Encoding Value')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Plot specific positions across dimensions
    positions_to_plot = [0, 5, 10, 20, 30]
    for pos in positions_to_plot:
        ax3.plot(pe_matrix[pos, :30], label=f'Pos {pos}')
    ax3.set_title('Encoding Patterns for Different Positions')
    ax3.set_xlabel('Embedding Dimension')
    ax3.set_ylabel('Encoding Value')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("POSITIONAL ENCODING PROPERTIES:")
    print("=" * 40)
    print("✅ Each position has a unique encoding pattern")
    print("✅ Sine/cosine allows model to learn relative positions")
    print("✅ Different frequencies for different dimensions")
    print("✅ Smooth patterns enable generalization to longer sequences")
    
    return pe_matrix

# Visualize positional encoding
pe_matrix = visualize_positional_encoding()
