# Attention Mechanisms

Implementation of Scaled Dot-Product Attention and Multi-Head Attention.

In this notebook, you'll learn how attention works by implementing it step-by-step. Each function is broken down into its own cell, with practice cells where you can write your own implementation!

## Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple

## Scaled Dot-Product Attention

The core attention mechanism: `Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V`

### Step 1: Initialize ScaledDotProductAttention

In [None]:
class ScaledDotProductAttention(nn.Module):
    """
    Scaled Dot-Product Attention.
    
    Computes attention weights and applies them to values:
        Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
    
    The scaling by sqrt(d_k) prevents the dot products from becoming too large,
    which would push softmax into regions with very small gradients.
    """
    
    def __init__(self, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

### ðŸŽ¯ Practice: Implement your own `__init__`

Try implementing the initialization yourself! What do you need to store?
- Hint: You need a dropout layer

In [None]:
# Your implementation here
# class MyScaledDotProductAttention(nn.Module):
#     def __init__(self, dropout: float = 0.1):
#         # Your code here
#         pass

### Step 2: Implement Forward Pass

In [None]:
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute scaled dot-product attention.
        
        Args:
            query: Query tensor of shape (batch, heads, seq_q, d_k)
            key: Key tensor of shape (batch, heads, seq_k, d_k)
            value: Value tensor of shape (batch, heads, seq_k, d_v)
            mask: Optional attention mask. True/1 values are MASKED (not attended to).
                  Shape can be (batch, 1, 1, seq_k) or (batch, 1, seq_q, seq_k)
                  
        Returns:
            Tuple of:
                - Output tensor of shape (batch, heads, seq_q, d_v)
                - Attention weights of shape (batch, heads, seq_q, seq_k)
        """
        d_k = query.size(-1)
        
        # Compute attention scores: QK^T / sqrt(d_k)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == True, float('-inf'))
        
        # Convert to probabilities
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply dropout
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention to values
        output = torch.matmul(attention_weights, value)
        
        return output, attention_weights

# Add forward method to the class
ScaledDotProductAttention.forward = forward

### ðŸŽ¯ Practice: Implement your own `forward`

Try implementing the forward pass yourself! The steps are:
1. Compute attention scores: `QK^T / sqrt(d_k)`
2. Apply mask (if provided)
3. Apply softmax to get attention weights
4. Apply dropout
5. Multiply attention weights by values

In [None]:
# Your implementation here
# def my_forward(self, query, key, value, mask=None):
#     # Your code here
#     pass

## Multi-Head Attention

Instead of performing a single attention function, we project Q, K, V multiple times with different learned projections, perform attention in parallel, then concatenate and project again.

### Step 1: Initialize MultiHeadAttention

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention mechanism.
    
    MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O
    where head_i = Attention(Q W_i^Q, K W_i^K, V W_i^V)
    """
    
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        dropout: float = 0.1
    ):
        super().__init__()
        
        assert d_model % n_heads == 0, \
            f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # Dimension per head
        
        # Linear projections for Q, K, V
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        
        # Output projection
        self.w_o = nn.Linear(d_model, d_model, bias=False)
        
        # Attention mechanism
        self.attention = ScaledDotProductAttention(dropout)
        
        # For storing attention weights
        self.attn_weights: Optional[torch.Tensor] = None

### ðŸŽ¯ Practice: Implement your own `__init__`

Try implementing the initialization yourself! What components do you need?
- Linear projections for Q, K, V
- Output projection
- Attention mechanism
- Calculate `d_k` (dimension per head)

In [None]:
# Your implementation here
# class MyMultiHeadAttention(nn.Module):
#     def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
#         # Your code here
#         pass

### Step 2: Forward Pass - Linear Projections

In [None]:
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        return_attention: bool = False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Apply multi-head attention.
        
        Args:
            query: Query tensor of shape (batch, seq_q, d_model)
            key: Key tensor of shape (batch, seq_k, d_model)
            value: Value tensor of shape (batch, seq_k, d_model)
            mask: Optional attention mask
            return_attention: Whether to return attention weights
            
        Returns:
            Tuple of:
                - Output tensor of shape (batch, seq_q, d_model)
                - Attention weights if return_attention=True, else None
        """
        batch_size = query.size(0)
        
        # 1. Linear projections
        q = self.w_q(query)
        k = self.w_k(key)
        v = self.w_v(value)
        
        # 2. Reshape to multiple heads
        # (batch, seq, d_model) -> (batch, seq, n_heads, d_k) -> (batch, n_heads, seq, d_k)
        q = q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # 3. Apply attention
        attn_output, attn_weights = self.attention(q, k, v, mask)
        
        # Store attention weights for visualization
        if return_attention:
            self.attn_weights = attn_weights
        
        # 4. Concatenate heads
        # (batch, n_heads, seq_q, d_k) -> (batch, seq_q, n_heads, d_k) -> (batch, seq_q, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )
        
        # 5. Final linear projection
        output = self.w_o(attn_output)
        
        if return_attention:
            return output, attn_weights
        return output, None

# Add forward method to the class
MultiHeadAttention.forward = forward

### ðŸŽ¯ Practice: Implement your own `forward`

Try implementing the forward pass yourself! The steps are:
1. Apply linear projections to Q, K, V
2. Reshape to multiple heads: `(batch, seq, d_model) -> (batch, n_heads, seq, d_k)`
3. Apply attention
4. Concatenate heads back: `(batch, n_heads, seq, d_k) -> (batch, seq, d_model)`
5. Apply final linear projection

In [None]:
# Your implementation here
# def my_forward(self, query, key, value, mask=None, return_attention=False):
#     # Your code here
#     pass

## Test Attention

Let's test our implementation!

In [None]:
# Test MultiHeadAttention
mha = MultiHeadAttention(d_model=512, n_heads=8)
x = torch.randn(2, 32, 512)  # (batch=2, seq=32, d_model=512)
out, attn = mha(x, x, x, return_attention=True)

print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Attention shape: {attn.shape}")
print(f"\nâœ… Attention mechanism works!")