# Transformer Attention Visualization

This notebook visualizes the attention patterns in a trained Transformer model. It shows how different attention heads focus on different parts of the input sequence, helping understand the model's internal representations.

In [None]:
import torch
import torch.nn as nn
import altair as alt
import pandas as pd
import numpy as np
import warnings
import time
import gc
from pathlib import Path
from IPython.display import display

from model import Transformer
from config import get_config, get_weights_file_path, latest_weights_file_path
from train import get_model, get_ds, greedy_decode

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

In [3]:
# Define the device - use MPS for Mac M1/M2 if available
device = torch.device("cuda" if torch.cuda.is_available() else 
                     "mps" if torch.backends.mps.is_available() else 
                     "cpu")
print("Using device:", device)

# Set better defaults for GPU memory management if using CUDA
if device.type == 'cuda':
    # Free up GPU memory before starting
    torch.cuda.empty_cache()
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")

Using device: cuda


## Model Loading and Setup

Load the trained model and prepare for visualization.

In [None]:
# Load configuration and prepare datasets
start_time = time.time()
print("Loading configuration and datasets...")

try:
    # Load configuration
    config = get_config()
    
    # Get datasets and tokenizers
    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    print(f"Loaded datasets in {time.time() - start_time:.2f}s")
    print(f"Source vocabulary size: {tokenizer_src.get_vocab_size():,}")
    print(f"Target vocabulary size: {tokenizer_tgt.get_vocab_size():,}")
    
    # Build model
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    
    # Count model parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model has {total_params:,} parameters")
    
    # Load the pretrained weights with better error handling
    model_filename = get_weights_file_path(config, "29")
    
    if not Path(model_filename).exists():
        print(f"Warning: Weight file not found at {model_filename}")
        print("Trying to find latest weights instead...")
        model_filename = latest_weights_file_path(config)
        if not model_filename:
            raise FileNotFoundError("No model weights found!")
    
    print(f"Loading weights from {model_filename}")
    # Load weights directly to the target device for better memory usage
    state = torch.load(model_filename, map_location=device)
    model.load_state_dict(state['model_state_dict'])
    print(f"Loaded weights from epoch {state.get('epoch', 'unknown')}")
    
    # Set model to evaluation mode for visualization
    model.eval()
    
    print(f"Model preparation completed in {time.time() - start_time:.2f}s")
except Exception as e:
    print(f"Error during model setup: {e}")

## Data Preparation for Visualization

Prepare sample data for the attention visualization.

In [None]:
def load_next_batch():
    """Load a batch from validation set and prepare for attention visualization.
    
    Returns:
        tuple: (batch data, encoder tokens, decoder tokens)
    """
    # Load a sample batch from the validation set
    try:
        batch = next(iter(val_dataloader))
        encoder_input = batch["encoder_input"].to(device)
        encoder_mask = batch["encoder_mask"].to(device)
        decoder_input = batch["decoder_input"].to(device)
        decoder_mask = batch["decoder_mask"].to(device)

        # Verify batch size is 1 for visualization
        assert encoder_input.size(0) == 1, "Batch size must be 1 for visualization"
        
        # Convert tokens outside of loop for performance
        encoder_input_cpu = encoder_input[0].cpu().numpy()
        decoder_input_cpu = decoder_input[0].cpu().numpy()
        
        # Safely convert token IDs to tokens with safeguards for missing IDs
        encoder_input_tokens = []
        for idx in encoder_input_cpu:
            try:
                token = tokenizer_src.id_to_token(int(idx))
                encoder_input_tokens.append(token if token else "[UNK]")
            except:
                encoder_input_tokens.append("[UNK]")
        
        decoder_input_tokens = []
        for idx in decoder_input_cpu:
            try:
                token = tokenizer_tgt.id_to_token(int(idx))
                decoder_input_tokens.append(token if token else "[UNK]")
            except:
                decoder_input_tokens.append("[UNK]")

        # Run the model to generate and capture attention
        with torch.no_grad():
            model_out = greedy_decode(
                model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, 
                config['seq_len'], device
            )
        
        return batch, encoder_input_tokens, decoder_input_tokens
    except Exception as e:
        print(f"Error loading batch: {e}")
        raise

In [None]:
# Visualization functions for attention maps
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    """Convert attention matrix to DataFrame for visualization.
    
    Args:
        m: Attention matrix
        max_row: Maximum number of rows to include
        max_col: Maximum number of columns to include
        row_tokens: Tokens for rows
        col_tokens: Tokens for columns
        
    Returns:
        pd.DataFrame: Formatted data for Altair visualization
    """
    # Pre-allocate list with estimated size
    estimated_size = min(max_row, m.shape[0]) * min(max_col, m.shape[1])
    data = []
    
    # Faster iteration without nested list comprehension
    for r in range(min(max_row, m.shape[0])):
        for c in range(min(max_col, m.shape[1])):
            # Get token labels with safeguards
            row_token = row_tokens[r] if r < len(row_tokens) else "<blank>"
            col_token = col_tokens[c] if c < len(col_tokens) else "<blank>"
            
            data.append((
                r,
                c,
                float(m[r, c]),  # Convert to Python float for better compatibility
                f"{r:03d} {row_token}",
                f"{c:03d} {col_token}"
            ))
    
    return pd.DataFrame(
        data,
        columns=["row", "column", "value", "row_token", "col_token"]
    )

def get_attn_map(attn_type: str, layer: int, head: int):
    """Extract attention scores from the specified attention mechanism.
    
    Args:
        attn_type: Type of attention ("encoder", "decoder", or "encoder-decoder")
        layer: Layer index
        head: Attention head index
        
    Returns:
        torch.Tensor: Attention scores
    """
    try:
        if attn_type == "encoder":
            attn = model.encoder.layers[layer].self_attention_block.attention_scores
        elif attn_type == "decoder":
            attn = model.decoder.layers[layer].self_attention_block.attention_scores
        elif attn_type == "encoder-decoder":
            attn = model.decoder.layers[layer].cross_attention_block.attention_scores
        else:
            raise ValueError(f"Unknown attention type: {attn_type}")
            
        # Clone and detach to avoid memory leaks
        return attn[0, head].clone().detach().cpu()
    except Exception as e:
        print(f"Error getting attention map for {attn_type}, layer {layer}, head {head}: {e}")
        # Return an empty tensor as fallback
        return torch.zeros((1, 1))

def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):
    """Create an Altair chart for an attention map.
    
    Args:
        attn_type: Type of attention
        layer: Layer index
        head: Head index
        row_tokens: Tokens for rows
        col_tokens: Tokens for columns
        max_sentence_len: Maximum sentence length to display
        
    Returns:
        alt.Chart: Visualization of attention weights
    """
    # Get the attention map data
    attn_data = get_attn_map(attn_type, layer, head)
    
    # Convert to DataFrame
    df = mtx2df(
        attn_data,
        max_sentence_len,
        max_sentence_len,
        row_tokens,
        col_tokens,
    )
    
    # Drop NaN values that might cause rendering issues
    df = df.dropna()
    
    # Create the heatmap visualization
    chart = alt.Chart(data=df).mark_rect().encode(
        x=alt.X("col_token:N", axis=alt.Axis(title="", labelAngle=-45)),
        y=alt.Y("row_token:N", axis=alt.Axis(title="")),
        color=alt.Color("value:Q", scale=alt.Scale(scheme="viridis")),
        tooltip=["row", "column", "value", "row_token", "col_token"],
    ).properties(
        height=400, 
        width=400, 
        title=f"{attn_type.capitalize()} Layer {layer} Head {head}"
    ).interactive()
    
    return chart

def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], 
                           row_tokens: list, col_tokens: list, max_sentence_len: int):
    """Generate a grid of attention map visualizations.
    
    Args:
        attn_type: Type of attention
        layers: List of layer indices
        heads: List of head indices
        row_tokens: Tokens for rows
        col_tokens: Tokens for columns
        max_sentence_len: Maximum sentence length to display
        
    Returns:
        alt.VConcatChart: Grid of attention visualizations
    """
    # Show progress
    print(f"Generating {len(layers) * len(heads)} attention maps for {attn_type}...")
    start_time = time.time()
    
    charts = []
    for i, layer in enumerate(layers):
        row_charts = []
        for head in heads:
            # Create chart for this layer and head
            row_charts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))
            
        # Combine charts for this layer horizontally
        charts.append(alt.hconcat(*row_charts, spacing=5))
        
        # Show progress for long operations
        if (i+1) % 2 == 0:
            print(f"Processed {i+1}/{len(layers)} layers...")
    
    # Combine all layers vertically
    result = alt.vconcat(*charts, spacing=10)
    print(f"Completed in {time.time() - start_time:.2f}s")
    
    return result

In [None]:
# Load and prepare sample data for visualization
try:
    batch, encoder_input_tokens, decoder_input_tokens = load_next_batch()
    
    # Display source and target text
    print(f"Source: {batch['src_text'][0]}")
    print(f"Target: {batch['tgt_text'][0]}")
    
    # Find the end of the actual content (before padding)
    try:
        sentence_len = encoder_input_tokens.index("[PAD]")
        print(f"Sentence length (before padding): {sentence_len}")
    except ValueError:
        # Fallback if [PAD] not found - use a reasonable length
        sentence_len = min(20, len(encoder_input_tokens))
        print(f"[PAD] token not found. Using length: {sentence_len}")
    
    # Display token sequences for debugging
    print("\nSource tokens (first 20):")
    print(encoder_input_tokens[:min(20, len(encoder_input_tokens))])
except Exception as e:
    print(f"Error preparing sample: {e}")

## Attention Visualizations

Generate visualizations for the three types of attention mechanisms in the transformer:

In [None]:
# Configuration for visualization
layers = [0, 1, 2]  # Use more layers if needed
heads = [0, 1, 2, 3, 4, 5, 6, 7]  # 8 attention heads

# Limit sentence length for clearer visualization
vis_len = min(20, sentence_len)

print(f"Visualizing attention for first {vis_len} tokens across {len(layers)} layers and {len(heads)} heads")

# 1. Encoder Self-Attention: How tokens in the encoder attend to each other
get_all_attention_maps("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, vis_len)

In [None]:
# 2. Decoder Self-Attention: How tokens in the decoder attend to previous tokens
get_all_attention_maps("decoder", layers, heads, decoder_input_tokens, decoder_input_tokens, vis_len)

In [None]:
# 3. Cross-Attention: How decoder tokens attend to encoder tokens
get_all_attention_maps("encoder-decoder", layers, heads, decoder_input_tokens, encoder_input_tokens, vis_len)

## Attention Pattern Analysis

The visualizations above show different patterns in the attention mechanisms:

1. **Encoder Self-Attention**: Shows how words in the source sentence relate to each other. Look for patterns where:
   - Words attend to related words or context
   - Some heads focus on local relationships (nearby words)
   - Others capture long-range dependencies

2. **Decoder Self-Attention**: Shows how words in the target attend to previous words (causal attention). Notice:
   - The triangular pattern (words can only see previous words)
   - Different roles for different heads

3. **Cross-Attention**: Shows how target words look at source words. This reveals:
   - Word alignment between languages
   - Which source words are most important for translating each target word

In [None]:
# Clean up memory - important for notebook environments, especially with GPU
import gc

# Remove references to large objects
del model
del train_dataloader
del val_dataloader

# Clean up CUDA memory if available
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    
# Force garbage collection
gc.collect()

print("Memory cleaned up")