# Extract All Encoder Attention Maps

Extract encoder self-attention maps for all 2000 sentence pairs in both directions.

**For Google Colab:**
1. Mount Google Drive (run cell below)
2. Set `ROOT_DIR` to your project folder path in code_fr_en
3. Enable GPU: Runtime ‚Üí Change runtime type ‚Üí GPU
4. Runtime: ~1-2 hours on Colab GPU (vs ~5 hours on CPU)

**For local execution:** Skip the Google Drive cell and run from "Import Libraries"

---

In [ ]:
# Mount Google Drive (only needed for Google Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # IMPORTANT: Set this to your code_fr_en directory path
    # This should point to where THIS notebook is located
    ROOT_DIR = "/content/drive/MyDrive/UofT/CSC2517/term_paper/code_fr_en"
    
    import os
    os.chdir(ROOT_DIR)
    print(f"‚úì Changed to: {os.getcwd()}")
except ImportError:
    print("Not running on Colab, using local environment")

In [None]:
# Verify working directory and required files
import os
from pathlib import Path

print(f"Current directory: {os.getcwd()}")

# Check model
model_path = "../models/nllb-1.3B"
if os.path.exists(model_path):
    print(f"‚úì Model directory exists: {model_path}")
else:
    print(f"‚úó Model directory NOT found: {model_path}")

# Check data
data_path = "../data/sentence_pairs_fr_en.pkl"
if os.path.exists(data_path):
    print(f"‚úì Data file exists: {data_path}")
else:
    print(f"‚úó Data file NOT found: {data_path}")

## Import Libraries

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
import numpy as np
from pathlib import Path
import pickle
from tqdm import tqdm
import time

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"MPS available: {torch.backends.mps.is_available() if hasattr(torch.backends, 'mps') else False}")

if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## Define Attention Extraction Function

**Key modification:** Only extract the **last encoder layer** (layer 23 out of 24) to save memory

In [None]:
# Install required packages if running on Colab
try:
    import google.colab
    print("Installing packages for Colab...")
    import subprocess
    subprocess.run(['pip', 'install', '-q', 'transformers', 'datasets', 'torch', 'pandas', 'numpy', 'tqdm'], check=True)
    print("‚úì Packages installed")
except ImportError:
    print("Not on Colab, skipping package installation")

## Load Model and Data

In [ ]:
# Device setup
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA (NVIDIA GPU)")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon)")
else:
    device = torch.device("cpu")
    print("Using CPU")

# Load model and tokenizer with eager attention (required for output_attentions=True)
model_path = "../models/nllb-1.3B"
print(f"Loading model from {model_path}...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
    attn_implementation="eager"  # Required for extracting attention weights
).to(device)
model.eval()
print("‚úì Model loaded successfully!")
print()

In [None]:
def extract_encoder_attention(text, src_lang, tgt_lang, tokenizer, model, device):
    """
    Extract LAST LAYER encoder self-attention for a given source text.
    
    Args:
        text: Source text string
        src_lang: Source language code (e.g., 'eng_Latn', 'fra_Latn')
        tgt_lang: Target language code (e.g., 'fra_Latn', 'eng_Latn')
        tokenizer: NLLB tokenizer
        model: NLLB model (1.3B has 24 encoder layers)
        device: torch device
    
    Returns:
        dict with keys:
            - tokens: List of source tokens
            - encoder_attention: LAST LAYER encoder self-attention (num_heads, seq_len, seq_len)
            - translation: Generated translation text
    """
    # Set source language
    tokenizer.src_lang = src_lang
    
    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    # Get target language BOS token
    tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
    
    # Generate translation with attention output
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            forced_bos_token_id=tgt_lang_id,
            output_attentions=True,
            return_dict_in_generate=True,
            max_length=128
        )
    
    # Extract ONLY the last encoder layer attention (layer 23 out of 24 layers)
    # outputs.encoder_attentions is a tuple of (num_layers,)
    # Each element has shape: (batch_size, num_heads, seq_len, seq_len)
    last_layer_attention = outputs.encoder_attentions[-1]  # Get last layer
    last_layer_attention = last_layer_attention.squeeze(0)  # Remove batch dimension -> (num_heads, seq_len, seq_len)
    
    # Decode tokens
    input_tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0].cpu())
    translation = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
    
    return {
        'tokens': input_tokens,
        'encoder_attention': last_layer_attention.cpu().numpy().astype(np.float32),  # (num_heads, seq_len, seq_len)
        'translation': translation
    }


def save_checkpoint(results, output_path, checkpoint_num):
    """Save checkpoint to avoid losing progress."""
    checkpoint_path = output_path.parent / f"{output_path.stem}_checkpoint_{checkpoint_num}.pkl"
    with open(checkpoint_path, 'wb') as f:
        pickle.dump(results, f)
    print(f"  üíæ Checkpoint saved: {checkpoint_path.name}")


print("‚úì Functions defined")

## Configuration

In [None]:
print("="*80)
print("Extracting LAST LAYER Encoder Attention Maps for All 2000 Sentence Pairs")
print("="*80)

# Configuration
CHECKPOINT_INTERVAL = 100  # Save checkpoint every N examples
OUTPUT_DIR = Path("../data/attention_maps_fr_en")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_FILE = OUTPUT_DIR / "all_encoder_attention_last_layer.pkl"

print(f"Output directory: {OUTPUT_DIR}")
print(f"Output file: {OUTPUT_FILE.name}")
print(f"Checkpoint interval: {CHECKPOINT_INTERVAL} examples")
print()

## Load Data

In [None]:
# Load data
data_path = Path("../data/sentence_pairs_fr_en.pkl")
print(f"Loading data from {data_path}...")
data = pd.read_pickle(data_path)
df = pd.DataFrame(data)
df = df.rename(columns={'english': 'en', 'french': 'fr'})
print(f"‚úì Loaded {len(df)} sentence pairs")
print()

## Check for Existing Checkpoint

In [ ]:
# Check for existing checkpoint to resume from
start_idx = 0
results = []
checkpoint_files = sorted(OUTPUT_DIR.glob(f"{OUTPUT_FILE.stem}_checkpoint_*.pkl"))

if checkpoint_files:
    latest_checkpoint = checkpoint_files[-1]
    print(f"Found checkpoint: {latest_checkpoint.name}")
    with open(latest_checkpoint, 'rb') as f:
        results = pickle.load(f)
    start_idx = len(results)
    print(f"‚úì Resuming from example {start_idx}")
    print()
else:
    print("No checkpoint found. Starting from the beginning.")
    print()

## Extract Attention Maps

**Note:** This extracts only the **last encoder layer** (layer 23 out of 24) to save memory.

**Runtime:** ~1-2 hours on GPU

In [ ]:
print(f"Extracting attention maps for {len(df)} sentence pairs...")
print(f"Progress will be saved every {CHECKPOINT_INTERVAL} examples")
print()

start_time = time.time()

for idx in tqdm(range(start_idx, len(df)), desc="Processing", unit="pair"):
    en_text = df.iloc[idx]['en']
    fr_text = df.iloc[idx]['fr']
    
    try:
        # Extract English encoder attention (EN ‚Üí FR)
        en_result = extract_encoder_attention(
            text=en_text,
            src_lang='eng_Latn',
            tgt_lang='fra_Latn',
            tokenizer=tokenizer,
            model=model,
            device=device
        )
        
        # Extract French encoder attention (FR ‚Üí EN)
        fr_result = extract_encoder_attention(
            text=fr_text,
            src_lang='fra_Latn',
            tgt_lang='eng_Latn',
            tokenizer=tokenizer,
            model=model,
            device=device
        )
        
        # Store results
        results.append({
            'idx': idx,
            'en_text': en_text,
            'fr_text': fr_text,
            'en_tokens': en_result['tokens'],
            'fr_tokens': fr_result['tokens'],
            'en_attention': en_result['encoder_attention'],  # (num_heads, seq_len, seq_len)
            'fr_attention': fr_result['encoder_attention'],  # (num_heads, seq_len, seq_len)
            'en_translation': en_result['translation'],
            'fr_translation': fr_result['translation']
        })
        
        # Save checkpoint periodically
        if (idx + 1) % CHECKPOINT_INTERVAL == 0:
            save_checkpoint(results, OUTPUT_FILE, idx + 1)
    
    except Exception as e:
        print(f"\n‚ö†Ô∏è  Error processing pair {idx}: {e}")
        print(f"   EN: {en_text[:60]}...")
        print(f"   FR: {fr_text[:60]}...")
        continue

elapsed_time = time.time() - start_time

print()
print("="*80)
print(f"‚úì Extraction complete! Processed {len(results)} sentence pairs")
print(f"‚è±Ô∏è  Total time: {elapsed_time / 60:.1f} minutes ({elapsed_time / len(results):.2f} sec/pair)")
print()

## Save Final Results

In [ ]:
print(f"Saving final results to {OUTPUT_FILE}...")
with open(OUTPUT_FILE, 'wb') as f:
    pickle.dump(results, f)
print(f"‚úì Saved to {OUTPUT_FILE}")

# Print summary statistics
file_size_mb = OUTPUT_FILE.stat().st_size / (1024 * 1024)
print()
print("="*80)
print("Summary Statistics")
print("="*80)
print(f"Total sentence pairs: {len(results)}")
print(f"Output file size: {file_size_mb:.1f} MB")
print(f"Average attention matrix shape (LAST LAYER ONLY):")
if results:
    sample = results[0]
    print(f"  English: {sample['en_attention'].shape} (num_heads, seq_len, seq_len)")
    print(f"  French:  {sample['fr_attention'].shape} (num_heads, seq_len, seq_len)")
print()

# Clean up checkpoint files
print("Cleaning up checkpoint files...")
for checkpoint_file in OUTPUT_DIR.glob(f"{OUTPUT_FILE.stem}_checkpoint_*.pkl"):
    checkpoint_file.unlink()
    print(f"  üóëÔ∏è  Removed {checkpoint_file.name}")

print()
print("="*80)
print("‚úÖ All done!")
print("="*80)

## Summary

This notebook extracts encoder self-attention maps for all 2000 sentence pairs in both directions.

**Key changes from previous version:**
- ‚úÖ **Only extracts last encoder layer (layer 23 out of 24)** to save memory (~24x less storage)
- ‚úÖ Supports both Colab and local environments
- ‚úÖ Upgraded to NLLB-1.3B model (24 encoder layers, 16 attention heads per layer)
- ‚úÖ GPU acceleration (CUDA/MPS) with float16 precision on CUDA
- ‚úÖ Checkpoint system to resume from interruptions

**Output format:**
- File: `all_encoder_attention_last_layer.pkl`
- Each entry contains:
  - `en_attention`: (16, seq_len, seq_len) - 16 attention heads from last layer
  - `fr_attention`: (16, seq_len, seq_len) - 16 attention heads from last layer
  
**Next steps:**
- Use this data for TDA analysis (persistent homology)
- Compare topological structure across languages

In [None]:
# Uncomment to download the results file
# from google.colab import files
# files.download(str(OUTPUT_FILE))