# Gemma2-27B Attention Analysis

This notebook analyzes how instruction tuning modifies attention patterns for PC and semantic vectors.

**Key Analyses:**
1. **QK Affinity**: Raw attention logits (before softmax) reveal semantic affinities
2. **VO Decomposition**: What semantic content flows through when attending to a vector
3. **Base vs Instruct Comparison**: How instruction tuning changes routing

**Approach:**
- Compute QK affinity matrices for PC and semantic vectors
- Compute VO decomposition (value-output transformation)
- Analyze z-scores relative to random baseline
- Identify layers and patterns where instruction tuning has strongest effect

In [2]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoConfig
from collections import OrderedDict

from chatspace.analysis import (
    load_pca_data,
    extract_pc_components,
    load_individual_role_vectors,
    load_individual_trait_vectors,
    normalize_vector,
    compute_qk_affinity_matrix,
    compute_vo_decomposition,
    compute_z_score_matrices,
    get_top_interactions,
    analyze_pc_pattern
)

%matplotlib inline
sns.set_style('whitegrid')

## 1. Load Models and Prepare Vectors

In [3]:
# Load models
base_model_id = "google/gemma-2-27b"
instruct_model_id = "google/gemma-2-27b-it"

print("Loading models...")
config = AutoConfig.from_pretrained(base_model_id)
base_model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.bfloat16, device_map="cpu")
instruct_model = AutoModelForCausalLM.from_pretrained(instruct_model_id, torch_dtype=torch.bfloat16, device_map="cpu")
print("✓ Models loaded")

`torch_dtype` is deprecated! Use `dtype` instead!


Loading models...


Loading checkpoint shards:   0%|          | 0/24 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/12 [00:00<?, ?it/s]

✓ Models loaded


In [None]:
# Load PCA data and extract all PCs we'll analyze
persona_data_root = Path("/workspace/persona-data")
roles_pca_dir = persona_data_root / "gemma-2-27b" / "roles_240" / "pca"
traits_pca_dir = persona_data_root / "gemma-2-27b" / "traits_240" / "pca"

pca_data, _ = load_pca_data(roles_pca_dir)
pca_layer = pca_data['layer']

# Extract ALL PCs we'll use (load 10 for flexibility)
n_pcs_total = 10
pcs_all, variance_all = extract_pc_components(pca_data, n_components=n_pcs_total)

print(f"✓ Loaded PCA data from layer {pca_layer}")
print(f"  Extracted {n_pcs_total} PCs")
print(f"  Variance explained by first 5 PCs: {variance_all[:5]}")

# Build test vectors: ALL PCs + random baseline
test_vecs_full = OrderedDict()

# Add all PCs (positive and negative)
for pc_idx in range(n_pcs_total):
    pc_name = f"PC{pc_idx+1}"
    test_vecs_full[pc_name] = normalize_vector(pcs_all[pc_idx])
    test_vecs_full[f"-{pc_name}"] = normalize_vector(-pcs_all[pc_idx].float())

# Add random baseline vectors
torch.manual_seed(42)
n_random = 20
for i in range(n_random):
    rand_vec = torch.randn(config.hidden_size, dtype=torch.float32)
    test_vecs_full[f"Random{i+1}"] = normalize_vector(rand_vec)

vec_names_full = list(test_vecs_full.keys())
vecs_tensor_full = torch.stack([test_vecs_full[name] for name in vec_names_full])

print(f"\n✓ Prepared {len(vec_names_full)} vectors:")
print(f"  {n_pcs_total*2} PCs (positive + negative)")
print(f"  {n_random} random baseline vectors")

In [None]:
# ============================================================================
# CONFIGURATION: Set analysis parameters here
# ============================================================================

# Which layers to compute attention patterns for
# ALL layers: 0-45 (46 total)
analysis_layers = list(range(config.num_hidden_layers))

# Which layers to use for averaging in PC comparison plots
# Middle layers where instruction tuning is most active
comparison_layers = list(range(17, 28))  # Layers 17-27

# Which PCs to visualize in layer-wise plots
plot_pcs = ["PC1", "PC2", "PC3"]
# plot_pcs = ["PC1"]  # Uncomment to focus on PC1 only

# How many PCs to include in PC number comparison
n_pcs_compare = 10

print(f"📋 Analysis Configuration:")
print(f"  Computing attention for {len(analysis_layers)} layers (all)")
print(f"  Averaging over layers {comparison_layers[0]}-{comparison_layers[-1]} for PC comparison")
print(f"  Visualizing {len(plot_pcs)} PCs in layer-wise plots: {plot_pcs}")
print(f"  Comparing {n_pcs_compare} PCs in PC number analysis")

## 2. Compute QK and VO Patterns for All Layers

Compute attention patterns once for all PCs and all layers.

In [None]:
# Compute QK affinity and VO decomposition for all layers
print(f"Computing QK and VO for {len(analysis_layers)} layers...")

n_vecs_full = len(vec_names_full)
qk_base_full = np.zeros((len(analysis_layers), n_vecs_full, n_vecs_full))
qk_inst_full = np.zeros((len(analysis_layers), n_vecs_full, n_vecs_full))
vo_base_full = np.zeros((len(analysis_layers), n_vecs_full, n_vecs_full))
vo_inst_full = np.zeros((len(analysis_layers), n_vecs_full, n_vecs_full))

with torch.inference_mode():
    for i, layer_idx in enumerate(tqdm(analysis_layers, desc="Computing attention patterns")):
        # QK affinity
        qk_b = compute_qk_affinity_matrix(vecs_tensor_full, layer_idx, base_model)
        qk_i = compute_qk_affinity_matrix(vecs_tensor_full, layer_idx, instruct_model)
        qk_base_full[i] = qk_b.cpu().numpy()
        qk_inst_full[i] = qk_i.cpu().numpy()
        
        # VO decomposition
        vo_b = compute_vo_decomposition(vecs_tensor_full, vecs_tensor_full, layer_idx, base_model)
        vo_i = compute_vo_decomposition(vecs_tensor_full, vecs_tensor_full, layer_idx, instruct_model)
        vo_base_full[i] = vo_b.cpu().numpy()
        vo_inst_full[i] = vo_i.cpu().numpy()

print(f"\n✓ Computed QK and VO for {len(analysis_layers)} layers")
print(f"  Arrays shape: {qk_base_full.shape}")

In [None]:
# Compute z-scores using random baseline
random_indices_full = [vec_names_full.index(n) for n in vec_names_full if "Random" in n]

print("Computing random baselines for base and instruct models...")

# Base model random baseline
qk_random_base = []
vo_random_base = []
for rand_idx in random_indices_full:
    qk_random_base.append(qk_base_full[:, rand_idx, rand_idx])
    vo_random_base.append(vo_base_full[:, rand_idx, rand_idx])
qk_random_base = np.array(qk_random_base)
vo_random_base = np.array(vo_random_base)
qk_base_mean = qk_random_base.mean(axis=0)
qk_base_std = qk_random_base.std(axis=0)
vo_base_mean = vo_random_base.mean(axis=0)
vo_base_std = vo_random_base.std(axis=0)

# Instruct model random baseline
qk_random_inst = []
vo_random_inst = []
for rand_idx in random_indices_full:
    qk_random_inst.append(qk_inst_full[:, rand_idx, rand_idx])
    vo_random_inst.append(vo_inst_full[:, rand_idx, rand_idx])
qk_random_inst = np.array(qk_random_inst)
vo_random_inst = np.array(vo_random_inst)
qk_inst_mean = qk_random_inst.mean(axis=0)
qk_inst_std = qk_random_inst.std(axis=0)
vo_inst_mean = vo_random_inst.mean(axis=0)
vo_inst_std = vo_random_inst.std(axis=0)

# Extract PC patterns and convert to z-scores
qk_zscores_base = {}
qk_zscores_inst = {}
vo_zscores_base = {}
vo_zscores_inst = {}

for pc_idx in range(n_pcs_total):
    pc_name = f"PC{pc_idx+1}"
    pc_neg_name = f"-{pc_name}"
    
    pc_pos_idx = vec_names_full.index(pc_name)
    pc_neg_idx = vec_names_full.index(pc_neg_name)
    
    # Positive PC patterns
    qk_zscores_base[pc_name] = {
        "self": (qk_base_full[:, pc_pos_idx, pc_pos_idx] - qk_base_mean) / (qk_base_std + 1e-8),
        "opposite": (qk_base_full[:, pc_pos_idx, pc_neg_idx] - qk_base_mean) / (qk_base_std + 1e-8)
    }
    qk_zscores_inst[pc_name] = {
        "self": (qk_inst_full[:, pc_pos_idx, pc_pos_idx] - qk_inst_mean) / (qk_inst_std + 1e-8),
        "opposite": (qk_inst_full[:, pc_pos_idx, pc_neg_idx] - qk_inst_mean) / (qk_inst_std + 1e-8)
    }
    vo_zscores_base[pc_name] = {
        "self": (vo_base_full[:, pc_pos_idx, pc_pos_idx] - vo_base_mean) / (vo_base_std + 1e-8),
        "opposite": (vo_base_full[:, pc_pos_idx, pc_neg_idx] - vo_base_mean) / (vo_base_std + 1e-8)
    }
    vo_zscores_inst[pc_name] = {
        "self": (vo_inst_full[:, pc_pos_idx, pc_pos_idx] - vo_inst_mean) / (vo_inst_std + 1e-8),
        "opposite": (vo_inst_full[:, pc_pos_idx, pc_neg_idx] - vo_inst_mean) / (vo_inst_std + 1e-8)
    }

print(f"✓ Computed z-scores for {n_pcs_total} PCs")
print(f"  Each model normalized by {len(random_indices_full)} random vectors")

## 3. Layer-wise PC Attention Patterns

Visualize how PCs attend to themselves and their opposites across layers.

In [None]:
# Compute deltas (Instruct - Base) for configured PCs
qk_delta_zscores = {}
vo_delta_zscores = {}

for pc in plot_pcs:
    qk_delta_zscores[pc] = {
        "self": qk_zscores_inst[pc]["self"] - qk_zscores_base[pc]["self"],
        "opposite": qk_zscores_inst[pc]["opposite"] - qk_zscores_base[pc]["opposite"]
    }
    vo_delta_zscores[pc] = {
        "self": vo_zscores_inst[pc]["self"] - vo_zscores_base[pc]["self"],
        "opposite": vo_zscores_inst[pc]["opposite"] - vo_zscores_base[pc]["opposite"]
    }

print(f"✓ Computed deltas for {len(plot_pcs)} PCs: {plot_pcs}")

In [None]:
# Visualize base vs instruct attention patterns
fig, axes = plt.subplots(2, len(plot_pcs), figsize=(20, 10))

# QK Affinity plots (top row)
for i, pc in enumerate(plot_pcs):
    if len(plot_pcs) == 1:
        ax = axes[0]
    else:
        ax = axes[0, i]
    
    # Plot base model
    ax.plot(analysis_layers, qk_zscores_base[pc]["self"], "o-", 
            label=f"Base: {pc}→{pc}", linewidth=2, alpha=0.7, color="C0")
    ax.plot(analysis_layers, qk_zscores_base[pc]["opposite"], "s--", 
            label=f"Base: {pc}→-{pc}", linewidth=1.5, alpha=0.7, color="C0")
    
    # Plot instruct model
    ax.plot(analysis_layers, qk_zscores_inst[pc]["self"], "o-", 
            label=f"Instruct: {pc}→{pc}", linewidth=2, alpha=0.7, color="C1")
    ax.plot(analysis_layers, qk_zscores_inst[pc]["opposite"], "s--", 
            label=f"Instruct: {pc}→-{pc}", linewidth=1.5, alpha=0.7, color="C1")
    
    ax.axhline(0, color="gray", linestyle=":", alpha=0.5)
    ax.axvline(pca_layer, color="red", linestyle=":", alpha=0.3, label=f"PCA layer ({pca_layer})")
    ax.set_xlabel("Layer", fontsize=11)
    ax.set_ylabel("Z-score (vs random)", fontsize=11)
    ax.set_title(f"QK Affinity: {pc} (Base vs Instruct)", fontsize=12, fontweight="bold")
    ax.legend(fontsize=9, loc="best")
    ax.grid(True, alpha=0.3)

# VO Decomposition plots (bottom row)
for i, pc in enumerate(plot_pcs):
    if len(plot_pcs) == 1:
        ax = axes[1]
    else:
        ax = axes[1, i]
    
    # Plot base model
    ax.plot(analysis_layers, vo_zscores_base[pc]["self"], "o-", 
            label=f"Base: {pc} self-bias", linewidth=2, alpha=0.7, color="C0")
    ax.plot(analysis_layers, vo_zscores_base[pc]["opposite"], "s--", 
            label=f"Base: {pc} opp-bias", linewidth=1.5, alpha=0.7, color="C0")
    
    # Plot instruct model
    ax.plot(analysis_layers, vo_zscores_inst[pc]["self"], "o-", 
            label=f"Instruct: {pc} self-bias", linewidth=2, alpha=0.7, color="C1")
    ax.plot(analysis_layers, vo_zscores_inst[pc]["opposite"], "s--", 
            label=f"Instruct: {pc} opp-bias", linewidth=1.5, alpha=0.7, color="C1")
    
    ax.axhline(0, color="gray", linestyle=":", alpha=0.5)
    ax.axvline(pca_layer, color="red", linestyle=":", alpha=0.3, label=f"PCA layer ({pca_layer})")
    ax.set_xlabel("Layer", fontsize=11)
    ax.set_ylabel("Z-score (vs random)", fontsize=11)
    ax.set_title(f"VO Decomposition: {pc} (Base vs Instruct)", fontsize=12, fontweight="bold")
    ax.legend(fontsize=9, loc="best")
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Visualize instruction tuning effects (delta)
fig, axes = plt.subplots(2, len(plot_pcs), figsize=(20, 10))

# QK Affinity deltas (top row)
for i, pc in enumerate(plot_pcs):
    if len(plot_pcs) == 1:
        ax = axes[0]
    else:
        ax = axes[0, i]
    
    ax.plot(analysis_layers, qk_delta_zscores[pc]["self"], "o-", 
            label=f"{pc}→{pc} (self)", linewidth=2.5, color="C2")
    ax.plot(analysis_layers, qk_delta_zscores[pc]["opposite"], "s-", 
            label=f"{pc}→-{pc} (opposite)", linewidth=2.5, color="C3")
    
    ax.axhline(0, color="black", linestyle="-", alpha=0.5, linewidth=1)
    ax.fill_between(analysis_layers, 0, qk_delta_zscores[pc]["self"], 
                    alpha=0.2, color="C2")
    ax.fill_between(analysis_layers, 0, qk_delta_zscores[pc]["opposite"], 
                    alpha=0.2, color="C3")
    
    ax.set_xlabel("Layer", fontsize=11)
    ax.set_ylabel("Δ Z-score (Instruct - Base)", fontsize=11)
    ax.set_title(f"Instruction Tuning Effect: {pc} QK", fontsize=12, fontweight="bold")
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)

# VO Decomposition deltas (bottom row)
for i, pc in enumerate(plot_pcs):
    if len(plot_pcs) == 1:
        ax = axes[1]
    else:
        ax = axes[1, i]
    
    ax.plot(analysis_layers, vo_delta_zscores[pc]["self"], "o-", 
            label=f"{pc} self-bias", linewidth=2.5, color="C2")
    ax.plot(analysis_layers, vo_delta_zscores[pc]["opposite"], "s-", 
            label=f"{pc} opposite-bias", linewidth=2.5, color="C3")
    
    ax.axhline(0, color="black", linestyle="-", alpha=0.5, linewidth=1)
    ax.fill_between(analysis_layers, 0, vo_delta_zscores[pc]["self"], 
                    alpha=0.2, color="C2")
    ax.fill_between(analysis_layers, 0, vo_delta_zscores[pc]["opposite"], 
                    alpha=0.2, color="C3")
    
    ax.set_xlabel("Layer", fontsize=11)
    ax.set_ylabel("Δ Z-score (Instruct - Base)", fontsize=11)
    ax.set_title(f"Instruction Tuning Effect: {pc} VO", fontsize=12, fontweight="bold")
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 4. PC Number Comparison

Compare instruction tuning effects across PC numbers to see if PC1 is special.

In [None]:
# Compute PC comparison using configured layers and PCs
print(f"Analyzing instruction tuning effects for PC1-PC{n_pcs_compare}...")
print(f"  Averaging over layers {comparison_layers[0]}-{comparison_layers[-1]}")

# Map layer indices to comparison_layers
layer_mask = np.array([layer_idx in comparison_layers for layer_idx in analysis_layers])
comparison_layer_indices = np.where(layer_mask)[0]

print(f"  Using layer indices: {comparison_layer_indices}")

# Compute effects for each PC
pc_effects = []
for pc_idx in range(n_pcs_compare):
    pc_name = f"PC{pc_idx+1}"
    pc_neg_name = f"-{pc_name}"
    
    pc_pos_idx = vec_names_full.index(pc_name)
    pc_neg_idx = vec_names_full.index(pc_neg_name)
    
    # Get patterns for comparison layers only
    qk_base_self = (qk_base_full[comparison_layer_indices, pc_pos_idx, pc_pos_idx] - qk_base_mean[comparison_layer_indices]) / (qk_base_std[comparison_layer_indices] + 1e-8)
    qk_inst_self = (qk_inst_full[comparison_layer_indices, pc_pos_idx, pc_pos_idx] - qk_inst_mean[comparison_layer_indices]) / (qk_inst_std[comparison_layer_indices] + 1e-8)
    qk_delta_self = qk_inst_self - qk_base_self
    
    qk_base_opp = (qk_base_full[comparison_layer_indices, pc_pos_idx, pc_neg_idx] - qk_base_mean[comparison_layer_indices]) / (qk_base_std[comparison_layer_indices] + 1e-8)
    qk_inst_opp = (qk_inst_full[comparison_layer_indices, pc_pos_idx, pc_neg_idx] - qk_inst_mean[comparison_layer_indices]) / (qk_inst_std[comparison_layer_indices] + 1e-8)
    qk_delta_opp = qk_inst_opp - qk_base_opp
    
    vo_base_self = (vo_base_full[comparison_layer_indices, pc_pos_idx, pc_pos_idx] - vo_base_mean[comparison_layer_indices]) / (vo_base_std[comparison_layer_indices] + 1e-8)
    vo_inst_self = (vo_inst_full[comparison_layer_indices, pc_pos_idx, pc_pos_idx] - vo_inst_mean[comparison_layer_indices]) / (vo_inst_std[comparison_layer_indices] + 1e-8)
    vo_delta_self = vo_inst_self - vo_base_self
    
    vo_base_opp = (vo_base_full[comparison_layer_indices, pc_pos_idx, pc_neg_idx] - vo_base_mean[comparison_layer_indices]) / (vo_base_std[comparison_layer_indices] + 1e-8)
    vo_inst_opp = (vo_inst_full[comparison_layer_indices, pc_pos_idx, pc_neg_idx] - vo_inst_mean[comparison_layer_indices]) / (vo_inst_std[comparison_layer_indices] + 1e-8)
    vo_delta_opp = vo_inst_opp - vo_base_opp
    
    # Take mean absolute delta across comparison layers
    pc_effects.append({
        'pc_num': pc_idx + 1,
        'variance_explained': float(variance_all[pc_idx]),
        'qk_self': np.mean(np.abs(qk_delta_self)),
        'qk_opposite': np.mean(np.abs(qk_delta_opp)),
        'vo_self': np.mean(np.abs(vo_delta_self)),
        'vo_opposite': np.mean(np.abs(vo_delta_opp)),
    })

print(f"✓ Computed mean effects across {len(comparison_layer_indices)} layers")

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

pc_nums = [d['pc_num'] for d in pc_effects]
qk_self = [d['qk_self'] for d in pc_effects]
qk_opp = [d['qk_opposite'] for d in pc_effects]
vo_self = [d['vo_self'] for d in pc_effects]
vo_opp = [d['vo_opposite'] for d in pc_effects]

# QK plot
ax = axes[0]
ax.plot(pc_nums, qk_self, 'o-', label='Self-attention', linewidth=2.5, markersize=8, color='C2')
ax.plot(pc_nums, qk_opp, 's-', label='Opposite-attention', linewidth=2.5, markersize=8, color='C3')
ax.set_xlabel('PC Number', fontsize=12)
ax.set_ylabel(f'Mean |Δ Z-score| (layers {comparison_layers[0]}-{comparison_layers[-1]})', fontsize=12)
ax.set_title('QK Affinity: Instruction Tuning Effect by PC', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_xticks(pc_nums)

# VO plot
ax = axes[1]
ax.plot(pc_nums, vo_self, 'o-', label='Self-bias', linewidth=2.5, markersize=8, color='C2')
ax.plot(pc_nums, vo_opp, 's-', label='Opposite-bias', linewidth=2.5, markersize=8, color='C3')
ax.set_xlabel('PC Number', fontsize=12)
ax.set_ylabel(f'Mean |Δ Z-score| (layers {comparison_layers[0]}-{comparison_layers[-1]})', fontsize=12)
ax.set_title('VO Decomposition: Instruction Tuning Effect by PC', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_xticks(pc_nums)

plt.tight_layout()
plt.show()

# Print summary
print(f"\n📊 Instruction Tuning Effects by PC (mean |Δ Z-score| over layers {comparison_layers[0]}-{comparison_layers[-1]}):")
print(f"\n{'PC':<5} {'Var%':<8} {'QK Self':<10} {'QK Opp':<10} {'VO Self':<10} {'VO Opp':<10}")
print("=" * 65)
for d in pc_effects:
    print(f"PC{d['pc_num']:<3} {d['variance_explained']*100:>6.2f}%  "
          f"{d['qk_self']:>8.3f}σ  {d['qk_opposite']:>8.3f}σ  "
          f"{d['vo_self']:>8.3f}σ  {d['vo_opposite']:>8.3f}σ")

print(f"\n🎯 Key Finding:")
max_qk = max(pc_effects, key=lambda x: x['qk_self'])
max_vo = max(pc_effects, key=lambda x: x['vo_self'])
print(f"  Strongest QK self-attention effect: PC{max_qk['pc_num']} ({max_qk['qk_self']:.3f}σ)")
print(f"  Strongest VO self-bias effect: PC{max_vo['pc_num']} ({max_vo['vo_self']:.3f}σ)")
if max_qk['pc_num'] == 1 and max_vo['pc_num'] == 1:
    print(f"  ✓ PC1 (dominant variance) also shows strongest instruction tuning effects!")
else:
    print(f"  ✗ Instruction tuning effects do NOT align with variance ranking")

## 5. Summary

This analysis reveals how instruction tuning modifies attention routing in the persona subspace:

- **Layer-wise patterns**: Shows how PC attention changes across all layers
- **Base vs Instruct**: Compares attention patterns between base and instruction-tuned models
- **PC sensitivity**: Determines whether PC1 (dominant variance component) is also most affected by instruction tuning
- **Configurable**: Easy to adjust layers and PCs via config cell

**Key Insights:**
- QK affinity reveals which semantic directions attend to each other
- VO decomposition shows what semantic content flows through attention
- Instruction tuning modifies these patterns, potentially aligning them with task objectives