# Attention Head Analysis

This notebook analyzes the attention heads in the attn-only-1l and attn-only-2l models.

## Methodology

### OV and QK Matrices
- **OV Matrix**: W_V @ W_O - maps from source token embedding to output effect on logits
- **QK Matrix**: W_Q @ W_K^T - determines attention patterns (which queries attend to which keys)

### Heuristic Decisions

1. **Pivot on Key**: We pivot the table on the key (source) token. This means each row shows one key and its corresponding queries and outputs. The key is shared between QK and OV matrices.

2. **QK Normalization**: QK values are normalized by subtracting the QK value for the special BOS token (token 50256) for each query, since models use it as a default value. This makes QK values comparable across different queries.

3. **OV Normalization**: OV values are normalized by subtracting the mean to make values consistent across different source tokens.

4. **Key Selection**: Keys are selected using the heuristic: `QK.max(0) * OV.max(0) * token_prob**0.1`. This favors:
   - Keys with queries that strongly prefer them (high QK)
   - Keys with large effect on output (high OV)
   - Slightly prefers probable keys (they occur more frequently)

5. **Query Selection**: Queries are selected by `QK[:, src] * token_prob**0.1` - upweighting probable tokens since they're more likely to occur.

6. **Output Selection**: Output tokens are simply those with the largest OV values.

In [None]:
import torch
import torch.nn.functional as F
from transformer_lens import HookedTransformer
import tiktoken
import numpy as np
from typing import List, Tuple
from collections import defaultdict

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

In [None]:
# Load models
print("Loading models...")
model_1l = HookedTransformer.from_pretrained("attn-only-1l", device=DEVICE).eval()
model_2l = HookedTransformer.from_pretrained("attn-only-2l", device=DEVICE).eval()

# Load tokenizer
enc = tiktoken.get_encoding("gpt2")
VOCAB_SIZE = enc.n_vocab
BOS_TOKEN_ID = enc.n_vocab  # GPT-2 uses vocab_size as BOS

print(f"1L Model: {model_1l.cfg.n_layers} layers, {model_1l.cfg.n_heads} heads, d_model={model_1l.cfg.d_model}")
print(f"2L Model: {model_2l.cfg.n_layers} layers, {model_2l.cfg.n_heads} heads, d_model={model_2l.cfg.d_model}")

In [None]:
# Load word frequency data for token probabilities
def load_token_frequencies():
    """Load token frequencies from words.txt to estimate token probabilities."""
    try:
        with open('words.txt', 'r') as f:
            content = f.read()
            token_counts = defaultdict(int)
            total_tokens = 0
            
            for line in content.split('\n'):
                line = line.strip()
                if not line:
                    continue
                ids = enc.encode(line)
                for tid in ids:
                    token_counts[tid] += 1
                    total_tokens += 1
            
            # Convert to probabilities
            token_probs = torch.zeros(VOCAB_SIZE)
            for tid, count in token_counts.items():
                token_probs[tid] = count / total_tokens
            
            # Add small smoothing for unseen tokens
            token_probs = token_probs + 1e-10
            token_probs = token_probs / token_probs.sum()
            
            print(f"Loaded frequencies for {len(token_counts)} unique tokens from {total_tokens} total tokens")
            return token_probs
    except FileNotFoundError:
        print("words.txt not found, using uniform distribution")
        return torch.ones(VOCAB_SIZE) / VOCAB_SIZE

token_probs = load_token_frequencies().to(DEVICE)

In [None]:
@torch.no_grad()
def compute_qk_ov_matrices(model: HookedTransformer, layer: int):
    """
    Compute QK and OV matrices for all heads in a layer.
    
    Returns:
        qk_matrices: [n_heads, d_vocab, d_vocab] - QK effect via embedding
        ov_matrices: [n_heads, d_vocab, d_vocab] - OV effect via unembedding
    """
    n_heads = model.cfg.n_heads
    d_vocab = model.cfg.d_vocab
    
    # Get weight matrices
    W_Q = model.W_Q[layer]  # [n_heads, d_model, d_head]
    W_K = model.W_K[layer]  # [n_heads, d_model, d_head]
    W_V = model.W_V[layer]  # [n_heads, d_model, d_head]
    W_O = model.W_O[layer]  # [n_heads, d_head, d_model]
    W_E = model.W_E         # [d_vocab, d_model]
    W_U = model.W_U         # [d_model, d_vocab]
    
    # Compute QK circuit: W_Q @ W_K^T
    # For each head: [d_model, d_head] @ [d_head, d_model] = [d_model, d_model]
    W_QK = torch.einsum('hqi,hki->hqk', W_Q, W_K)  # [n_heads, d_model, d_model]
    
    # Compute OV circuit: W_V @ W_O
    # For each head: [d_model, d_head] @ [d_head, d_model] = [d_model, d_model]
    W_OV = torch.einsum('hvi,hio->hvo', W_V, W_O)  # [n_heads, d_model, d_model]
    
    # Project through embedding and unembedding
    # QK in token space: [d_vocab, d_model] @ [d_model, d_model] @ [d_model, d_vocab]
    qk_matrices = torch.einsum('vq,hqk,kw->hvw', W_E, W_QK, W_E.T)  # [n_heads, d_vocab, d_vocab]
    
    # OV in token space: [d_vocab, d_model] @ [d_model, d_model] @ [d_model, d_vocab]
    ov_matrices = torch.einsum('vs,hso,ow->hvw', W_E, W_OV, W_U)  # [n_heads, d_vocab, d_vocab]
    
    return qk_matrices, ov_matrices

In [None]:
@torch.no_grad()
def normalize_qk_ov(qk_matrices, ov_matrices, bos_token_id: int):
    """
    Normalize QK and OV matrices according to the heuristics.
    
    - QK: Subtract the QK value for BOS token for each query
    - OV: Subtract the mean for each source token
    """
    qk_normalized = qk_matrices.clone()
    ov_normalized = ov_matrices.clone()
    
    # Normalize QK: subtract BOS baseline for each query
    # qk_matrices is [n_heads, query, key]
    # For each query, subtract qk_matrices[:, query, bos_token_id]
    bos_baseline = qk_matrices[:, :, bos_token_id:bos_token_id+1]  # [n_heads, d_vocab, 1]
    qk_normalized = qk_matrices - bos_baseline
    
    # Normalize OV: subtract mean for each source token
    # ov_matrices is [n_heads, source, output]
    ov_mean = ov_matrices.mean(dim=-1, keepdim=True)  # [n_heads, d_vocab, 1]
    ov_normalized = ov_matrices - ov_mean
    
    return qk_normalized, ov_normalized

In [None]:
def select_interesting_keys(
    qk_normalized: torch.Tensor,
    ov_normalized: torch.Tensor,
    token_probs: torch.Tensor,
    head_idx: int,
    n_keys: int = 20,
) -> List[int]:
    """
    Select interesting keys using the heuristic:
    QK.max(0) * OV.max(0) * token_prob**0.1
    
    Args:
        qk_normalized: [n_heads, query, key]
        ov_normalized: [n_heads, source, output]
        token_probs: [d_vocab]
        head_idx: Which head to analyze
        n_keys: Number of top keys to return
    """
    # Get max QK value across all queries for each key
    max_qk_per_key = qk_normalized[head_idx].max(dim=0).values  # [d_vocab]
    
    # Get max OV value across all outputs for each source
    max_ov_per_source = ov_normalized[head_idx].max(dim=1).values  # [d_vocab]
    
    # Compute importance score
    importance = max_qk_per_key * max_ov_per_source * (token_probs ** 0.1)
    
    # Get top keys
    top_key_indices = torch.topk(importance, k=min(n_keys, len(importance))).indices
    
    return top_key_indices.tolist()

In [None]:
def get_top_queries_and_outputs(
    qk_normalized: torch.Tensor,
    ov_normalized: torch.Tensor,
    token_probs: torch.Tensor,
    head_idx: int,
    key_idx: int,
    n_queries: int = 20,
    n_outputs: int = 20,
) -> Tuple[List[Tuple[int, float]], List[Tuple[int, float]]]:
    """
    Get top queries and outputs for a given key.
    
    Queries: sorted by QK[:, key] * token_prob**0.1
    Outputs: sorted by OV[key, :]
    
    Returns:
        top_queries: List of (token_id, score)
        top_outputs: List of (token_id, score)
    """
    # Get QK scores for all queries attending to this key
    qk_scores = qk_normalized[head_idx, :, key_idx]  # [d_vocab]
    query_importance = qk_scores * (token_probs ** 0.1)
    
    # Get top queries
    top_query_indices = torch.topk(query_importance, k=min(n_queries, len(query_importance))).indices
    top_queries = [(idx.item(), qk_scores[idx].item()) for idx in top_query_indices]
    
    # Get OV scores for all output effects from this source
    ov_scores = ov_normalized[head_idx, key_idx, :]  # [d_vocab]
    
    # Get top outputs
    top_output_indices = torch.topk(ov_scores, k=min(n_outputs, len(ov_scores))).indices
    top_outputs = [(idx.item(), ov_scores[idx].item()) for idx in top_output_indices]
    
    return top_queries, top_outputs

In [None]:
def format_token(token_id: int) -> str:
    """Format a token for display, handling special characters."""
    token_str = enc.decode([token_id])
    # Escape special characters for display
    token_str = repr(token_str)[1:-1]  # Remove outer quotes from repr
    return f"'{token_str}'"

def print_head_analysis(
    model_name: str,
    layer: int,
    head: int,
    qk_normalized: torch.Tensor,
    ov_normalized: torch.Tensor,
    token_probs: torch.Tensor,
    n_keys: int = 5,
    n_queries: int = 20,
    n_outputs: int = 20,
):
    """
    Print formatted analysis for a single attention head.
    """
    print(f"\n{'='*80}")
    print(f"Head {layer}:{head} ({model_name})")
    print(f"{'='*80}")
    
    # Select interesting keys
    interesting_keys = select_interesting_keys(
        qk_normalized, ov_normalized, token_probs, head, n_keys=n_keys
    )
    
    for key_idx in interesting_keys:
        key_token = format_token(key_idx)
        
        # Get top queries and outputs
        top_queries, top_outputs = get_top_queries_and_outputs(
            qk_normalized, ov_normalized, token_probs,
            head, key_idx, n_queries, n_outputs
        )
        
        print(f"\n{'─'*80}")
        print(f"Key: {key_token}")
        print(f"{'─'*80}")
        
        # Print queries that prefer this key
        print("\nQueries that prefer key:")
        query_strs = [f"{format_token(tid)} ({score:.2f})" for tid, score in top_queries]
        # Print in rows of ~80 chars
        current_line = ""
        for qs in query_strs:
            if len(current_line) + len(qs) + 2 > 78:
                print(current_line)
                current_line = qs
            else:
                if current_line:
                    current_line += " " + qs
                else:
                    current_line = qs
        if current_line:
            print(current_line)
        
        # Print effect on logits
        print("\nEffect on logits:")
        output_strs = [f"{format_token(tid)} ({score:.2f})" for tid, score in top_outputs]
        current_line = ""
        for os in output_strs:
            if len(current_line) + len(os) + 2 > 78:
                print(current_line)
                current_line = os
            else:
                if current_line:
                    current_line += " " + os
                else:
                    current_line = os
        if current_line:
            print(current_line)

## Analysis: 1-Layer Model (attn-only-1l)

In [None]:
# Compute QK and OV matrices for 1L model
print("Computing QK and OV matrices for 1-layer model...")
qk_1l_L0, ov_1l_L0 = compute_qk_ov_matrices(model_1l, layer=0)

# Normalize
qk_1l_L0_norm, ov_1l_L0_norm = normalize_qk_ov(qk_1l_L0, ov_1l_L0, BOS_TOKEN_ID)

print(f"QK matrix shape: {qk_1l_L0_norm.shape}")
print(f"OV matrix shape: {ov_1l_L0_norm.shape}")
print(f"Number of heads: {model_1l.cfg.n_heads}")

In [None]:
# Analyze each head in layer 0
for head in range(model_1l.cfg.n_heads):
    print_head_analysis(
        "attn-only-1l",
        layer=0,
        head=head,
        qk_normalized=qk_1l_L0_norm,
        ov_normalized=ov_1l_L0_norm,
        token_probs=token_probs,
        n_keys=5,  # Show 5 interesting keys per head
        n_queries=20,
        n_outputs=20,
    )

## Analysis: 2-Layer Model (attn-only-2l)

In [None]:
# Compute QK and OV matrices for 2L model
print("Computing QK and OV matrices for 2-layer model...")
qk_2l_L0, ov_2l_L0 = compute_qk_ov_matrices(model_2l, layer=0)
qk_2l_L1, ov_2l_L1 = compute_qk_ov_matrices(model_2l, layer=1)

# Normalize
qk_2l_L0_norm, ov_2l_L0_norm = normalize_qk_ov(qk_2l_L0, ov_2l_L0, BOS_TOKEN_ID)
qk_2l_L1_norm, ov_2l_L1_norm = normalize_qk_ov(qk_2l_L1, ov_2l_L1, BOS_TOKEN_ID)

print(f"Layer 0 - QK matrix shape: {qk_2l_L0_norm.shape}")
print(f"Layer 0 - OV matrix shape: {ov_2l_L0_norm.shape}")
print(f"Layer 1 - QK matrix shape: {qk_2l_L1_norm.shape}")
print(f"Layer 1 - OV matrix shape: {ov_2l_L1_norm.shape}")
print(f"Number of heads per layer: {model_2l.cfg.n_heads}")

In [None]:
# Analyze each head in layer 0
print("\n" + "#"*80)
print("# LAYER 0")
print("#"*80)
for head in range(model_2l.cfg.n_heads):
    print_head_analysis(
        "attn-only-2l",
        layer=0,
        head=head,
        qk_normalized=qk_2l_L0_norm,
        ov_normalized=ov_2l_L0_norm,
        token_probs=token_probs,
        n_keys=5,
        n_queries=20,
        n_outputs=20,
    )

In [None]:
# Analyze each head in layer 1
print("\n" + "#"*80)
print("# LAYER 1")
print("#"*80)
for head in range(model_2l.cfg.n_heads):
    print_head_analysis(
        "attn-only-2l",
        layer=1,
        head=head,
        qk_normalized=qk_2l_L1_norm,
        ov_normalized=ov_2l_L1_norm,
        token_probs=token_probs,
        n_keys=5,
        n_queries=20,
        n_outputs=20,
    )

## Interpretation Guide

For each attention head, we show:

1. **Key**: The source token being analyzed
2. **Queries that prefer key**: Query tokens that would strongly attend to this key (with normalized QK scores)
3. **Effect on logits**: Output tokens that are promoted when attending to this key (with normalized OV values)

### What to look for:

- **Induction heads**: Look for heads where queries after token X prefer keys at token X, and the output promotes the token that followed X previously
- **Previous token heads**: Heads that attend to the previous token (queries prefer keys at position -1)
- **Positional heads**: Heads with consistent attention patterns based on relative positions
- **Syntactic heads**: Heads that attend to specific syntactic patterns (e.g., matching brackets, quotes)

### Scores:

- **QK scores**: Higher = query strongly prefers this key
- **OV scores**: Higher = this source token strongly promotes this output token

The normalization makes these scores comparable across different queries and keys.