# Bidirectional Encoder Attention Extraction

This notebook demonstrates extracting encoder self-attention maps from **both** English and Chinese sentences by running the NLLB model in both translation directions:

1. **EN → ZH**: English source → Extract English encoder attention
2. **ZH → EN**: Chinese source → Extract Chinese encoder attention

This bidirectional extraction allows us to compare the topological structure of encoder attention patterns across languages.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import seaborn as sns
from pathlib import Path

# Configure matplotlib for Chinese font support
matplotlib.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False

# Force matplotlib to rebuild font cache if needed
try:
    fm._rebuild()
except:
    pass

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"MPS available: {torch.backends.mps.is_available() if hasattr(torch.backends, 'mps') else False}")

## 1. Load Model and Data

In [None]:
# Device setup
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA (NVIDIA GPU)")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon)")
else:
    device = torch.device("cpu")
    print("Using CPU")

# Load model and tokenizer with eager attention (required for output_attentions=True)
model_path = "../models/nllb-600M"
print(f"Loading model from {model_path}...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_path,
    attn_implementation="eager"  # Required for extracting attention weights
).to(device)
model.eval()
print("Model loaded successfully!")

In [None]:
# Load sample data
data_path = Path("../data/sentence_pairs_zh_en.pkl")
data = pd.read_pickle(data_path)

# Convert list of dicts to DataFrame and rename columns
df = pd.DataFrame(data)
df = df.rename(columns={'english': 'en', 'chinese': 'zh'})

print(f"Loaded {len(df)} sentence pairs")
print(f"\nColumns: {df.columns.tolist()}")
df.head(3)

## 2. Extract Encoder Attention in Both Directions

For each sentence pair, we'll:
1. Run **EN → ZH** translation and extract **English encoder attention**
2. Run **ZH → EN** translation and extract **Chinese encoder attention**

In [None]:
def extract_encoder_attention(text, src_lang, tgt_lang, tokenizer, model, device):
    """
    Extract encoder self-attention for a given source text.
    
    Args:
        text: Source text string
        src_lang: Source language code (e.g., 'eng_Latn', 'zho_Hans')
        tgt_lang: Target language code (e.g., 'zho_Hans', 'eng_Latn')
        tokenizer: NLLB tokenizer
        model: NLLB model
        device: torch device
    
    Returns:
        dict with keys:
            - tokens: List of source tokens
            - encoder_attention: Encoder self-attention (num_layers, num_heads, seq_len, seq_len)
            - translation: Generated translation text
    """
    # Set source language
    tokenizer.src_lang = src_lang
    
    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    # Get target language BOS token
    tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
    
    # Generate translation with attention output
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            forced_bos_token_id=tgt_lang_id,
            output_attentions=True,
            return_dict_in_generate=True,
            max_length=128
        )
    
    # Extract encoder attention (available in encoder_attentions)
    # Shape: (num_layers, batch_size, num_heads, seq_len, seq_len)
    encoder_attention = outputs.encoder_attentions
    encoder_attention = torch.stack([layer.squeeze(0) for layer in encoder_attention])  # (num_layers, num_heads, seq_len, seq_len)
    
    # Decode tokens
    input_tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
    translation = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
    
    return {
        'tokens': input_tokens,
        'encoder_attention': encoder_attention.cpu().numpy(),
        'translation': translation
    }

print("Function defined: extract_encoder_attention()")

## 3. Test on Sample Sentence Pair

In [None]:
# Select a sample sentence pair
idx = 10
en_text = df.iloc[idx]['en']
zh_text = df.iloc[idx]['zh']

print(f"Sample {idx}:")
print(f"English: {en_text}")
print(f"Chinese: {zh_text}")
print()

In [None]:
# Extract English encoder attention (EN → ZH)
print("Extracting English encoder attention (EN → ZH)...")
en_result = extract_encoder_attention(
    text=en_text,
    src_lang='eng_Latn',
    tgt_lang='zho_Hans',
    tokenizer=tokenizer,
    model=model,
    device=device
)

print(f"English tokens: {en_result['tokens']}")
print(f"English encoder attention shape: {en_result['encoder_attention'].shape}")
print(f"Translation to Chinese: {en_result['translation']}")
print()

In [None]:
# Extract Chinese encoder attention (ZH → EN)
print("Extracting Chinese encoder attention (ZH → EN)...")
zh_result = extract_encoder_attention(
    text=zh_text,
    src_lang='zho_Hans',
    tgt_lang='eng_Latn',
    tokenizer=tokenizer,
    model=model,
    device=device
)

print(f"Chinese tokens: {zh_result['tokens']}")
print(f"Chinese encoder attention shape: {zh_result['encoder_attention'].shape}")
print(f"Translation to English: {zh_result['translation']}")
print()

## 4. Visualize Encoder Attention Maps

Compare encoder attention patterns from English and Chinese for the same sentence pair.

In [None]:
def plot_encoder_attention(attention, tokens, layer=0, head=0, title="Encoder Self-Attention", filter_special=True):
    """
    Plot encoder self-attention heatmap.
    
    Args:
        attention: Attention weights (num_layers, num_heads, seq_len, seq_len)
        tokens: List of token strings
        layer: Which layer to visualize
        head: Which attention head to visualize
        title: Plot title
        filter_special: Whether to filter out special tokens
    """
    # Extract specified layer and head
    attn = attention[layer, head]  # (seq_len, seq_len)
    
    # Filter special tokens if requested
    if filter_special:
        # Keep only content tokens (filter out special tokens and language tags)
        special_tokens = {'</s>', '<s>', '<pad>', 'eng_Latn', 'zho_Hans'}
        content_mask = [tok not in special_tokens for tok in tokens]
        
        if sum(content_mask) > 0:  # Only filter if there are content tokens
            attn = attn[content_mask][:, content_mask]
            tokens = [tok for tok, keep in zip(tokens, content_mask) if keep]
            
            # Renormalize attention weights after filtering
            attn = attn / attn.sum(axis=-1, keepdims=True)
    
    # Plot
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attn,
        xticklabels=tokens,
        yticklabels=tokens,
        cmap='Blues',
        cbar_kws={'label': 'Attention Weight'},
        square=True
    )
    plt.xlabel('Key Tokens')
    plt.ylabel('Query Tokens')
    plt.title(f"{title}\nLayer {layer}, Head {head}")
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

print("Function defined: plot_encoder_attention()")

In [None]:
# Visualize English encoder attention
plot_encoder_attention(
    attention=en_result['encoder_attention'],
    tokens=en_result['tokens'],
    layer=0,
    head=0,
    title=f"English Encoder Attention\n'{en_text}'",
    filter_special=True
)

In [None]:
# Visualize English encoder attention - middle layer
middle_layer = en_result['encoder_attention'].shape[0] // 2
plot_encoder_attention(
    attention=en_result['encoder_attention'],
    tokens=en_result['tokens'],
    layer=middle_layer,
    head=0,
    title=f"English Encoder Attention (Middle Layer)\n'{en_text}'",
    filter_special=True
)

In [None]:
# Visualize English encoder attention - last layer
last_layer = en_result['encoder_attention'].shape[0] - 1
plot_encoder_attention(
    attention=en_result['encoder_attention'],
    tokens=en_result['tokens'],
    layer=last_layer,
    head=0,
    title=f"English Encoder Attention (Last Layer)\n'{en_text}'",
    filter_special=True
)

In [None]:
# Visualize Chinese encoder attention
plot_encoder_attention(
    attention=zh_result['encoder_attention'],
    tokens=zh_result['tokens'],
    layer=0,
    head=0,
    title=f"Chinese Encoder Attention\n'{zh_text}'",
    filter_special=True
)

In [None]:
# Visualize Chinese encoder attention - middle layer
middle_layer = zh_result['encoder_attention'].shape[0] // 2
plot_encoder_attention(
    attention=zh_result['encoder_attention'],
    tokens=zh_result['tokens'],
    layer=middle_layer,
    head=0,
    title=f"Chinese Encoder Attention (Middle Layer)\n'{zh_text}'",
    filter_special=True
)

In [None]:
# Visualize Chinese encoder attention - last layer
last_layer = zh_result['encoder_attention'].shape[0] - 1
plot_encoder_attention(
    attention=zh_result['encoder_attention'],
    tokens=zh_result['tokens'],
    layer=last_layer,
    head=0,
    title=f"Chinese Encoder Attention (Last Layer)\n'{zh_text}'",
    filter_special=True
)

## 5. Compare Attention Statistics

Compute basic statistics to compare English and Chinese encoder attention patterns.

In [None]:
# Average attention across all layers and heads
en_avg_attention = en_result['encoder_attention'].mean(axis=(0, 1))  # (seq_len, seq_len)
zh_avg_attention = zh_result['encoder_attention'].mean(axis=(0, 1))  # (seq_len, seq_len)

print("English encoder attention statistics:")
print(f"  Shape: {en_avg_attention.shape}")
print(f"  Mean:  {en_avg_attention.mean():.4f}")
print(f"  Std:   {en_avg_attention.std():.4f}")
print(f"  Min:   {en_avg_attention.min():.4f}")
print(f"  Max:   {en_avg_attention.max():.4f}")
print()

print("Chinese encoder attention statistics:")
print(f"  Shape: {zh_avg_attention.shape}")
print(f"  Mean:  {zh_avg_attention.mean():.4f}")
print(f"  Std:   {zh_avg_attention.std():.4f}")
print(f"  Min:   {zh_avg_attention.min():.4f}")
print(f"  Max:   {zh_avg_attention.max():.4f}")

## 6. Process Multiple Samples

Extract encoder attention for a few more sentence pairs to verify the pipeline.

In [None]:
# Process first 5 sentence pairs
num_samples = 5
results = []

for idx in range(num_samples):
    en_text = df.iloc[idx]['en']
    zh_text = df.iloc[idx]['zh']
    
    print(f"\nProcessing pair {idx}...")
    print(f"  EN: {en_text[:60]}...")
    print(f"  ZH: {zh_text[:60]}...")
    
    # Extract encoder attention for both directions
    en_result = extract_encoder_attention(en_text, 'eng_Latn', 'zho_Hans', tokenizer, model, device)
    zh_result = extract_encoder_attention(zh_text, 'zho_Hans', 'eng_Latn', tokenizer, model, device)
    
    results.append({
        'idx': idx,
        'en_text': en_text,
        'zh_text': zh_text,
        'en_tokens': en_result['tokens'],
        'zh_tokens': zh_result['tokens'],
        'en_attention': en_result['encoder_attention'],
        'zh_attention': zh_result['encoder_attention'],
        'en_translation': en_result['translation'],
        'zh_translation': zh_result['translation']
    })
    
    print(f"  EN attention shape: {en_result['encoder_attention'].shape}")
    print(f"  ZH attention shape: {zh_result['encoder_attention'].shape}")

print(f"\n✓ Processed {len(results)} sentence pairs")

In [None]:
# Summary of extracted data
print("Summary of extracted encoder attention:")
print(f"Total pairs processed: {len(results)}")
print()

for result in results:
    print(f"Pair {result['idx']}:")
    print(f"  English: {len(result['en_tokens'])} tokens, attention shape {result['en_attention'].shape}")
    print(f"  Chinese: {len(result['zh_tokens'])} tokens, attention shape {result['zh_attention'].shape}")
    print()

## Summary

This notebook demonstrates bidirectional encoder attention extraction:

1. ✅ **EN → ZH**: Extract English encoder attention
2. ✅ **ZH → EN**: Extract Chinese encoder attention
3. ✅ Visualize and compare attention patterns
4. ✅ Process multiple sentence pairs

**Next Steps:**
- Scale to all 2000 sentence pairs
- Build attention graphs (tokens as nodes, attention weights as edges)
- Compute persistent homology (β₀, β₁) using TDA
- Compare topological structure across languages