# Extract Attention Maps from NLLB-1.3B Model

**For Google Colab:**
1. Mount Google Drive (run cell below)
2. Set `ROOT_DIR` to your project folder path
3. Run the rest of the notebook

**For local execution:** Skip the Google Drive cell and run from "Import Libraries"

In [None]:
# Mount Google Drive (only needed for Google Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # IMPORTANT: Set this to your code_fr_en directory path
    # This should point to where THIS notebook is located
    ROOT_DIR = "/content/drive/MyDrive/UofT/CSC2517/term_paper/code_fr_en"
    
    import os
    os.chdir(ROOT_DIR)
    print(f"✓ Changed to: {os.getcwd()}")
except ImportError:
    print("Not running on Colab, using local environment")

In [None]:
# Verify working directory and required files
import os
print(f"Current directory: {os.getcwd()}")

# Check model
model_path = "../models/nllb-1.3B"
if os.path.exists(model_path):
    print(f"✓ Model directory exists: {model_path}")
else:
    print(f"✗ Model directory NOT found: {model_path}")

# Check data
data_path = "../data/wmt14_fr_en_validation_2000"
if os.path.exists(data_path):
    print(f"✓ Data directory exists: {data_path}")
else:
    print(f"✗ Data directory NOT found: {data_path}")

## Import Libraries

In [None]:
# Import all required libraries
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from datasets import load_from_disk
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

## 1. Load model and data

In [None]:
# Load model with eager attention implementation (required for attention output)
model_dir = "../models/nllb-1.3B"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_dir,
    attn_implementation="eager"  # Required for output_attentions=True
)

# Move to GPU if available
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

model = model.to(device)
print(f"Model loaded on device: {device}")
print(f"Attention implementation: eager")

# Load dataset
dataset = load_from_disk("../data/wmt14_fr_en_validation_2000")
print(f"\nLoaded {len(dataset)} sentence pairs")

## 2. Extract attention from a single example

In [None]:
# Get first example
example = dataset[0]["translation"]
english = example["en"]
french = example["fr"]

print(f"English: {english}")
print(f"French:  {french}")

# Tokenize English input
tokenizer.src_lang = "eng_Latn"
inputs = tokenizer(english, return_tensors="pt").to(device)

print(f"\nInput shape: {inputs['input_ids'].shape}")
print(f"Input tokens: {tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])}")

In [None]:
# Generate translation with attention output
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        forced_bos_token_id=tokenizer.convert_tokens_to_ids("fra_Latn"),
        max_length=100,
        output_attentions=True,  # IMPORTANT: Enable attention output
        return_dict_in_generate=True  # Return structured output
    )

# Decode translation
translation = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
print(f"Translation: {translation}")
print(f"\nOutput tokens: {tokenizer.convert_ids_to_tokens(outputs.sequences[0])}")

## 3. Understand attention structure

In [None]:
# Inspect the output structure
print("Output keys:", outputs.keys())
print("\nAttention types available:")
if hasattr(outputs, 'encoder_attentions') and outputs.encoder_attentions is not None:
    print(f"  - Encoder self-attention: {len(outputs.encoder_attentions)} layers")
if hasattr(outputs, 'decoder_attentions') and outputs.decoder_attentions is not None:
    print(f"  - Decoder self-attention: {len(outputs.decoder_attentions)} timesteps")
if hasattr(outputs, 'cross_attentions') and outputs.cross_attentions is not None:
    print(f"  - Cross-attention: {len(outputs.cross_attentions)} timesteps")

In [None]:
# Examine encoder attention structure
if outputs.encoder_attentions is not None:
    encoder_attn = outputs.encoder_attentions
    print(f"Encoder attention:")
    print(f"  Number of layers: {len(encoder_attn)}")
    print(f"  Shape per layer: {encoder_attn[0].shape}")  # (batch, heads, seq_len, seq_len)
    print(f"  Format: (batch_size, num_heads, seq_length, seq_length)")
    
    # Get last layer attention
    last_layer_attn = encoder_attn[-1][0]  # Remove batch dimension
    print(f"\n  Last layer shape: {last_layer_attn.shape}")
    print(f"  Number of attention heads: {last_layer_attn.shape[0]}")

In [None]:
# Examine decoder and cross-attention structure
if outputs.decoder_attentions is not None:
    print(f"\nDecoder self-attention:")
    print(f"  Number of timesteps: {len(outputs.decoder_attentions)}")
    print(f"  Each timestep contains {len(outputs.decoder_attentions[0])} layers")
    print(f"  Shape format: (batch_size, num_heads, query_length, key_length)")
    print(f"  Shape varies per timestep due to causal masking:")
    # Show first few timesteps to illustrate the pattern
    for t in range(min(3, len(outputs.decoder_attentions))):
        shape = outputs.decoder_attentions[t][0].shape
        print(f"    Timestep {t}: {shape}")

if outputs.cross_attentions is not None:
    print(f"\nCross-attention (decoder attending to encoder):")
    print(f"  Number of timesteps: {len(outputs.cross_attentions)}")
    print(f"  Each timestep contains {len(outputs.cross_attentions[0])} layers")
    print(f"  Shape format: (batch_size, num_heads, decoder_length, encoder_length)")
    print(f"  Shape varies per timestep:")
    for t in range(min(3, len(outputs.cross_attentions))):
        shape = outputs.cross_attentions[t][0].shape
        print(f"    Timestep {t}: {shape}")

## 4. Visualize encoder self-attention

In [None]:
# Get encoder attention from last layer
encoder_attn_last = outputs.encoder_attentions[-1][0].cpu().numpy()  # (heads, seq_len, seq_len)

# Average over all attention heads
encoder_attn_avg = encoder_attn_last.mean(axis=0)  # (seq_len, seq_len)

# Get tokens for axis labels
input_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].cpu())

# Plot heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(encoder_attn_avg, 
            xticklabels=input_tokens, 
            yticklabels=input_tokens,
            cmap='viridis',
            cbar_kws={'label': 'Attention Weight'})
plt.title('Encoder Self-Attention (Last Layer, Averaged over Heads)\nEnglish Sentence')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.tight_layout()
plt.show()

print(f"Attention matrix shape: {encoder_attn_avg.shape}")
print(f"Min attention: {encoder_attn_avg.min():.4f}")
print(f"Max attention: {encoder_attn_avg.max():.4f}")
print(f"Mean attention: {encoder_attn_avg.mean():.4f}")

In [None]:
# Filter out special tokens and renormalize attention

# Identify special tokens to filter (BOS, EOS, language codes, padding)
special_tokens = {'</s>', '<s>', '<pad>', 'eng_Latn', 'fra_Latn', 'zho_Hans'}

# Create mask for non-special tokens
mask = [token not in special_tokens for token in input_tokens]
filtered_indices = [i for i, keep in enumerate(mask) if keep]

print(f"Original tokens ({len(input_tokens)}): {input_tokens}")
print(f"Filtered tokens ({len(filtered_indices)}): {[input_tokens[i] for i in filtered_indices]}")
print(f"Removed tokens: {[input_tokens[i] for i, keep in enumerate(mask) if not keep]}")

# Filter attention matrix (remove special token rows and columns)
encoder_attn_filtered = encoder_attn_avg[np.ix_(filtered_indices, filtered_indices)]

# Renormalize so each row sums to 1
row_sums = encoder_attn_filtered.sum(axis=1, keepdims=True)
encoder_attn_normalized = encoder_attn_filtered / row_sums

# Get filtered tokens for axis labels
filtered_tokens = [input_tokens[i] for i in filtered_indices]

# Plot filtered and normalized heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(encoder_attn_normalized, 
            xticklabels=filtered_tokens, 
            yticklabels=filtered_tokens,
            cmap='viridis',
            cbar_kws={'label': 'Attention Weight'})
plt.title('Encoder Self-Attention (Filtered & Renormalized)\nSpecial Tokens Removed')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.tight_layout()
plt.show()

print(f"\nFiltered attention matrix shape: {encoder_attn_normalized.shape}")
print(f"Min attention: {encoder_attn_normalized.min():.4f}")
print(f"Max attention: {encoder_attn_normalized.max():.4f}")
print(f"Mean attention: {encoder_attn_normalized.mean():.4f}")
print(f"Row sum check (should be 1.0): {encoder_attn_normalized.sum(axis=1)[:3]}")  # Check first 3 rows

## 5. Visualize individual attention heads

In [None]:
# Visualize first 4 attention heads from last encoder layer
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axes = axes.flatten()

for head_idx in range(min(4, encoder_attn_last.shape[0])):
    attn_head = encoder_attn_last[head_idx]
    
    sns.heatmap(attn_head, 
                xticklabels=input_tokens, 
                yticklabels=input_tokens,
                cmap='viridis',
                ax=axes[head_idx],
                cbar_kws={'label': 'Weight'})
    axes[head_idx].set_title(f'Attention Head {head_idx}')
    axes[head_idx].set_xlabel('Key')
    axes[head_idx].set_ylabel('Query')

plt.suptitle('Encoder Self-Attention Heads (Last Layer)', y=1.02, fontsize=14)
plt.tight_layout()
plt.show()

## 6. Extract attention for multiple examples

In [None]:
def extract_attention_maps(text, tokenizer, model, device, src_lang="eng_Latn", tgt_lang="fra_Latn"):
    """
    Extract encoder self-attention for a given text.
    
    Returns:
        dict with:
            - 'tokens': list of tokens
            - 'encoder_attention': numpy array (layers, heads, seq_len, seq_len)
            - 'encoder_attention_avg': numpy array (seq_len, seq_len) - averaged over layers and heads
    """
    # Tokenize
    tokenizer.src_lang = src_lang
    inputs = tokenizer(text, return_tensors="pt").to(device)
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].cpu())
    
    # Generate with attention
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
            max_length=100,
            output_attentions=True,
            return_dict_in_generate=True
        )
    
    # Extract encoder attention
    encoder_attn_all = torch.stack([layer[0] for layer in outputs.encoder_attentions]).cpu().numpy()
    # Shape: (layers, heads, seq_len, seq_len)
    
    # Average over layers and heads
    encoder_attn_avg = encoder_attn_all.mean(axis=(0, 1))  # (seq_len, seq_len)
    
    return {
        'tokens': tokens,
        'encoder_attention': encoder_attn_all,
        'encoder_attention_avg': encoder_attn_avg
    }

# Test the function
test_result = extract_attention_maps(english, tokenizer, model, device)
print(f"Extracted attention for: {english}")
print(f"Tokens: {test_result['tokens']}")
print(f"Encoder attention shape: {test_result['encoder_attention'].shape}")
print(f"Averaged attention shape: {test_result['encoder_attention_avg'].shape}")

In [None]:
# Extract attention for first 5 examples
num_examples = 5
attention_data = []

print(f"Extracting attention for {num_examples} examples...\n")

for i in range(num_examples):
    example = dataset[i]["translation"]
    english = example["en"]
    french = example["fr"]
    
    result = extract_attention_maps(english, tokenizer, model, device)
    
    attention_data.append({
        'index': i,
        'english': english,
        'french': french,
        'tokens': result['tokens'],
        'attention_avg': result['encoder_attention_avg']
    })
    
    print(f"[{i+1}/{num_examples}] Extracted attention for: {english[:50]}...")

print(f"\n✓ Extracted attention for {len(attention_data)} examples")

In [None]:
# Visualize attention for multiple examples
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for i in range(min(3, len(attention_data))):
    data = attention_data[i]
    
    sns.heatmap(data['attention_avg'],
                xticklabels=data['tokens'],
                yticklabels=data['tokens'],
                cmap='viridis',
                ax=axes[i],
                cbar_kws={'label': 'Weight'})
    axes[i].set_title(f"Example {i+1}\n{data['english'][:40]}...", fontsize=10)
    axes[i].set_xlabel('Key')
    axes[i].set_ylabel('Query')

plt.suptitle('Encoder Self-Attention (Averaged)', fontsize=14)
plt.tight_layout()
plt.show()

## Summary

**Successfully extracted:**
- Encoder self-attention maps (English sentence structure)
- Attention weights across all layers and heads
- Averaged attention for graph construction

**Next steps:**
1. Build attention graphs (tokens as nodes, attention weights as edges)
2. Compute persistent homology on the graphs
3. Compare English vs French topological structures