# Visualizing Transformer Attention Patterns

This notebook demonstrates how to visualize attention patterns in our custom transformer implementation.

In [None]:
import sys
import os
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from src.transformer import Transformer
from src.utils import visualize_attention

## Create a Toy Example

Let's create a small transformer model and generate some attention patterns.

In [None]:
# Set random seed for reproducibility
torch.manual_seed(42)

# Create a small transformer model
vocab_size = 1000
embed_dim = 64
num_heads = 4
num_layers = 2
ff_dim = 128

model = Transformer(
    src_vocab_size=vocab_size,
    tgt_vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_layers=num_layers,
    num_heads=num_heads,
    ff_dim=ff_dim
)

In [None]:
# Create a toy input sequence
batch_size = 1
seq_len = 10

src = torch.randint(1, vocab_size, (batch_size, seq_len))
tgt = torch.randint(1, vocab_size, (batch_size, seq_len))

# Create masks
src_mask, tgt_mask, src_tgt_mask = Transformer.create_masks(src, tgt)

# Forward pass
output, attention_maps = model(src, tgt, src_mask, tgt_mask, src_tgt_mask)

## Visualize Encoder Self-Attention

In [None]:
# Get encoder attention maps (from the last layer, first head)
encoder_attention = attention_maps['encoder_attention'][-1][0, 0].detach().numpy()

plt.figure(figsize=(10, 8))
sns.heatmap(encoder_attention, annot=True, cmap='viridis')
plt.title("Encoder Self-Attention (Layer 2, Head 1)")
plt.xlabel("Key Position")
plt.ylabel("Query Position")
plt.show()

## Visualize Decoder Self-Attention

In [None]:
# Get decoder self-attention maps (from the last layer, first head)
decoder_self_attention = attention_maps['decoder_self_attention'][-1][0, 0].detach().numpy()

plt.figure(figsize=(10, 8))
sns.heatmap(decoder_self_attention, annot=True, cmap='viridis')
plt.title("Decoder Self-Attention (Layer 2, Head 1)")
plt.xlabel("Key Position")
plt.ylabel("Query Position")
plt.show()

## Visualize Decoder Cross-Attention

In [None]:
# Get decoder cross-attention maps (from the last layer, first head)
decoder_cross_attention = attention_maps['decoder_cross_attention'][-1][0, 0].detach().numpy()

plt.figure(figsize=(10, 8))
sns.heatmap(decoder_cross_attention, annot=True, cmap='viridis')
plt.title("Decoder Cross-Attention (Layer 2, Head 1)")
plt.xlabel("Encoder Key Position")
plt.ylabel("Decoder Query Position")
plt.show()

## Attention Pattern Analysis

Let's analyze what patterns emerge in the attention weights.

In [None]:
# Calculate statistics for each attention type
for name, attention_type in [
    ("Encoder Self-Attention", encoder_attention),
    ("Decoder Self-Attention", decoder_self_attention),
    ("Decoder Cross-Attention", decoder_cross_attention)
]:
    print(f"\n{name} Statistics:")
    print(f"Mean: {attention_type.mean():.4f}")
    print(f"Max: {attention_type.max():.4f}")
    print(f"Min: {attention_type.min():.4f}")
    print(f"Standard Deviation: {attention_type.std():.4f}")
    
    # Check for diagonal dominance in self-attention
    if "Self" in name:
        diagonal = np.diag(attention_type)
        off_diagonal = attention_type[~np.eye(attention_type.shape[0], dtype=bool)]
        print(f"Diagonal Mean: {diagonal.mean():.4f}")
        print(f"Off-Diagonal Mean: {off_diagonal.mean():.4f}")
        print(f"Diagonal/Off-Diagonal Ratio: {diagonal.mean() / off_diagonal.mean():.4f}")