# Attention Convergence Test Across Models (Colab Version)

Test whether different model families and sizes attend to similar "trigger tokens" in harmful prompts.

**Efficient approach**: Single forward pass, extract last-token-to-all attention (no generation needed).

**Models tested**: Llama, Qwen, Gemma, Phi families at multiple sizes.

In [None]:
# Colab Setup - Run this first!
# Mount Google Drive for saving results

from google.colab import drive
drive.mount('/content/drive')

# Install dependencies
!pip install -q transformers accelerate bitsandbytes datasets huggingface_hub

# Login to HuggingFace (needed for Llama, Gemma)
from huggingface_hub import login
login()  # Enter your HF token when prompted

import os
RESULTS_DIR = "/content/drive/MyDrive/attention_convergence_results"
os.makedirs(RESULTS_DIR, exist_ok=True)
print(f"Results will be saved to: {RESULTS_DIR}")

In [None]:
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
import json
import gc

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

torch.set_grad_enabled(False)

import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## Configuration

In [None]:
# Models to test - organized by family and size
MODELS_TO_TEST = {
    # Llama family (Meta) - 4 sizes
    "llama-1b": "meta-llama/Llama-3.2-1B-Instruct",
    "llama-3b": "meta-llama/Llama-3.2-3B-Instruct",
    "llama-8b": "meta-llama/Llama-3.1-8B-Instruct",
    # "llama-70b": "meta-llama/Llama-3.1-70B-Instruct",
    
    # Qwen family (Alibaba) - 5 sizes
    # "qwen-0.5b": "Qwen/Qwen2.5-0.5B-Instruct",
    "qwen-3b": "Qwen/Qwen2.5-3B-Instruct",
    "qwen-7b": "Qwen/Qwen2.5-7B-Instruct",
    "qwen-14b": "Qwen/Qwen2.5-14B-Instruct",
    # "qwen-72b": "Qwen/Qwen2.5-72B-Instruct",
    
    # Gemma family (Google) - 3 sizes
    "gemma-2b": "google/gemma-2-2b-it",
    "gemma-9b": "google/gemma-2-9b-it",
    "gemma-27b": "google/gemma-2-27b-it",
    
    # Phi family (Microsoft) - 3 sizes
    # "phi-mini": "microsoft/Phi-3-mini-4k-instruct",      # 3.8B
    # "phi-small": "microsoft/Phi-3-small-8k-instruct",   # 7B
    # "phi-medium": "microsoft/Phi-3-medium-4k-instruct", # 14B
}

# Experiment parameters
N_EXAMPLES = 50  # Number of harmful prompts to test
TOP_K_TOKENS = 10  # Top attended tokens to track
LAYER_PERCENT = 50  # Use middle layer (50% through)

# Quantization settings (for memory efficiency)
QUANTIZE_THRESHOLD_B = 8  # Quantize models >= 8B parameters

print(f"Testing {len(MODELS_TO_TEST)} models on {N_EXAMPLES} examples")
print(f"\nModels by family:")
for family in ["llama", "qwen", "gemma", "phi"]:
    models = [k for k in MODELS_TO_TEST.keys() if k.startswith(family)]
    print(f"  {family.capitalize()}: {models}")

## Load HarmBench Data

In [None]:
from datasets import load_dataset

print("Loading JailbreakBench behaviors...")
jbb = load_dataset("JailbreakBench/JBB-Behaviors", "behaviors")
all_harmful = list(jbb["harmful"])

# Sample N_EXAMPLES
behaviors = all_harmful[:N_EXAMPLES]

print(f"Loaded {len(behaviors)} behaviors (from {len(all_harmful)} total)")
print(f"\nExample behaviors:")
for i, b in enumerate(behaviors[:3]):
    print(f"  [{i}] {b['Goal'][:80]}...")
    print(f"      Category: {b.get('Category', 'N/A')}")

## Attention Capture Utilities

In [None]:
@dataclass
class AttentionResult:
    """Stores attention capture results for a single prompt."""
    attention_weights: torch.Tensor  # [num_heads, prompt_len] - last token attending to all
    tokens: list[str]  # Decoded prompt tokens
    token_ids: list[int]  # Raw token IDs
    
    def get_attention_scores(self, head: int | None = None) -> torch.Tensor:
        """Get attention scores for each prompt token.
        
        Args:
            head: Specific head index, or None for average across all heads.
        """
        attn = self.attention_weights.float()  # [heads, prompt_len]
        
        if head is not None:
            return attn[head]  # [prompt_len]
        else:
            return attn.mean(dim=0)  # Average over heads -> [prompt_len]
    
    def get_top_attended_tokens(self, k: int = 10) -> list[tuple[int, str, float]]:
        """Get top-k attended prompt tokens.
        
        Returns:
            List of (position, token_string, attention_score) tuples.
        """
        scores = self.get_attention_scores()
        top_indices = torch.argsort(scores, descending=True)[:k]
        
        results = []
        for idx in top_indices:
            idx = idx.item()
            if idx < len(self.tokens):
                results.append((idx, self.tokens[idx], scores[idx].item()))
        return results

In [None]:
def estimate_model_size(model_name: str) -> float:
    """Estimate model size in billions from name."""
    import re
    # Look for patterns like "70B", "7b", "0.5B", etc.
    match = re.search(r'(\d+\.?\d*)b', model_name.lower())
    if match:
        return float(match.group(1))
    # Special cases
    if "mini" in model_name.lower():
        return 3.8
    if "small" in model_name.lower():
        return 7.0
    if "medium" in model_name.lower():
        return 14.0
    return 7.0  # Default assumption


def load_model(model_name: str, quantize_threshold: float = 8.0):
    """Load model and tokenizer with appropriate settings.
    
    Args:
        model_name: HuggingFace model identifier.
        quantize_threshold: Quantize models >= this size (in B params).
    """
    estimated_size = estimate_model_size(model_name)
    should_quantize = estimated_size >= quantize_threshold
    
    print(f"  Estimated size: {estimated_size}B, quantize: {should_quantize}")
    
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    load_kwargs = {
        "trust_remote_code": True,
        "torch_dtype": torch.bfloat16,
        "attn_implementation": "eager",  # Required for output_attentions
        "device_map": "auto",
    }
    
    if should_quantize:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
        load_kwargs["quantization_config"] = bnb_config
    
    model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
    model.eval()
    
    return model, tokenizer


def get_num_layers(model) -> int:
    """Get number of layers in the model."""
    if hasattr(model, "config"):
        if hasattr(model.config, "num_hidden_layers"):
            return model.config.num_hidden_layers
        if hasattr(model.config, "n_layer"):
            return model.config.n_layer
    return 32  # fallback


def cleanup_model(model, tokenizer):
    """Clean up model to free GPU memory."""
    del model
    del tokenizer
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    gc.collect()

In [None]:
def capture_attention(
    model,
    tokenizer,
    prompt: str,
    layer_idx: int,
) -> AttentionResult:
    """Capture last-token-to-all attention with a single forward pass.
    
    This is much faster than generating a response - we only need to
    process the prompt once to see what the model attends to.
    
    Args:
        model: The language model.
        tokenizer: The tokenizer.
        prompt: The user prompt text.
        layer_idx: Which layer to extract attention from.
    
    Returns:
        AttentionResult with attention weights and token info.
    """
    # Format prompt with chat template
    messages = [{"role": "user", "content": prompt}]
    formatted = tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    # Tokenize
    inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
    input_ids = inputs["input_ids"]
    seq_len = input_ids.shape[1]
    
    # Single forward pass with attention output
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=inputs["attention_mask"],
            output_attentions=True,
            return_dict=True,
        )
    
    # Extract attention from target layer
    # Shape: [batch, heads, seq_len, seq_len]
    layer_attention = outputs.attentions[layer_idx][0]  # Remove batch dim
    
    # Get last token attending to all previous tokens
    # Shape: [heads, seq_len-1] (excluding self-attention to last token)
    last_token_attention = layer_attention[:, -1, :-1]
    
    # Decode tokens (excluding the last token which is usually generation prompt)
    token_ids = input_ids[0, :-1].tolist()
    tokens = [tokenizer.decode([tid]) for tid in token_ids]
    
    return AttentionResult(
        attention_weights=last_token_attention.cpu(),
        tokens=tokens,
        token_ids=token_ids,
    )

## Run Attention Analysis

In [None]:
all_results = {}

for model_key, model_name in MODELS_TO_TEST.items():
    print(f"\n{'='*60}")
    print(f"Loading: {model_key} ({model_name})")
    print(f"{'='*60}")
    
    model = None
    tokenizer = None
    
    try:
        model, tokenizer = load_model(model_name, quantize_threshold=QUANTIZE_THRESHOLD_B)
        
        num_layers = get_num_layers(model)
        target_layer = int(num_layers * LAYER_PERCENT / 100)
        print(f"  Layers: {num_layers}, using layer {target_layer}")
        
        model_results = []
        
        for idx, behavior in enumerate(tqdm(behaviors, desc=f"{model_key}")):
            prompt = behavior["Goal"]
            
            try:
                result = capture_attention(
                    model, tokenizer, prompt,
                    layer_idx=target_layer,
                )
                
                top_tokens = result.get_top_attended_tokens(k=TOP_K_TOKENS)
                scores = result.get_attention_scores()
                
                model_results.append({
                    "behavior_id": idx,
                    "category": behavior.get("Category", "unknown"),
                    "prompt": prompt,
                    "tokens": result.tokens,
                    "attention_scores": scores.numpy().tolist(),
                    "top_tokens": [(i, tok, float(score)) for i, tok, score in top_tokens],
                })
                
            except Exception as e:
                print(f"\n  Error on behavior {idx}: {e}")
                model_results.append({
                    "behavior_id": idx,
                    "prompt": prompt,
                    "error": str(e),
                })
        
        all_results[model_key] = model_results
        print(f"  Completed {len(model_results)} examples")
        
        # Save intermediate results after each model (in case of crash)
        intermediate_file = os.path.join(RESULTS_DIR, f"intermediate_{model_key}.json")
        with open(intermediate_file, "w") as f:
            json.dump({"model": model_key, "results": model_results}, f)
        print(f"  Saved intermediate: {intermediate_file}")
        
    except Exception as e:
        print(f"  Failed to load {model_key}: {e}")
        all_results[model_key] = [{"error": str(e)}]
    
    finally:
        if model is not None or tokenizer is not None:
            cleanup_model(model, tokenizer)
            print(f"  Cleaned up {model_key}")

print(f"\n{'='*60}")
print("Done processing all models!")
print(f"{'='*60}")

## Analyze Token Convergence

In [None]:
def extract_top_token_strings(results: list, top_k: int = 5) -> list:
    """Extract normalized token strings from top attended tokens."""
    token_sets = []
    for r in results:
        if "top_tokens" in r:
            # Normalize: lowercase, strip whitespace
            tokens = [tok.strip().lower() for _, tok, _ in r["top_tokens"][:top_k]]
            token_sets.append(set(tokens))
    return token_sets


def compute_token_overlap(token_sets_a: list, token_sets_b: list) -> float:
    """Compute average Jaccard similarity between paired token sets."""
    similarities = []
    for set_a, set_b in zip(token_sets_a, token_sets_b):
        if len(set_a | set_b) > 0:
            jaccard = len(set_a & set_b) / len(set_a | set_b)
            similarities.append(jaccard)
    return np.mean(similarities) if similarities else 0.0


def get_model_family(model_key: str) -> str:
    """Extract family name from model key."""
    return model_key.split("-")[0]


# Build overlap matrix
model_names = [k for k in all_results.keys() if "error" not in all_results[k][0]]
n_models = len(model_names)

print("Token Overlap Matrix (Jaccard Similarity of Top-5 Attended Tokens)")
print("="*70)

overlap_matrix = np.zeros((n_models, n_models))

for i, model_a in enumerate(model_names):
    tokens_a = extract_top_token_strings(all_results[model_a], top_k=5)
    for j, model_b in enumerate(model_names):
        tokens_b = extract_top_token_strings(all_results[model_b], top_k=5)
        overlap = compute_token_overlap(tokens_a, tokens_b)
        overlap_matrix[i, j] = overlap

overlap_df = pd.DataFrame(overlap_matrix, index=model_names, columns=model_names)
print(overlap_df.round(3))

In [None]:
# Global token frequency analysis
print("\nMost Commonly Attended Tokens Across All Models")
print("="*70)

global_token_counts = defaultdict(lambda: {"count": 0, "total_score": 0.0, "models": set(), "families": set()})

for model_name, results in all_results.items():
    family = get_model_family(model_name)
    for r in results:
        if "top_tokens" in r:
            for idx, tok, score in r["top_tokens"][:5]:
                tok_clean = tok.strip().lower()
                if len(tok_clean) > 1:  # Skip single chars
                    global_token_counts[tok_clean]["count"] += 1
                    global_token_counts[tok_clean]["total_score"] += score
                    global_token_counts[tok_clean]["models"].add(model_name)
                    global_token_counts[tok_clean]["families"].add(family)

# Sort by number of families, then count
sorted_tokens = sorted(
    global_token_counts.items(),
    key=lambda x: (len(x[1]["families"]), len(x[1]["models"]), x[1]["count"]),
    reverse=True
)

print(f"{'Token':<20} {'Families':<10} {'Models':<8} {'Count':<8} {'Avg Score':<12}")
print("-"*70)
for tok, stats in sorted_tokens[:25]:
    avg_score = stats["total_score"] / stats["count"]
    n_families = len(stats["families"])
    n_models = len(stats["models"])
    print(f"{tok:<20} {n_families:<10} {n_models:<8} {stats['count']:<8} {avg_score:<12.4f}")

In [None]:
# Per-example token agreement across models
print("\nPer-Example Token Agreement (first 5 examples)")
print("="*70)

for example_idx in range(min(5, N_EXAMPLES)):
    print(f"\n--- Example {example_idx} ---")
    
    # Get prompt from first valid model
    for mn in model_names:
        if example_idx < len(all_results[mn]):
            result = all_results[mn][example_idx]
            if "prompt" in result:
                print(f"Prompt: {result['prompt'][:80]}...")
                break
    
    # Show top tokens for each model
    for model_name in model_names:
        if example_idx < len(all_results[model_name]):
            result = all_results[model_name][example_idx]
            
            if "top_tokens" in result:
                top_3 = [f"'{tok}'({score:.3f})" for _, tok, score in result["top_tokens"][:3]]
                print(f"  {model_name:<15}: {', '.join(top_3)}")
            else:
                print(f"  {model_name:<15}: [error]")

## Visualize

In [None]:
import matplotlib.pyplot as plt

# Create color mapping for families
family_colors = {"llama": "C0", "qwen": "C1", "gemma": "C2", "phi": "C3"}

fig, ax = plt.subplots(figsize=(12, 10))
im = ax.imshow(overlap_matrix, cmap='Blues', vmin=0, vmax=1)

ax.set_xticks(range(n_models))
ax.set_yticks(range(n_models))
ax.set_xticklabels(model_names, rotation=45, ha='right', fontsize=9)
ax.set_yticklabels(model_names, fontsize=9)

# Add values to cells
for i in range(n_models):
    for j in range(n_models):
        val = overlap_matrix[i, j]
        color = 'white' if val > 0.5 else 'black'
        ax.text(j, i, f"{val:.2f}", ha='center', va='center', fontsize=8, color=color)

ax.set_title("Token Overlap (Jaccard Similarity)\nTop-5 Attended Tokens, Last-Token-to-All Attention", fontsize=12)
plt.colorbar(im, ax=ax, label='Jaccard Similarity', shrink=0.8)
plt.tight_layout()

# Save to Drive
fig.savefig(os.path.join(RESULTS_DIR, "overlap_matrix.png"), dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved: {RESULTS_DIR}/overlap_matrix.png")

In [None]:
# Size-based analysis: Does attention pattern change with model size?
print("\nSize Scaling Analysis")
print("="*70)

# Extract size from model key
def get_model_size(model_key: str) -> float:
    """Extract model size in billions from key."""
    import re
    match = re.search(r'(\d+\.?\d*)b', model_key.lower())
    if match:
        return float(match.group(1))
    if "mini" in model_key:
        return 3.8
    if "small" in model_key:
        return 7.0
    if "medium" in model_key:
        return 14.0
    return 0.0

# Compute size correlation within families
print("\nWithin-family size correlation:")
for family in sorted(set(get_model_family(m) for m in model_names)):
    family_models = sorted(
        [m for m in model_names if get_model_family(m) == family],
        key=get_model_size
    )
    
    if len(family_models) < 2:
        continue
    
    print(f"\n  {family.upper()}:")
    for i, m1 in enumerate(family_models):
        for m2 in family_models[i+1:]:
            idx1, idx2 = model_names.index(m1), model_names.index(m2)
            overlap = overlap_matrix[idx1, idx2]
            size1, size2 = get_model_size(m1), get_model_size(m2)
            print(f"    {m1} ({size1}B) <-> {m2} ({size2}B): {overlap:.3f}")

In [None]:
# Within-family vs cross-family comparison
print("\nWithin-Family vs Cross-Family Overlap")
print("="*70)

within_family = []
cross_family = []

for i, model_a in enumerate(model_names):
    family_a = get_model_family(model_a)
    for j, model_b in enumerate(model_names):
        if i >= j:
            continue
        family_b = get_model_family(model_b)
        
        if family_a == family_b:
            within_family.append(overlap_matrix[i, j])
        else:
            cross_family.append(overlap_matrix[i, j])

print(f"Within-family overlap:  mean={np.mean(within_family):.3f}, std={np.std(within_family):.3f}, n={len(within_family)}")
print(f"Cross-family overlap:   mean={np.mean(cross_family):.3f}, std={np.std(cross_family):.3f}, n={len(cross_family)}")

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Box plot comparison
axes[0].boxplot([within_family, cross_family], labels=['Within Family', 'Cross Family'])
axes[0].set_ylabel('Jaccard Similarity')
axes[0].set_title('Token Overlap: Within vs Cross Family')

# Per-family average overlap
families = list(set(get_model_family(m) for m in model_names))
family_avg_overlap = {}

for family in families:
    family_models = [m for m in model_names if get_model_family(m) == family]
    if len(family_models) > 1:
        indices = [model_names.index(m) for m in family_models]
        overlaps = []
        for i in indices:
            for j in indices:
                if i < j:
                    overlaps.append(overlap_matrix[i, j])
        family_avg_overlap[family] = np.mean(overlaps) if overlaps else 0

family_colors = {"llama": "C0", "qwen": "C1", "gemma": "C2", "phi": "C3"}
axes[1].bar(family_avg_overlap.keys(), family_avg_overlap.values(), color=[family_colors.get(f, 'gray') for f in family_avg_overlap.keys()])
axes[1].set_ylabel('Avg Jaccard Similarity')
axes[1].set_title('Within-Family Average Overlap')
axes[1].set_ylim(0, 1)

plt.tight_layout()
fig.savefig(os.path.join(RESULTS_DIR, "family_comparison.png"), dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved: {RESULTS_DIR}/family_comparison.png")

## Save Results

In [None]:
# Save final results to Google Drive
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = os.path.join(RESULTS_DIR, f"attention_convergence_{timestamp}.json")

# Prepare serializable data
def make_serializable(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: make_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [make_serializable(v) for v in obj]
    elif isinstance(obj, set):
        return list(obj)
    elif isinstance(obj, (np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, (np.int32, np.int64)):
        return int(obj)
    return obj

save_data = {
    "config": {
        "n_examples": N_EXAMPLES,
        "top_k_tokens": TOP_K_TOKENS,
        "layer_percent": LAYER_PERCENT,
        "quantize_threshold_b": QUANTIZE_THRESHOLD_B,
        "models": dict(MODELS_TO_TEST),
        "attention_method": "last_token_to_all",
    },
    "overlap_matrix": overlap_matrix.tolist(),
    "model_names": model_names,
    "within_family_overlap": {
        "mean": float(np.mean(within_family)),
        "std": float(np.std(within_family)),
        "values": within_family,
    },
    "cross_family_overlap": {
        "mean": float(np.mean(cross_family)),
        "std": float(np.std(cross_family)),
        "values": cross_family,
    },
    "results": make_serializable(all_results),
}

with open(output_file, "w") as f:
    json.dump(save_data, f, indent=2)

print(f"Saved final results: {output_file}")
print(f"File size: {os.path.getsize(output_file) / 1e6:.2f} MB")

In [None]:
print("\n" + "="*70)
print("SUMMARY")
print("="*70)

# Overall statistics
mask = ~np.eye(n_models, dtype=bool)
avg_overlap = overlap_matrix[mask].mean() if n_models > 1 else 0.0
print(f"\nOverall average token overlap: {avg_overlap:.3f}")

# Within vs cross family
print(f"\nWithin-family avg:  {np.mean(within_family):.3f}")
print(f"Cross-family avg:   {np.mean(cross_family):.3f}")
print(f"Difference:         {np.mean(within_family) - np.mean(cross_family):.3f}")

# Most/least similar pairs
if n_models > 1:
    upper_tri = np.triu_indices(n_models, k=1)
    overlaps = overlap_matrix[upper_tri]
    
    max_idx = np.argmax(overlaps)
    min_idx = np.argmin(overlaps)
    
    print(f"\nMost similar pair:  {model_names[upper_tri[0][max_idx]]} <-> {model_names[upper_tri[1][max_idx]]} ({overlaps[max_idx]:.3f})")
    print(f"Least similar pair: {model_names[upper_tri[0][min_idx]]} <-> {model_names[upper_tri[1][min_idx]]} ({overlaps[min_idx]:.3f})")

# Per-family summary
print("\nPer-family statistics:")
for family in sorted(families):
    family_models = [m for m in model_names if get_model_family(m) == family]
    valid_results = sum(1 for m in family_models for r in all_results[m] if "top_tokens" in r)
    total_results = sum(len(all_results[m]) for m in family_models)
    print(f"  {family:<8}: {len(family_models)} models, {valid_results}/{total_results} valid examples")

print(f"\nResults saved to: {RESULTS_DIR}")
print("Files:")
print(f"  - attention_convergence_{timestamp}.json (full data)")
print(f"  - overlap_matrix.png")
print(f"  - family_comparison.png")
print(f"  - intermediate_*.json (per-model backups)")