# Extract Attention Maps from NLLB Model
Explore and extract encoder/decoder attention weights for English → French translation

In [None]:
# Import all required libraries
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from datasets import load_from_disk
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

## 1. Load model and data

In [None]:
# Load model with eager attention implementation (required for attention output)
model_dir = "../models/nllb-600M"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_dir,
    attn_implementation="eager"  # Required for output_attentions=True
)

# Move to GPU if available
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

model = model.to(device)
print(f"Model loaded on device: {device}")
print(f"Attention implementation: eager")

# Load dataset
dataset = load_from_disk("../data/wmt14_fr-en_validation_2000")
print(f"\nLoaded {len(dataset)} sentence pairs")

## 2. Extract attention from a single example

In [None]:
# Get first example
example = dataset[0]["translation"]
english = example["en"]
french = example["fr"]

print(f"English: {english}")
print(f"French:  {french}")

# Tokenize English input
tokenizer.src_lang = "eng_Latn"
inputs = tokenizer(english, return_tensors="pt").to(device)

print(f"\nInput shape: {inputs['input_ids'].shape}")
print(f"Input tokens: {tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])}")

In [None]:
# Generate translation with attention output
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        forced_bos_token_id=tokenizer.convert_tokens_to_ids("fra_Latn"),
        max_length=100,
        output_attentions=True,  # IMPORTANT: Enable attention output
        return_dict_in_generate=True  # Return structured output
    )

# Decode translation
translation = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
print(f"Translation: {translation}")
print(f"\nOutput tokens: {tokenizer.convert_ids_to_tokens(outputs.sequences[0])}")

## 3. Understand attention structure

In [None]:
# Inspect the output structure
print("Output keys:", outputs.keys())
print("\nAttention types available:")
if hasattr(outputs, 'encoder_attentions') and outputs.encoder_attentions is not None:
    print(f"  - Encoder self-attention: {len(outputs.encoder_attentions)} layers")
if hasattr(outputs, 'decoder_attentions') and outputs.decoder_attentions is not None:
    print(f"  - Decoder self-attention: {len(outputs.decoder_attentions)} timesteps")
if hasattr(outputs, 'cross_attentions') and outputs.cross_attentions is not None:
    print(f"  - Cross-attention: {len(outputs.cross_attentions)} timesteps")

In [None]:
# Examine encoder attention structure
if outputs.encoder_attentions is not None:
    encoder_attn = outputs.encoder_attentions
    print(f"Encoder attention:")
    print(f"  Number of layers: {len(encoder_attn)}")
    print(f"  Shape per layer: {encoder_attn[0].shape}")  # (batch, heads, seq_len, seq_len)
    print(f"  Format: (batch_size, num_heads, seq_length, seq_length)")
    
    # Get last layer attention
    last_layer_attn = encoder_attn[-1][0]  # Remove batch dimension
    print(f"\n  Last layer shape: {last_layer_attn.shape}")
    print(f"  Number of attention heads: {last_layer_attn.shape[0]}")

In [None]:
# Examine decoder and cross-attention structure
if outputs.decoder_attentions is not None:
    print(f"\nDecoder self-attention:")
    print(f"  Number of timesteps: {len(outputs.decoder_attentions)}")
    print(f"  Each timestep contains {len(outputs.decoder_attentions[0])} layers")
    print(f"  Shape format: (batch_size, num_heads, query_length, key_length)")
    print(f"  Shape varies per timestep due to causal masking:")
    # Show first few timesteps to illustrate the pattern
    for t in range(min(3, len(outputs.decoder_attentions))):
        shape = outputs.decoder_attentions[t][0].shape
        print(f"    Timestep {t}: {shape}")

if outputs.cross_attentions is not None:
    print(f"\nCross-attention (decoder attending to encoder):")
    print(f"  Number of timesteps: {len(outputs.cross_attentions)}")
    print(f"  Each timestep contains {len(outputs.cross_attentions[0])} layers")
    print(f"  Shape format: (batch_size, num_heads, decoder_length, encoder_length)")
    print(f"  Shape varies per timestep:")
    for t in range(min(3, len(outputs.cross_attentions))):
        shape = outputs.cross_attentions[t][0].shape
        print(f"    Timestep {t}: {shape}")

## 4. Visualize encoder self-attention

In [None]:
# Get encoder attention from last layer
encoder_attn_last = outputs.encoder_attentions[-1][0].cpu().numpy()  # (heads, seq_len, seq_len)

# Average over all attention heads
encoder_attn_avg = encoder_attn_last.mean(axis=0)  # (seq_len, seq_len)

# Get tokens for axis labels
input_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].cpu())

# Plot heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(encoder_attn_avg, 
            xticklabels=input_tokens, 
            yticklabels=input_tokens,
            cmap='viridis',
            cbar_kws={'label': 'Attention Weight'})
plt.title('Encoder Self-Attention (Last Layer, Averaged over Heads)\nEnglish Sentence')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.tight_layout()
plt.show()

print(f"Attention matrix shape: {encoder_attn_avg.shape}")
print(f"Min attention: {encoder_attn_avg.min():.4f}")
print(f"Max attention: {encoder_attn_avg.max():.4f}")
print(f"Mean attention: {encoder_attn_avg.mean():.4f}")

## 5. Visualize individual attention heads

In [None]:
# Visualize first 4 attention heads from last encoder layer
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axes = axes.flatten()

for head_idx in range(min(4, encoder_attn_last.shape[0])):
    attn_head = encoder_attn_last[head_idx]
    
    sns.heatmap(attn_head, 
                xticklabels=input_tokens, 
                yticklabels=input_tokens,
                cmap='viridis',
                ax=axes[head_idx],
                cbar_kws={'label': 'Weight'})
    axes[head_idx].set_title(f'Attention Head {head_idx}')
    axes[head_idx].set_xlabel('Key')
    axes[head_idx].set_ylabel('Query')

plt.suptitle('Encoder Self-Attention Heads (Last Layer)', y=1.02, fontsize=14)
plt.tight_layout()
plt.show()

## 5b. Compare attention across layers (early vs late)

In [None]:
# Compare attention from different layers: early (layer 0), middle, and late (last layer)
num_layers = len(outputs.encoder_attentions)
layer_indices = [0, num_layers // 2, num_layers - 1]  # First, middle, last
layer_names = ['First Layer (0)', f'Middle Layer ({num_layers // 2})', f'Last Layer ({num_layers - 1})']

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for idx, (layer_idx, layer_name) in enumerate(zip(layer_indices, layer_names)):
    # Get attention from this layer and average over heads
    layer_attn = outputs.encoder_attentions[layer_idx][0].cpu().numpy()  # (heads, seq_len, seq_len)
    layer_attn_avg = layer_attn.mean(axis=0)  # Average over heads
    
    # Plot
    sns.heatmap(layer_attn_avg,
                xticklabels=input_tokens,
                yticklabels=input_tokens,
                cmap='viridis',
                ax=axes[idx],
                cbar_kws={'label': 'Weight'})
    axes[idx].set_title(f'{layer_name}\n(Averaged over {layer_attn.shape[0]} heads)')
    axes[idx].set_xlabel('Key Position')
    axes[idx].set_ylabel('Query Position')

plt.suptitle('Encoder Self-Attention: Comparing Early vs Middle vs Late Layers', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

# Print statistics for each layer
print("Attention statistics by layer:")
print("=" * 60)
for layer_idx, layer_name in zip(layer_indices, layer_names):
    layer_attn = outputs.encoder_attentions[layer_idx][0].cpu().numpy()
    layer_attn_avg = layer_attn.mean(axis=0)
    
    # Exclude special tokens for content-to-content attention
    # Assuming first token is lang tag and last is </s>
    content_attn = layer_attn_avg[1:-1, 1:-1]
    
    print(f"\n{layer_name}:")
    print(f"  Full attention - Mean: {layer_attn_avg.mean():.4f}, Max: {layer_attn_avg.max():.4f}")
    if content_attn.size > 0:
        print(f"  Content-only - Mean: {content_attn.mean():.4f}, Max: {content_attn.max():.4f}")

## 5c. Filter special tokens and apply threshold

In [None]:
def filter_and_threshold_attention(attn_matrix, tokens, threshold=0.1):
    """
    Filter special tokens and apply threshold to attention matrix.
    
    Args:
        attn_matrix: (seq_len, seq_len) attention matrix
        tokens: list of token strings
        threshold: minimum attention weight to keep (default 0.1)
    
    Returns:
        dict with:
            - 'filtered_attn': filtered and renormalized attention matrix
            - 'content_tokens': tokens without special tokens
            - 'num_edges': number of edges after thresholding
    """
    # Identify special token indices (first and last)
    # First token is usually language tag (eng_Latn), last is </s>
    special_indices = [0, len(tokens) - 1]
    
    # Get content token indices
    content_indices = [i for i in range(len(tokens)) if i not in special_indices]
    
    # Extract content-only attention (remove special tokens)
    filtered_attn = attn_matrix[content_indices, :][:, content_indices]
    content_tokens = [tokens[i] for i in content_indices]
    
    # Apply threshold (set values below threshold to 0)
    thresholded_attn = filtered_attn.copy()
    thresholded_attn[thresholded_attn < threshold] = 0
    
    # Renormalize rows to sum to 1 (only for non-zero rows)
    row_sums = thresholded_attn.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1  # Avoid division by zero
    renormalized_attn = thresholded_attn / row_sums
    
    # Count edges (non-zero entries)
    num_edges = (renormalized_attn > 0).sum()
    
    return {
        'filtered_attn': renormalized_attn,
        'content_tokens': content_tokens,
        'num_edges': num_edges,
        'sparsity': 1 - (num_edges / (renormalized_attn.size))
    }

# Test with different thresholds
thresholds = [0.0, 0.1, 0.2]
print("Testing different threshold values:")
print("=" * 70)

for thresh in thresholds:
    result = filter_and_threshold_attention(encoder_attn_avg, input_tokens, threshold=thresh)
    print(f"\nThreshold: {thresh}")
    print(f"  Content tokens: {len(result['content_tokens'])}")
    print(f"  Matrix shape: {result['filtered_attn'].shape}")
    print(f"  Number of edges: {result['num_edges']}")
    print(f"  Sparsity: {result['sparsity']:.2%}")
    print(f"  Mean attention: {result['filtered_attn'].mean():.4f}")
    print(f"  Max attention: {result['filtered_attn'].max():.4f}")

In [None]:
# Visualize before and after filtering + thresholding
threshold = 0.1
filtered_result = filter_and_threshold_attention(encoder_attn_avg, input_tokens, threshold=threshold)

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Before: Original attention
sns.heatmap(encoder_attn_avg,
            xticklabels=input_tokens,
            yticklabels=input_tokens,
            cmap='viridis',
            ax=axes[0],
            cbar_kws={'label': 'Weight'})
axes[0].set_title(f'Before Filtering\n({len(input_tokens)} tokens, including special tokens)')
axes[0].set_xlabel('Key Position')
axes[0].set_ylabel('Query Position')

# After: Filtered and thresholded
sns.heatmap(filtered_result['filtered_attn'],
            xticklabels=filtered_result['content_tokens'],
            yticklabels=filtered_result['content_tokens'],
            cmap='viridis',
            ax=axes[1],
            cbar_kws={'label': 'Weight'})
axes[1].set_title(f'After Filtering + Threshold={threshold}\n({len(filtered_result["content_tokens"])} content tokens, {filtered_result["num_edges"]} edges, {filtered_result["sparsity"]:.1%} sparse)')
axes[1].set_xlabel('Key Position')
axes[1].set_ylabel('Query Position')

plt.suptitle('Attention Filtering: Removing Special Tokens + Thresholding', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print(f"\nFiltering summary:")
print(f"  Original shape: {encoder_attn_avg.shape}")
print(f"  Filtered shape: {filtered_result['filtered_attn'].shape}")
print(f"  Removed tokens: {[input_tokens[0], input_tokens[-1]]}")
print(f"  Content tokens: {filtered_result['content_tokens']}")

In [None]:
# Test with very low thresholds to preserve more structure
print("Testing lower thresholds to preserve attention structure")
print("=" * 70)

# Test a range of low thresholds
low_thresholds = [0.0, 0.01, 0.02, 0.05, 0.10]

fig, axes = plt.subplots(1, len(low_thresholds), figsize=(5*len(low_thresholds), 5))

for idx, thresh in enumerate(low_thresholds):
    result = filter_and_threshold_attention(encoder_attn_avg, input_tokens, threshold=thresh)
    
    sns.heatmap(result['filtered_attn'],
                xticklabels=result['content_tokens'],
                yticklabels=result['content_tokens'],
                cmap='viridis',
                ax=axes[idx],
                vmin=0,
                vmax=0.3,  # Fix color scale for comparison
                cbar_kws={'label': 'Weight'})
    
    axes[idx].set_title(f'Threshold = {thresh}\n{result["num_edges"]} edges ({(1-result["sparsity"]):.1%} dense)')
    axes[idx].set_xlabel('Key')
    axes[idx].set_ylabel('Query')
    
    # Print stats
    print(f"\nThreshold {thresh}:")
    print(f"  Edges: {result['num_edges']} / {result['filtered_attn'].size} ({(1-result['sparsity']):.1%} dense)")
    print(f"  Mean weight: {result['filtered_attn'].mean():.6f}")
    print(f"  Max weight: {result['filtered_attn'].max():.6f}")

plt.suptitle('Effect of Threshold on Attention Graph Sparsity', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print("\n" + "=" * 70)
print("Recommendation: Use threshold between 0.01-0.05 to balance sparsity and structure")

In [None]:
# Check individual attention heads for diagonal patterns
# Maybe averaging is washing out the diagonal pattern
print("Checking individual attention heads for diagonal patterns")
print("=" * 70)

encoder_attn_last_layer = outputs.encoder_attentions[-1][0].cpu().numpy()  # (heads, seq_len, seq_len)
num_heads = encoder_attn_last_layer.shape[0]

# For each head, compute diagonal strength
head_diagonal_strengths = []
for head_idx in range(num_heads):
    head_attn = encoder_attn_last_layer[head_idx]
    # Get content-only attention (remove special tokens)
    content_head_attn = head_attn[1:-1, 1:-1]
    
    # Compute diagonal vs off-diagonal mean
    diag_mean = np.diag(content_head_attn).mean()
    off_diag_mask = ~np.eye(content_head_attn.shape[0], dtype=bool)
    off_diag_mean = content_head_attn[off_diag_mask].mean()
    
    ratio = diag_mean / off_diag_mean if off_diag_mean > 0 else 0
    head_diagonal_strengths.append({
        'head': head_idx,
        'diag_mean': diag_mean,
        'off_diag_mean': off_diag_mean,
        'ratio': ratio
    })

# Sort by ratio (strongest diagonal first)
head_diagonal_strengths.sort(key=lambda x: x['ratio'], reverse=True)

print(f"\nAttention heads ranked by diagonal strength (top 5):")
print(f"{'Head':<6} {'Diag Mean':<12} {'Off-Diag Mean':<15} {'Ratio':<10}")
print("-" * 50)
for i, stats in enumerate(head_diagonal_strengths[:5]):
    print(f"{stats['head']:<6} {stats['diag_mean']:<12.6f} {stats['off_diag_mean']:<15.6f} {stats['ratio']:<10.2f}x")

# Visualize the head with strongest diagonal
best_head_idx = head_diagonal_strengths[0]['head']
best_head_attn = encoder_attn_last_layer[best_head_idx][1:-1, 1:-1]  # Remove special tokens

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot the best head (strongest diagonal)
sns.heatmap(best_head_attn,
            xticklabels=content_tokens_list,
            yticklabels=content_tokens_list,
            cmap='viridis',
            ax=axes[0],
            cbar_kws={'label': 'Weight'})
axes[0].set_title(f'Head {best_head_idx} (Strongest Diagonal)\nRatio: {head_diagonal_strengths[0]["ratio"]:.2f}x')
axes[0].set_xlabel('Key')
axes[0].set_ylabel('Query')

# Plot the averaged attention for comparison
sns.heatmap(content_attn,
            xticklabels=content_tokens_list,
            yticklabels=content_tokens_list,
            cmap='viridis',
            ax=axes[1],
            cbar_kws={'label': 'Weight'})
axes[1].set_title(f'Averaged Over All {num_heads} Heads\n(This is what we\'ve been using)')
axes[1].set_xlabel('Key')
axes[1].set_ylabel('Query')

# Plot distribution of diagonal ratios across all heads
ratios = [h['ratio'] for h in head_diagonal_strengths]
axes[2].bar(range(len(ratios)), ratios)
axes[2].axhline(1.0, color='red', linestyle='--', label='Equal (ratio=1)')
axes[2].set_xlabel('Attention Head (sorted by diagonal strength)')
axes[2].set_ylabel('Diagonal / Off-Diagonal Ratio')
axes[2].set_title(f'Diagonal Strength Across All {num_heads} Heads')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.suptitle('Individual Attention Heads: Is Averaging Hiding Diagonal Patterns?', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print(f"\n✓ Found {sum(1 for h in head_diagonal_strengths if h['ratio'] > 1.0)} heads with diagonal > off-diagonal")

In [None]:
# Analyze attention value distribution
print("Attention Value Distribution Analysis")
print("=" * 70)

# Get content-only attention (remove special tokens)
content_attn = encoder_attn_avg[1:-1, 1:-1]
content_tokens_list = input_tokens[1:-1]

print(f"\nContent-only attention matrix shape: {content_attn.shape}")
print(f"Number of content tokens: {len(content_tokens_list)}")
print(f"Content tokens: {content_tokens_list}")

# Get diagonal and off-diagonal values
diagonal_values = np.diag(content_attn)
off_diagonal_mask = ~np.eye(content_attn.shape[0], dtype=bool)
off_diagonal_values = content_attn[off_diagonal_mask]

print(f"\nDiagonal (self-attention) values:")
print(f"  Min: {diagonal_values.min():.6f}")
print(f"  Max: {diagonal_values.max():.6f}")
print(f"  Mean: {diagonal_values.mean():.6f}")
print(f"  Median: {np.median(diagonal_values):.6f}")
print(f"  Values: {diagonal_values}")

print(f"\nOff-diagonal values:")
print(f"  Min: {off_diagonal_values.min():.6f}")
print(f"  Max: {off_diagonal_values.max():.6f}")
print(f"  Mean: {off_diagonal_values.mean():.6f}")
print(f"  Median: {np.median(off_diagonal_values):.6f}")

print(f"\nComparison:")
print(f"  Diagonal mean / Off-diagonal mean: {diagonal_values.mean() / off_diagonal_values.mean():.2f}x")
print(f"  → Diagonal is {'STRONGER' if diagonal_values.mean() > off_diagonal_values.mean() else 'WEAKER'} than off-diagonal")

# Visualize distribution
fig, axes = plt.subplots(1, 3, figsize=(18, 4))

# Histogram of all values
axes[0].hist(content_attn.flatten(), bins=50, alpha=0.7, label='All values', edgecolor='black')
axes[0].axvline(diagonal_values.mean(), color='red', linestyle='--', linewidth=2, label=f'Diagonal mean ({diagonal_values.mean():.4f})')
axes[0].axvline(off_diagonal_values.mean(), color='blue', linestyle='--', linewidth=2, label=f'Off-diag mean ({off_diagonal_values.mean():.4f})')
axes[0].set_xlabel('Attention Weight')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Distribution of All Attention Values')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Separate histograms
axes[1].hist(diagonal_values, bins=20, alpha=0.7, color='red', label='Diagonal', edgecolor='black')
axes[1].hist(off_diagonal_values, bins=50, alpha=0.5, color='blue', label='Off-diagonal', edgecolor='black')
axes[1].set_xlabel('Attention Weight')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Diagonal vs Off-Diagonal Values')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Box plot
axes[2].boxplot([diagonal_values, off_diagonal_values], labels=['Diagonal', 'Off-diagonal'])
axes[2].set_ylabel('Attention Weight')
axes[2].set_title('Value Distribution Comparison')
axes[2].grid(True, alpha=0.3)

plt.suptitle('Content-Only Attention: Analyzing Self-Attention (Diagonal) Patterns', fontsize=14)
plt.tight_layout()
plt.show()

## 5d. Investigate attention patterns: Why no diagonal?

## 6. Extract and save attention for analysis

In [None]:
def extract_attention_maps(text, tokenizer, model, device, src_lang="eng_Latn", tgt_lang="fra_Latn"):
    """
    Extract encoder self-attention for a given text.
    
    Returns:
        dict with:
            - 'tokens': list of tokens
            - 'encoder_attention': numpy array (layers, heads, seq_len, seq_len)
            - 'encoder_attention_avg': numpy array (seq_len, seq_len) - averaged over layers and heads
    """
    # Tokenize
    tokenizer.src_lang = src_lang
    inputs = tokenizer(text, return_tensors="pt").to(device)
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].cpu())
    
    # Generate with attention
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
            max_length=100,
            output_attentions=True,
            return_dict_in_generate=True
        )
    
    # Extract encoder attention
    encoder_attn_all = torch.stack([layer[0] for layer in outputs.encoder_attentions]).cpu().numpy()
    # Shape: (layers, heads, seq_len, seq_len)
    
    # Average over layers and heads
    encoder_attn_avg = encoder_attn_all.mean(axis=(0, 1))  # (seq_len, seq_len)
    
    return {
        'tokens': tokens,
        'encoder_attention': encoder_attn_all,
        'encoder_attention_avg': encoder_attn_avg
    }

# Test the function
test_result = extract_attention_maps(english, tokenizer, model, device)
print(f"Extracted attention for: {english}")
print(f"Tokens: {test_result['tokens']}")
print(f"Encoder attention shape: {test_result['encoder_attention'].shape}")
print(f"Averaged attention shape: {test_result['encoder_attention_avg'].shape}")

## 7. Test on multiple examples

In [None]:
# Extract attention for first 5 examples
num_examples = 5
attention_data = []

print(f"Extracting attention for {num_examples} examples...\n")

for i in range(num_examples):
    example = dataset[i]["translation"]
    english = example["en"]
    french = example["fr"]
    
    result = extract_attention_maps(english, tokenizer, model, device)
    
    attention_data.append({
        'index': i,
        'english': english,
        'french': french,
        'tokens': result['tokens'],
        'attention_avg': result['encoder_attention_avg']
    })
    
    print(f"[{i+1}/{num_examples}] Extracted attention for: {english[:50]}...")

print(f"\n✓ Extracted attention for {len(attention_data)} examples")

In [None]:
# Visualize attention for multiple examples
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for i in range(min(3, len(attention_data))):
    data = attention_data[i]
    
    sns.heatmap(data['attention_avg'],
                xticklabels=data['tokens'],
                yticklabels=data['tokens'],
                cmap='viridis',
                ax=axes[i],
                cbar_kws={'label': 'Weight'})
    axes[i].set_title(f"Example {i+1}\n{data['english'][:40]}...", fontsize=10)
    axes[i].set_xlabel('Key')
    axes[i].set_ylabel('Query')

plt.suptitle('Encoder Self-Attention (Averaged)', fontsize=14)
plt.tight_layout()
plt.show()

## Summary

**Successfully extracted:**
- Encoder self-attention maps (English sentence structure)
- Attention weights across all layers and heads
- Averaged attention for graph construction

**Next steps:**
1. Create batch script to process all 2000 examples
2. Build attention graphs (tokens as nodes, attention weights as edges)
3. Compute persistent homology on the graphs