# Attention Convergence Test Across Models

Test whether different models attend to similar "trigger tokens" in harmful prompts.
Uses HarmBench/JailbreakBench examples and multiple models to analyze attention patterns.

In [None]:
from dotenv import load_dotenv
import os
import sys
import json
from datetime import datetime

load_dotenv()

# Setup paths
if os.path.basename(os.getcwd()) == 'notebooks':
    PROJECT_ROOT = os.path.dirname(os.getcwd())
else:
    PROJECT_ROOT = os.getcwd()

sys.path.insert(0, os.path.join(PROJECT_ROOT, 'src'))

print(f"Project root: {PROJECT_ROOT}")

In [None]:
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from collections import defaultdict
from dataclasses import dataclass
import gc

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

torch.set_grad_enabled(False)

import warnings
warnings.filterwarnings('ignore')

## Configuration

In [None]:
# Models to test
MODELS_TO_TEST = {
    "gemma3-4b": {
        "name": "google/gemma-3-4b-it",
        "quantize": False,
    },
    "qwen3-4b": {
        "name": "Qwen/Qwen3-4B",
        "quantize": False,
    },
    "llama-8b-bnb": {
        "name": "unsloth/Llama-3.1-8B-Instruct-bnb-4bit",  # Pre-quantized BNB model
        "quantize": False,  # Already quantized
    },
}

# Experiment parameters
N_EXAMPLES = 30
TOP_K_TOKENS = 10
MAX_NEW_TOKENS = 100
LAYER_PERCENT = 50  # Use middle layer for attention

print(f"Testing {len(MODELS_TO_TEST)} models on {N_EXAMPLES} examples")
print(f"Models: {list(MODELS_TO_TEST.keys())}")

## Load HarmBench Data

In [None]:
from datasets import load_dataset

print("Loading JailbreakBench...")
jbb = load_dataset("JailbreakBench/JBB-Behaviors", "behaviors")
all_harmful = list(jbb["harmful"])

# Sample N_EXAMPLES
behaviors = all_harmful[:N_EXAMPLES]

print(f"Loaded {len(behaviors)} behaviors (from {len(all_harmful)} total)")
print(f"\nExample behaviors:")
for i, b in enumerate(behaviors[:3]):
    print(f"  [{i}] {b['Goal'][:80]}...")
    print(f"      Category: {b.get('Category', 'N/A')}")

## Attention Capture Utilities

In [None]:
@dataclass
class AttentionResult:
    """Stores attention capture results."""
    attention_weights: torch.Tensor  # [num_heads, seq_len, seq_len]
    prompt_tokens: list[str]
    response_tokens: list[str]
    prompt_end_idx: int
    
    def get_response_to_prompt_attention(self) -> torch.Tensor:
        """Get attention from response tokens to prompt tokens."""
        # [heads, response_len, prompt_len]
        return self.attention_weights[:, self.prompt_end_idx:, :self.prompt_end_idx]
    
    def get_prompt_attention_scores(self, head: int | None = None) -> torch.Tensor:
        """Get aggregated attention scores for each prompt token."""
        attn = self.get_response_to_prompt_attention().float()  # Convert to float32
        
        if head is not None:
            attn = attn[head:head+1]
        
        # Average over heads and response positions
        scores = attn.mean(dim=(0, 1))  # [prompt_len]
        return scores
    
    def get_top_attended_tokens(self, k: int = 10) -> list[tuple[int, str, float]]:
        """Get top-k attended prompt tokens."""
        scores = self.get_prompt_attention_scores()
        top_indices = torch.argsort(scores, descending=True)[:k]
        
        results = []
        for idx in top_indices:
            idx = idx.item()
            if idx < len(self.prompt_tokens):
                results.append((idx, self.prompt_tokens[idx], scores[idx].item()))
        return results

In [None]:
def load_model(model_config: dict, device: str = "cuda"):
    """Load model and tokenizer."""
    model_name = model_config["name"]
    
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    load_kwargs = {
        "trust_remote_code": True,
        "torch_dtype": torch.bfloat16,
        "attn_implementation": "eager",  # Need eager for attention weights
    }
    
    if model_config.get("quantize", False):
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
        load_kwargs["quantization_config"] = bnb_config
        load_kwargs["device_map"] = "auto"
    else:
        load_kwargs["device_map"] = device
    
    # Special handling for Gemma 3 - load text-only model without vision tower
    if "gemma-3" in model_name.lower():
        from transformers import AutoConfig
        from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig
        from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
        
        print("Loading Gemma 3 in text-only mode (no vision tower)...")
        
        # Extract text config from multimodal config
        multi_config = AutoConfig.from_pretrained(model_name)
        text_cfg_dict = multi_config.text_config.to_dict()
        text_cfg_dict["vocab_size"] = 262208
        if tokenizer.pad_token_id is not None:
            text_cfg_dict["pad_token_id"] = tokenizer.pad_token_id
        text_cfg_dict["bos_token_id"] = tokenizer.bos_token_id
        text_cfg_dict["eos_token_id"] = tokenizer.eos_token_id
        
        text_config = Gemma3TextConfig(**text_cfg_dict)
        
        # Remove device_map for Gemma3ForCausalLM loading
        gemma_kwargs = {k: v for k, v in load_kwargs.items() if k != "device_map"}
        
        model = Gemma3ForCausalLM.from_pretrained(
            model_name,
            config=text_config,
            device_map="auto",
            **gemma_kwargs
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
    
    model.eval()
    
    return model, tokenizer


def get_num_layers(model) -> int:
    """Get number of layers in the model."""
    if hasattr(model, "config"):
        if hasattr(model.config, "num_hidden_layers"):
            return model.config.num_hidden_layers
        if hasattr(model.config, "n_layer"):
            return model.config.n_layer
    return 32  # fallback


def cleanup_model(model, tokenizer):
    """Clean up model to free GPU memory."""
    del model
    del tokenizer
    gc.collect()
    torch.cuda.empty_cache()
    print("Model cleaned up")

In [None]:
def capture_attention(
    model,
    tokenizer,
    prompt: str,
    layer_idx: int,
    max_new_tokens: int = 100,
) -> AttentionResult:
    """Generate response and capture attention at specified layer."""
    
    # Format prompt
    messages = [{"role": "user", "content": prompt}]
    formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    
    # Tokenize and find prompt boundary
    prompt_inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
    prompt_len = prompt_inputs["input_ids"].shape[1]
    
    # Generate response
    with torch.no_grad():
        output_ids = model.generate(
            prompt_inputs["input_ids"],
            attention_mask=prompt_inputs["attention_mask"],
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    # Get full sequence
    full_ids = output_ids[0]
    
    # Forward pass with attention output
    with torch.no_grad():
        outputs = model(
            input_ids=full_ids.unsqueeze(0),
            output_attentions=True,
            return_dict=True,
        )
    
    # Extract attention from target layer
    # Shape: [batch, heads, seq_len, seq_len]
    layer_attention = outputs.attentions[layer_idx][0]  # Remove batch dim
    
    # Decode tokens
    all_tokens = [tokenizer.decode([tid]) for tid in full_ids]
    prompt_tokens = all_tokens[:prompt_len]
    response_tokens = all_tokens[prompt_len:]
    
    return AttentionResult(
        attention_weights=layer_attention.cpu(),
        prompt_tokens=prompt_tokens,
        response_tokens=response_tokens,
        prompt_end_idx=prompt_len,
    )

## Run Attention Analysis

In [None]:
all_results = {}

for model_key, model_config in MODELS_TO_TEST.items():
    print(f"\n{'='*60}")
    print(f"Loading: {model_key} ({model_config['name']})")
    print(f"{'='*60}")
    
    model = None
    tokenizer = None
    
    try:
        model, tokenizer = load_model(model_config)
        
        num_layers = get_num_layers(model)
        target_layer = int(num_layers * LAYER_PERCENT / 100)
        print(f"Layers: {num_layers}, using layer {target_layer}")
        
        model_results = []
        
        for idx, behavior in enumerate(tqdm(behaviors, desc=f"Processing {model_key}")):
            prompt = behavior["Goal"]
            
            try:
                result = capture_attention(
                    model, tokenizer, prompt,
                    layer_idx=target_layer,
                    max_new_tokens=MAX_NEW_TOKENS,
                )
                
                top_tokens = result.get_top_attended_tokens(k=TOP_K_TOKENS)
                scores = result.get_prompt_attention_scores()
                
                model_results.append({
                    "behavior_id": idx,
                    "category": behavior.get("Category", "unknown"),
                    "prompt": prompt,
                    "prompt_tokens": result.prompt_tokens,
                    "attention_scores": scores.float().numpy(),
                    "top_tokens": [(i, tok, float(score)) for i, tok, score in top_tokens],
                    "response_preview": tokenizer.decode(
                        tokenizer.encode("".join(result.response_tokens))[:50],
                        skip_special_tokens=True
                    ),
                })
                
            except Exception as e:
                print(f"\nError on behavior {idx}: {e}")
                model_results.append({
                    "behavior_id": idx,
                    "prompt": prompt,
                    "error": str(e),
                })
        
        all_results[model_key] = model_results
        print(f"\nCompleted {len(model_results)} examples")
        
    except Exception as e:
        print(f"Failed to load {model_key}: {e}")
        all_results[model_key] = [{"error": str(e)}]
    
    finally:
        # Aggressive cleanup
        if model is not None:
            del model
        if tokenizer is not None:
            del tokenizer
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        gc.collect()
        print(f"Cleaned up {model_key}")

print(f"\n{'='*60}")
print("Done!")
print(f"{'='*60}")

## Analyze Token Convergence

In [None]:
def extract_top_token_strings(results: list, top_k: int = 5) -> list:
    """Extract token strings from top attended tokens."""
    token_sets = []
    for r in results:
        if "top_tokens" in r:
            tokens = [tok.strip().lower() for _, tok, _ in r["top_tokens"][:top_k]]
            token_sets.append(set(tokens))
    return token_sets


def compute_token_overlap(token_sets_a: list, token_sets_b: list) -> float:
    """Compute average Jaccard similarity."""
    similarities = []
    for set_a, set_b in zip(token_sets_a, token_sets_b):
        if len(set_a | set_b) > 0:
            jaccard = len(set_a & set_b) / len(set_a | set_b)
            similarities.append(jaccard)
    return np.mean(similarities) if similarities else 0.0


model_names = list(all_results.keys())
n_models = len(model_names)

print("Token Overlap Matrix (Jaccard Similarity of Top-5 Attended Tokens)")
print("="*60)

overlap_matrix = np.zeros((n_models, n_models))

for i, model_a in enumerate(model_names):
    tokens_a = extract_top_token_strings(all_results[model_a], top_k=5)
    for j, model_b in enumerate(model_names):
        tokens_b = extract_top_token_strings(all_results[model_b], top_k=5)
        overlap = compute_token_overlap(tokens_a, tokens_b)
        overlap_matrix[i, j] = overlap

overlap_df = pd.DataFrame(overlap_matrix, index=model_names, columns=model_names)
print(overlap_df.round(3))

In [None]:
# Global token frequency
print("\nMost Commonly Attended Tokens Across All Models")
print("="*60)

global_token_counts = defaultdict(lambda: {"count": 0, "total_score": 0.0, "models": set()})

for model_name, results in all_results.items():
    for r in results:
        if "top_tokens" in r:
            for idx, tok, score in r["top_tokens"][:5]:
                tok_clean = tok.strip().lower()
                if len(tok_clean) > 1:
                    global_token_counts[tok_clean]["count"] += 1
                    global_token_counts[tok_clean]["total_score"] += score
                    global_token_counts[tok_clean]["models"].add(model_name)

sorted_tokens = sorted(
    global_token_counts.items(),
    key=lambda x: (len(x[1]["models"]), x[1]["count"]),
    reverse=True
)

print(f"{'Token':<20} {'Count':<8} {'Avg Score':<12} {'Models'}")
print("-"*60)
for tok, stats in sorted_tokens[:20]:
    avg_score = stats["total_score"] / stats["count"]
    models_str = ", ".join(sorted(stats["models"]))
    print(f"{tok:<20} {stats['count']:<8} {avg_score:<12.4f} {models_str}")

In [None]:
# Per-example comparison
print("\nPer-Example Token Agreement")
print("="*60)

for example_idx in range(min(5, N_EXAMPLES)):
    print(f"\n--- Example {example_idx} ---")
    
    first_result = all_results[model_names[0]][example_idx]
    if "prompt" in first_result:
        print(f"Prompt: {first_result['prompt'][:80]}...")
    
    for model_name in model_names:
        result = all_results[model_name][example_idx]
        
        if "top_tokens" in result:
            top_3 = [f"'{tok}'({score:.3f})" for _, tok, score in result["top_tokens"][:3]]
            print(f"  {model_name:<15}: {', '.join(top_3)}")
        else:
            print(f"  {model_name:<15}: [error]")

## Visualize

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(overlap_matrix, cmap='Blues', vmin=0, vmax=1)

ax.set_xticks(range(n_models))
ax.set_yticks(range(n_models))
ax.set_xticklabels(model_names, rotation=45, ha='right')
ax.set_yticklabels(model_names)

for i in range(n_models):
    for j in range(n_models):
        ax.text(j, i, f"{overlap_matrix[i,j]:.2f}", ha='center', va='center', fontsize=12)

ax.set_title("Token Overlap (Jaccard Similarity)\nTop-5 Attended Tokens")
plt.colorbar(im, ax=ax, label='Jaccard Similarity')
plt.tight_layout()
plt.show()

In [None]:
# Attention score distributions
fig, axes = plt.subplots(1, len(model_names), figsize=(4*len(model_names), 4), sharey=True)
if len(model_names) == 1:
    axes = [axes]

for ax, model_name in zip(axes, model_names):
    scores = []
    for r in all_results[model_name]:
        if "top_tokens" in r:
            scores.extend([score for _, _, score in r["top_tokens"][:5]])
    
    if scores:
        ax.hist(scores, bins=20, alpha=0.7, edgecolor='black')
        ax.set_xlim(0, max(scores) * 1.1)
    ax.set_xlabel('Attention Score')
    ax.set_title(model_name)

axes[0].set_ylabel('Frequency')
fig.suptitle('Distribution of Top-5 Attention Scores', y=1.02)
plt.tight_layout()
plt.show()

## Save Results

In [None]:
def make_serializable(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: make_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [make_serializable(v) for v in obj]
    elif isinstance(obj, set):
        return list(obj)
    return obj

output_dir = os.path.join(PROJECT_ROOT, "results")
os.makedirs(output_dir, exist_ok=True)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = os.path.join(output_dir, f"attention_convergence_{timestamp}.json")

save_data = {
    "config": {
        "n_examples": N_EXAMPLES,
        "top_k_tokens": TOP_K_TOKENS,
        "models": {k: v["name"] for k, v in MODELS_TO_TEST.items()},
        "layer_percent": LAYER_PERCENT,
    },
    "overlap_matrix": overlap_matrix.tolist(),
    "model_names": model_names,
    "results": make_serializable(all_results),
}

with open(output_file, "w") as f:
    json.dump(save_data, f, indent=2)

print(f"Saved to: {output_file}")

In [None]:
print("\n" + "="*60)
print("SUMMARY")
print("="*60)

mask = ~np.eye(n_models, dtype=bool)
avg_overlap = overlap_matrix[mask].mean() if n_models > 1 else 0.0
print(f"\nAverage cross-model token overlap: {avg_overlap:.3f}")

if n_models > 1:
    upper_tri = np.triu_indices(n_models, k=1)
    overlaps = overlap_matrix[upper_tri]
    
    max_idx = np.argmax(overlaps)
    min_idx = np.argmin(overlaps)
    
    print(f"Most similar: {model_names[upper_tri[0][max_idx]]} <-> {model_names[upper_tri[1][max_idx]]} ({overlaps[max_idx]:.3f})")
    print(f"Least similar: {model_names[upper_tri[0][min_idx]]} <-> {model_names[upper_tri[1][min_idx]]} ({overlaps[min_idx]:.3f})")

print("\nPer-model stats:")
for model_name in model_names:
    results = all_results[model_name]
    valid = [r for r in results if "top_tokens" in r]
    
    if valid:
        all_scores = [score for r in valid for _, _, score in r["top_tokens"][:5]]
        print(f"  {model_name}: {len(valid)}/{len(results)} valid, avg attn={np.mean(all_scores):.4f}, max={np.max(all_scores):.4f}")