# Testing Attention Mechanisms

This notebook demonstrates and visualizes the implementation of:
1. DilatedAttention
2. MultiheadDilatedAttention

We'll create some sample inputs and visualize how the attention patterns work.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from einops import rearrange

from model import DilatedAttention, MultiheadDilatedAttention, GPTConfig

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Testing DilatedAttention

Let's create a small example to visualize how the dilated attention works with different segment lengths and dilation rates.

In [None]:
# Create sample inputs
B, T, H, D = 1, 16, 4, 8  # batch, seq_len, heads, dim
query = torch.randn(B, T, H, D)
key = torch.randn(B, T, H, D)
value = torch.randn(B, T, H, D)

# Initialize attention with different segment lengths and dilation rates
segment_lengths = [4, 4, 4, 4]  # 4 segments of length 4
dilation_rates = [1, 2, 4, 8]   # different dilation rates
attention = DilatedAttention(segment_lengths, dilation_rates)

# Forward pass
output = attention(query, key, value)

print(f"Input shape: {query.shape}")
print(f"Output shape: {output.shape}")

### Visualizing Attention Patterns

Let's create a function to visualize the attention patterns for each head and dilation rate.

In [None]:
def visualize_attention_patterns(query, key, value, attention, title="Attention Patterns"):
    # Get attention scores
    B, T, H, D = query.shape
    num_groups = len(attention.dilation_rates)
    
    # Create a figure with subplots for each head
    fig, axes = plt.subplots(H, 1, figsize=(15, 4*H))
    if H == 1:
        axes = [axes]
    
    for h in range(H):
        # Calculate attention scores for this head
        q = query[0, :, h, :]  # (T, D)
        k = key[0, :, h, :]    # (T, D)
        
        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(D)
        
        # Apply softmax
        attn = torch.softmax(scores, dim=-1)
        
        # Plot attention pattern
        sns.heatmap(attn.detach().numpy(), 
                   ax=axes[h],
                   cmap='viridis',
                   xticklabels=True,
                   yticklabels=True)
        axes[h].set_title(f"Head {h} (Dilation Rate: {attention.dilation_rates[h % num_groups]})")
    
    plt.tight_layout()
    plt.show()

# Visualize attention patterns
visualize_attention_patterns(query, key, value, attention)

## 2. Testing MultiheadDilatedAttention

Now let's test the MultiheadDilatedAttention which combines the dilated attention with multi-head attention.

In [None]:
# Create a config for MultiheadDilatedAttention
config = GPTConfig(
    block_size=16,
    n_head=4,
    n_embd=32,
    dropout=0.0
)

# Initialize the attention module
multihead_attention = MultiheadDilatedAttention(config)

# Create sample input
x = torch.randn(1, 16, 32)  # (batch, seq_len, n_embd)

# Forward pass
output = multihead_attention(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

### Visualizing Multihead Attention Patterns

Let's visualize how the attention is distributed across different heads in the multihead attention.

In [None]:
def visualize_multihead_attention(x, multihead_attention, title="Multihead Attention Patterns"):
    # Get the query, key, value projections
    B, T, D = x.shape
    q, k, v = multihead_attention.c_attn(x).split(D, dim=2)
    
    # Reshape for visualization
    k = k.view(B, T, multihead_attention.n_head, D // multihead_attention.n_head)
    q = q.view(B, T, multihead_attention.n_head, D // multihead_attention.n_head)
    
    # Create a figure with subplots for each head
    fig, axes = plt.subplots(multihead_attention.n_head, 1, figsize=(15, 4*multihead_attention.n_head))
    if multihead_attention.n_head == 1:
        axes = [axes]
    
    for h in range(multihead_attention.n_head):
        # Calculate attention scores for this head
        q_h = q[0, :, h, :]  # (T, D)
        k_h = k[0, :, h, :]  # (T, D)
        
        # Compute attention scores
        scores = torch.matmul(q_h, k_h.transpose(-2, -1)) / np.sqrt(D // multihead_attention.n_head)
        
        # Apply softmax
        attn = torch.softmax(scores, dim=-1)
        
        # Plot attention pattern
        sns.heatmap(attn.detach().numpy(), 
                   ax=axes[h],
                   cmap='viridis',
                   xticklabels=True,
                   yticklabels=True)
        axes[h].set_title(f"Head {h}")
    
    plt.tight_layout()
    plt.show()

# Visualize multihead attention patterns
visualize_multihead_attention(x, multihead_attention)

## 3. Comparison with Standard Attention

Let's compare the attention patterns between standard attention and our dilated attention implementation.

In [None]:
def standard_attention(query, key, value, scale=None):
    """Standard scaled dot-product attention"""
    d_k = query.size(-1)
    if scale is None:
        scale = 1.0 / math.sqrt(d_k)
    
    scores = torch.matmul(query, key.transpose(-2, -1)) * scale
    attn = torch.softmax(scores, dim=-1)
    return torch.matmul(attn, value), attn

# Create sample inputs
B, T, H, D = 1, 16, 4, 8
query = torch.randn(B, T, H, D)
key = torch.randn(B, T, H, D)
value = torch.randn(B, T, H, D)

# Get attention patterns from both implementations
_, standard_attn = standard_attention(query, key, value)
dilated_output = attention(query, key, value)

# Visualize comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

# Plot standard attention
sns.heatmap(standard_attn[0, 0].detach().numpy(), 
           ax=ax1,
           cmap='viridis',
           xticklabels=True,
           yticklabels=True)
ax1.set_title("Standard Attention (Head 0)")

# Plot dilated attention
scores = torch.matmul(query[0, 0], key[0, 0].transpose(-2, -1)) / np.sqrt(D)
dilated_attn = torch.softmax(scores, dim=-1)
sns.heatmap(dilated_attn.detach().numpy(), 
           ax=ax2,
           cmap='viridis',
           xticklabels=True,
           yticklabels=True)
ax2.set_title("Dilated Attention (Head 0)")

plt.tight_layout()
plt.show()