# Attention Mechanism Demo

This notebook demonstrates how attention works in practice with visual examples.


In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)


In [None]:
# Setup course imports (choose one approach)

# APPROACH 1: If you installed the course as a package (recommended)
# pip install -e .  # (run this once from course root)
try:
    from lmcourse.utils import plot_attention_heatmap
    print("✅ Using installed package imports")
except ImportError:
    # APPROACH 2: Fallback to path manipulation
    import sys
    from pathlib import Path
    
    # Add course root to path
    course_root = Path().cwd().parent.parent.parent
    if str(course_root) not in sys.path:
        sys.path.append(str(course_root))
    
    from lmcourse.utils import plot_attention_heatmap
    print("✅ Using path-based imports")


In [None]:
def simple_attention(query, key, value):
    """Simple scaled dot-product attention."""
    scores = torch.matmul(query, key.transpose(-2, -1))
    d_k = key.size(-1)
    scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, value)
    return output, attention_weights


In [None]:
# Create sample data
seq_len = 5
d_model = 4

# Random query, key, value matrices
query = torch.randn(1, seq_len, d_model)
key = torch.randn(1, seq_len, d_model) 
value = torch.randn(1, seq_len, d_model)

# Apply attention
output, weights = simple_attention(query, key, value)

print(f"Input shapes:")
print(f"Query: {query.shape}")
print(f"Key: {key.shape}")
print(f"Value: {value.shape}")
print(f"\nOutput shapes:")
print(f"Output: {output.shape}")
print(f"Attention weights: {weights.shape}")


In [None]:
# Visualize using course utilities
plot_attention_heatmap(
    weights, 
    title="Attention Weights Example",
    figsize=(8, 6)
)
