# Solutions: Task 8.1 - Attention from Scratch

This notebook contains solutions to the exercises from notebook 01.

---

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import seaborn as sns

torch.manual_seed(42)

## Exercise 1: Implement Attention with Dropout

**Task:** Add dropout to the attention weights (commonly used to prevent overfitting).

In [None]:
def scaled_dot_product_attention_with_dropout(Q, K, V, dropout_p=0.1, mask=None, training=True):
    """
    Scaled dot-product attention with dropout.
    
    Args:
        Q, K, V: Query, Key, Value tensors
        dropout_p: Dropout probability
        mask: Optional attention mask
        training: Whether in training mode (dropout active)
    
    Returns:
        output: Attention output
        attention_weights: Attention weights (after dropout)
    """
    d_k = K.size(-1)
    
    # Step 1: Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1))
    
    # Step 2: Scale by sqrt(d_k)
    scores = scores / math.sqrt(d_k)
    
    # Step 3: Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Step 4: Apply softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # Step 5: Apply dropout to attention weights
    if training and dropout_p > 0:
        attention_weights = F.dropout(attention_weights, p=dropout_p, training=training)
    
    # Step 6: Multiply with values
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

# Test
Q = torch.randn(1, 4, 8)
K = torch.randn(1, 4, 8)
V = torch.randn(1, 4, 8)

output, attention = scaled_dot_product_attention_with_dropout(Q, K, V, dropout_p=0.1)
print(f"Output shape: {output.shape}")
print(f"Attention shape: {attention.shape}")
print("\nSolution verified!")

## Exercise 2: Attention Complexity Analysis

**Task:** What is the time and space complexity of self-attention?

In [None]:
def analyze_attention_complexity():
    """
    Measure memory usage for different sequence lengths.
    """
    sequence_lengths = [128, 256, 512, 1024, 2048]
    d_model = 64
    
    memory_usage = []
    
    for seq_len in sequence_lengths:
        # Create attention matrix
        attention_matrix = torch.randn(1, seq_len, seq_len)
        
        # Calculate memory (in MB)
        # Each float32 is 4 bytes
        memory_mb = attention_matrix.numel() * 4 / (1024 * 1024)
        memory_usage.append(memory_mb)
        
        print(f"Seq len {seq_len}: Attention matrix size = {seq_len}x{seq_len} = {seq_len**2:,} elements")
        print(f"  Memory: {memory_mb:.2f} MB")
    
    # Plot
    plt.figure(figsize=(10, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(sequence_lengths, memory_usage, 'bo-')
    plt.xlabel('Sequence Length')
    plt.ylabel('Memory (MB)')
    plt.title('Attention Memory: O(n²)')
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot(sequence_lengths, [s**2 for s in sequence_lengths], 'ro-')
    plt.xlabel('Sequence Length')
    plt.ylabel('n²')
    plt.title('Quadratic Growth')
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    print("\n=== Complexity Analysis ===")
    print("Time Complexity: O(n² × d) where n = sequence length, d = dimension")
    print("Space Complexity: O(n²) for the attention matrix")
    print("\nThis is why:")
    print("  - GPT-4 has a 128K context limit")
    print("  - Efficient attention variants exist (Flash Attention, Linear Attention)")
    print("  - Long sequences are expensive!")

analyze_attention_complexity()

## Exercise 3: Relative Position Attention (Challenge)

**Task:** Implement relative position attention.

In [None]:
class RelativePositionAttention(nn.Module):
    """
    Self-attention with relative position bias.
    
    score(i, j) = q_i · k_j + q_i · r_(i-j)
    
    Where r_(i-j) is a learned embedding for relative position.
    """
    
    def __init__(self, d_model, max_relative_position=32):
        super().__init__()
        
        self.d_model = d_model
        self.max_relative_position = max_relative_position
        
        # Q, K, V projections
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # Relative position embeddings
        # Range: -max_relative_position to +max_relative_position
        num_positions = 2 * max_relative_position + 1
        self.relative_embeddings = nn.Embedding(num_positions, d_model)
        
    def _get_relative_positions(self, seq_len):
        """Compute relative position matrix."""
        positions = torch.arange(seq_len)
        relative = positions.unsqueeze(0) - positions.unsqueeze(1)
        
        # Clip to valid range
        relative = torch.clamp(
            relative, 
            -self.max_relative_position, 
            self.max_relative_position
        )
        
        # Shift to positive indices
        relative = relative + self.max_relative_position
        
        return relative
    
    def forward(self, x, mask=None):
        """
        Args:
            x: Input tensor (batch, seq_len, d_model)
            mask: Optional attention mask
        """
        batch_size, seq_len, _ = x.shape
        
        # Project to Q, K, V
        Q = self.W_q(x)  # (batch, seq, d_model)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Content-based attention scores: Q @ K^T
        content_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_model)
        
        # Position-based attention scores
        relative_positions = self._get_relative_positions(seq_len).to(x.device)
        relative_emb = self.relative_embeddings(relative_positions)  # (seq, seq, d_model)
        
        # Q @ relative_emb^T for each query position
        # Q: (batch, seq, d_model)
        # relative_emb: (seq, seq, d_model)
        position_scores = torch.einsum('bqd,qkd->bqk', Q, relative_emb) / math.sqrt(self.d_model)
        
        # Combine scores
        scores = content_scores + position_scores
        
        # Apply mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax and apply to values
        attention = F.softmax(scores, dim=-1)
        output = torch.matmul(attention, V)
        
        return output, attention

# Test
rel_attn = RelativePositionAttention(d_model=64, max_relative_position=16)
x = torch.randn(2, 20, 64)
output, attention = rel_attn(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention shape: {attention.shape}")
print("\nRelative position attention implemented successfully!")

---

## Key Takeaways

1. **Dropout in attention** is applied AFTER softmax but BEFORE multiplying with V
2. **Attention complexity** is O(n²) which limits context length
3. **Relative position** encodes distance between tokens, not absolute position

---