# Day 6: Scaled Dot-Product Attention - Implementation and Examples

This notebook contains all the code examples, visualizations, and hands-on exercises for Day 6 of Week 2.

## Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set up plotting
plt.style.use('default')
sns.set_palette('husl')

## 1. Understanding Q, K, V Components

In [None]:
def explain_attention_components():
    """Explain each component of attention with intuitive examples."""
    
    print("Understanding Q, K, V with Intuitive Examples")
    print("=" * 50)
    
    # Example: "The cat sat on the mat"
    tokens = ["The", "cat", "sat", "on", "the", "mat"]
    
    print("Example sentence: 'The cat sat on the mat'")
    print("\nIntuitive Understanding:")
    print("- Query (Q): 'What should I pay attention to?'")
    print("- Key (K): 'What information do I have?'")
    print("- Value (V): 'What is the actual content?'")
    
    print("\nFor token 'sat':")
    print("- Query: Looking for subject and object relationships")
    print("- Keys: All tokens offer their relationship information")
    print("- Values: Actual semantic content of each token")
    
    print("\nAttention weights tell us:")
    print("- How much 'sat' should focus on 'cat' (subject)")
    print("- How much 'sat' should focus on 'mat' (object)")
    print("- Less attention to articles 'the', 'the'")

explain_attention_components()

## 2. Basic Attention Implementation

In [None]:
class ScaledDotProductAttention(nn.Module):
    """Scaled Dot-Product Attention implementation from scratch."""
    
    def __init__(self, d_model, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: [batch_size, seq_len, d_model]
            key: [batch_size, seq_len, d_model]  
            value: [batch_size, seq_len, d_model]
            mask: [batch_size, seq_len, seq_len] or None
        
        Returns:
            output: [batch_size, seq_len, d_model]
            attention_weights: [batch_size, seq_len, seq_len]
        """
        batch_size, seq_len, d_model = query.size()
        
        # Step 1: Compute attention scores (Q·K^T)
        scores = torch.matmul(query, key.transpose(-2, -1))
        
        # Step 2: Scale by √d_k
        scores = scores / np.sqrt(d_model)
        
        # Step 3: Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Step 4: Apply softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Step 5: Apply attention weights to values
        output = torch.matmul(attention_weights, value)
        
        return output, attention_weights

In [None]:
# Test the implementation
def test_basic_attention():
    """Test basic attention with simple example."""
    
    batch_size, seq_len, d_model = 1, 4, 8
    
    # Create simple input embeddings
    embeddings = torch.randn(batch_size, seq_len, d_model)
    
    # Initialize attention
    attention = ScaledDotProductAttention(d_model)
    
    # Self-attention: Q, K, V are all the same
    output, weights = attention(embeddings, embeddings, embeddings)
    
    print("Basic Attention Test:")
    print(f"Input shape: {embeddings.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Attention weights shape: {weights.shape}")
    print(f"Attention weights sum (should be ~1.0): {weights.sum(dim=-1)}")
    
    return output, weights

output, weights = test_basic_attention()

## 3. Manual Computation Example

In [None]:
def manual_attention_computation():
    """Compute attention manually for a small example."""
    
    print("Manual Attention Computation")
    print("=" * 40)
    
    # Small example: 3 tokens, 4 dimensions
    seq_len, d_model = 3, 4
    
    # Create simple Q, K, V matrices
    Q = torch.tensor([
        [1.0, 0.0, 1.0, 0.0],  # Token 1 query
        [0.0, 1.0, 0.0, 1.0],  # Token 2 query  
        [1.0, 1.0, 0.0, 0.0]   # Token 3 query
    ], dtype=torch.float32)
    
    K = torch.tensor([
        [1.0, 0.0, 0.0, 1.0],  # Token 1 key
        [0.0, 1.0, 1.0, 0.0],  # Token 2 key
        [1.0, 0.0, 1.0, 1.0]   # Token 3 key
    ], dtype=torch.float32)
    
    V = torch.tensor([
        [2.0, 0.0, 1.0, 0.0],  # Token 1 value
        [0.0, 2.0, 0.0, 1.0],  # Token 2 value
        [1.0, 1.0, 2.0, 2.0]   # Token 3 value
    ], dtype=torch.float32)
    
    print("Query matrix Q:")
    print(Q.numpy())
    print("\nKey matrix K:")
    print(K.numpy())
    print("\nValue matrix V:")
    print(V.numpy())
    
    # Step 1: Compute Q·K^T
    scores = torch.matmul(Q, K.transpose(0, 1))
    print(f"\nStep 1 - Attention scores (Q·K^T):")
    print(scores.numpy())
    
    # Step 2: Scale by √d_k
    scaled_scores = scores / np.sqrt(d_model)
    print(f"\nStep 2 - Scaled scores (÷√{d_model} = ÷{np.sqrt(d_model):.2f}):")
    print(scaled_scores.numpy())
    
    # Step 3: Apply softmax
    attention_weights = F.softmax(scaled_scores, dim=-1)
    print(f"\nStep 3 - Attention weights (softmax):")
    print(attention_weights.numpy())
    
    # Verify weights sum to 1
    print(f"\nWeights sum per row: {attention_weights.sum(dim=-1).numpy()}")
    
    # Step 4: Apply to values
    output = torch.matmul(attention_weights, V)
    print(f"\nStep 4 - Final output (weights × V):")
    print(output.numpy())
    
    return Q, K, V, attention_weights, output

Q, K, V, manual_weights, manual_output = manual_attention_computation()

## 4. Attention Pattern Visualization

In [None]:
def visualize_attention_patterns():
    """Create comprehensive attention visualizations."""
    
    # Create a more interesting example
    seq_len, d_model = 6, 16
    
    # Simulate embeddings for: "The cat sat on the mat"
    tokens = ["The", "cat", "sat", "on", "the", "mat"]
    
    # Create embeddings with some structure
    torch.manual_seed(42)
    embeddings = torch.randn(1, seq_len, d_model)
    
    # Make some tokens more similar (e.g., "The" and "the")
    embeddings[0, 4] = embeddings[0, 0] + 0.1 * torch.randn(d_model)
    
    # Make "cat" and "mat" somewhat similar (both nouns)
    embeddings[0, 5] = embeddings[0, 1] + 0.3 * torch.randn(d_model)
    
    # Apply attention
    attention = ScaledDotProductAttention(d_model)
    output, weights = attention(embeddings, embeddings, embeddings)
    
    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # 1. Attention heatmap
    sns.heatmap(weights[0].detach().numpy(), 
                xticklabels=tokens, yticklabels=tokens,
                annot=True, fmt='.3f', cmap='Blues',
                ax=axes[0, 0])
    axes[0, 0].set_title('Attention Weights Heatmap')
    axes[0, 0].set_xlabel('Key (attending to)')
    axes[0, 0].set_ylabel('Query (attending from)')
    
    # 2. Attention weights for specific token
    token_idx = 2  # "sat"
    axes[0, 1].bar(tokens, weights[0, token_idx].detach().numpy())
    axes[0, 1].set_title(f'Attention weights for "{tokens[token_idx]}"')
    axes[0, 1].set_ylabel('Attention Weight')
    axes[0, 1].tick_params(axis='x', rotation=45)
    
    # 3. Raw attention scores (before softmax)
    raw_scores = torch.matmul(embeddings, embeddings.transpose(-2, -1)) / np.sqrt(d_model)
    sns.heatmap(raw_scores[0].detach().numpy(),
                xticklabels=tokens, yticklabels=tokens,
                annot=True, fmt='.2f', cmap='RdBu_r', center=0,
                ax=axes[1, 0])
    axes[1, 0].set_title('Raw Attention Scores (before softmax)')
    
    # 4. Attention entropy (how focused/distributed)
    entropy = -torch.sum(weights * torch.log(weights + 1e-9), dim=-1)
    axes[1, 1].bar(tokens, entropy[0].detach().numpy())
    axes[1, 1].set_title('Attention Entropy (higher = more distributed)')
    axes[1, 1].set_ylabel('Entropy')
    axes[1, 1].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    # Print interpretation
    print("Attention Pattern Interpretation:")
    print("=" * 40)
    
    for i, token in enumerate(tokens):
        top_attention = torch.topk(weights[0, i], 2)
        top_tokens = [tokens[idx] for idx in top_attention.indices]
        top_weights = top_attention.values
        
        print(f"'{token}' attends most to:")
        for j, (att_token, weight) in enumerate(zip(top_tokens, top_weights)):
            print(f"  {j+1}. '{att_token}' (weight: {weight:.3f})")
        print()
    
    return weights, tokens

attention_weights, tokens = visualize_attention_patterns()

## 5. Scaling Factor Demonstration

In [None]:
def demonstrate_scaling_importance():
    """Show why we need the √d_k scaling factor."""
    
    print("Why Scale by √d_k?")
    print("=" * 30)
    
    # Test with different dimensions
    dimensions = [4, 16, 64, 256]
    seq_len = 4
    
    results = {}
    
    for d_model in dimensions:
        # Create random Q and K
        Q = torch.randn(1, seq_len, d_model)
        K = torch.randn(1, seq_len, d_model)
        
        # Compute scores without scaling
        scores_unscaled = torch.matmul(Q, K.transpose(-2, -1))
        
        # Compute scores with scaling
        scores_scaled = scores_unscaled / np.sqrt(d_model)
        
        # Apply softmax
        weights_unscaled = F.softmax(scores_unscaled, dim=-1)
        weights_scaled = F.softmax(scores_scaled, dim=-1)
        
        # Measure how "sharp" the attention is (entropy)
        entropy_unscaled = -torch.sum(weights_unscaled * torch.log(weights_unscaled + 1e-9), dim=-1).mean()
        entropy_scaled = -torch.sum(weights_scaled * torch.log(weights_scaled + 1e-9), dim=-1).mean()
        
        results[d_model] = {
            'scores_std_unscaled': scores_unscaled.std().item(),
            'scores_std_scaled': scores_scaled.std().item(),
            'entropy_unscaled': entropy_unscaled.item(),
            'entropy_scaled': entropy_scaled.item(),
            'max_weight_unscaled': weights_unscaled.max().item(),
            'max_weight_scaled': weights_scaled.max().item()
        }
    
    # Display results
    print("Dimension | Scores Std (Unscaled) | Scores Std (Scaled) | Max Weight (Unscaled) | Max Weight (Scaled)")
    print("-" * 100)
    
    for d_model, stats in results.items():
        print(f"{d_model:8d} | {stats['scores_std_unscaled']:17.3f} | {stats['scores_std_scaled']:16.3f} | "
              f"{stats['max_weight_unscaled']:18.3f} | {stats['max_weight_scaled']:17.3f}")
    
    print("\nObservations:")
    print("- Without scaling: larger dimensions → larger scores → sharper attention")
    print("- With scaling: attention sharpness remains consistent across dimensions")
    print("- Scaling prevents attention from becoming too concentrated")
    
    return results

scaling_results = demonstrate_scaling_importance()

## 6. Exercise: Hand Computation

In [None]:
def exercise_hand_computation():
    """Exercise: Compute attention by hand for very small example."""
    
    print("Exercise 1: Hand Computation")
    print("=" * 35)
    
    print("Given:")
    print("Q = [[1, 0], [0, 1]]")
    print("K = [[1, 1], [1, 0]]") 
    print("V = [[2, 1], [1, 2]]")
    print("d_k = 2")
    
    print("\nYour task:")
    print("1. Compute QK^T")
    print("2. Scale by √d_k")
    print("3. Apply softmax")
    print("4. Multiply by V")
    
    # Solution
    Q = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
    K = torch.tensor([[1.0, 1.0], [1.0, 0.0]])
    V = torch.tensor([[2.0, 1.0], [1.0, 2.0]])
    
    print("\nSolution:")
    
    # Step 1
    QK = torch.matmul(Q, K.T)
    print(f"1. QK^T = \n{QK.numpy()}")
    
    # Step 2
    scaled = QK / np.sqrt(2)
    print(f"2. Scaled = \n{scaled.numpy()}")
    
    # Step 3
    weights = F.softmax(scaled, dim=-1)
    print(f"3. Softmax = \n{weights.numpy()}")
    
    # Step 4
    output = torch.matmul(weights, V)
    print(f"4. Output = \n{output.numpy()}")

exercise_hand_computation()

## 7. Exercise: Attention Pattern Analysis

In [None]:
def analyze_attention_patterns():
    """Analyze different types of attention patterns."""
    
    print("Types of Attention Patterns")
    print("=" * 35)
    
    seq_len, d_model = 5, 8
    
    # Create different scenarios
    scenarios = {
        'uniform': torch.ones(1, seq_len, d_model),  # All tokens identical
        'sequential': torch.arange(seq_len * d_model).float().view(1, seq_len, d_model),  # Sequential pattern
        'similar_pairs': torch.randn(1, seq_len, d_model)  # Will modify for similarity
    }
    
    # Make pairs similar in the third scenario
    scenarios['similar_pairs'][0, 1] = scenarios['similar_pairs'][0, 0] + 0.1 * torch.randn(d_model)
    scenarios['similar_pairs'][0, 3] = scenarios['similar_pairs'][0, 2] + 0.1 * torch.randn(d_model)
    
    attention = ScaledDotProductAttention(d_model)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    for idx, (name, embeddings) in enumerate(scenarios.items()):
        output, weights = attention(embeddings, embeddings, embeddings)
        
        # Visualize attention pattern
        sns.heatmap(weights[0].detach().numpy(),
                   annot=True, fmt='.2f', cmap='Blues',
                   ax=axes[idx])
        axes[idx].set_title(f'{name.title()} Embeddings')
        axes[idx].set_xlabel('Key Position')
        axes[idx].set_ylabel('Query Position')
    
    plt.tight_layout()
    plt.show()
    
    # Analyze patterns
    for name, embeddings in scenarios.items():
        output, weights = attention(embeddings, embeddings, embeddings)
        
        # Compute attention statistics
        self_attention = torch.diag(weights[0]).mean()  # How much tokens attend to themselves
        max_attention = weights[0].max()  # Maximum attention weight
        entropy = -torch.sum(weights[0] * torch.log(weights[0] + 1e-9), dim=-1).mean()
        
        print(f"\n{name.title()} Pattern:")
        print(f"  Average self-attention: {self_attention:.3f}")
        print(f"  Maximum attention weight: {max_attention:.3f}")
        print(f"  Average entropy: {entropy:.3f}")

analyze_attention_patterns()