# Gemma2-27B MLP Interpretation

This notebook analyzes how PC vectors transform through the full MLP forward pass, including nonlinear activations.

**Key Finding from Original Analysis:**
- **Layer 18 is the "smoking gun"** - shows strongest PC1 transformation effects
- Full MLP pass reveals semantic decomposition that linear analysis misses

**Approach:**
1. Run PC vectors through complete MLP (pre_ln → gate/up → down → post_ln)
2. Compare base vs instruct outputs
3. Decompose outputs onto role/trait semantic vectors
4. Focus on layer 18 for detailed interpretation

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoConfig

from chatspace.analysis import (
    load_pca_data,
    extract_pc_components,
    load_individual_role_vectors,
    load_individual_trait_vectors,
    normalize_vector,
    full_mlp_forward_batch,
    compute_z_scores
)

%matplotlib inline
sns.set_style('whitegrid')

## 1. Load Models and PC Vectors

In [None]:
# Load models (reuse from basic_weight_susceptibility or reload)
base_model_id = "google/gemma-2-27b"
instruct_model_id = "google/gemma-2-27b-it"

print("Loading models...")
config = AutoConfig.from_pretrained(base_model_id)
base_model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True)
instruct_model = AutoModelForCausalLM.from_pretrained(instruct_model_id, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True)
base_state_dict = {k: v.cpu() for k, v in base_model.state_dict().items()}
instruct_state_dict = {k: v.cpu() for k, v in instruct_model.state_dict().items()}
print("✓ Models loaded")

In [None]:
# Load PCA data and extract all PCs
persona_data_root = Path("/workspace/persona-data")
roles_pca_dir = persona_data_root / "gemma-2-27b" / "roles_240" / "pca"
traits_pca_dir = persona_data_root / "gemma-2-27b" / "traits_240" / "pca"
roles_vectors_dir = persona_data_root / "gemma-2-27b" / "roles_240" / "vectors"
traits_vectors_dir = persona_data_root / "gemma-2-27b" / "traits_240" / "vectors"

pca_data, _ = load_pca_data(roles_pca_dir)
pca_layer = pca_data['layer']

# Extract ALL PCs (load 10 for flexibility)
n_pcs_total = 10
pcs_all, variance_all = extract_pc_components(pca_data, n_components=n_pcs_total)

print(f"✓ Loaded PCA data from layer {pca_layer}")
print(f"  Extracted {n_pcs_total} PCs")
print(f"  Variance explained by first 5: {variance_all[:5]}")

# Load all semantic vectors at PCA layer
role_vectors = load_individual_role_vectors(roles_vectors_dir, pca_layer)
trait_vectors = load_individual_trait_vectors(traits_vectors_dir, pca_layer)
all_semantic = {**role_vectors, **trait_vectors}

print(f"\n✓ Loaded semantic vectors:")
print(f"  {len(role_vectors)} role difference vectors")
print(f"  {len(trait_vectors)} trait contrast vectors")
print(f"  {len(all_semantic)} total semantic vectors")

In [None]:
# ============================================================================
# CONFIGURATION: Set analysis parameters here
# ============================================================================

# Which layers to analyze (focused on middle layers where effects are strongest)
analysis_layers = list(range(15, 25))

# Which layer to focus on for detailed semantic decomposition
# Original analysis identified layer 18 as "smoking gun"
focus_layer = 18

# Which PCs to visualize in plots
plot_pcs = ["PC1", "-PC1"]
# plot_pcs = ["PC1"]  # Uncomment to focus on PC1 only

# Number of top semantic projections to show
n_top_projections = 15

print(f"📋 Analysis Configuration:")
print(f"  Analyzing layers: {analysis_layers[0]}-{analysis_layers[-1]}")
print(f"  Focus layer for semantic decomposition: {focus_layer}")
print(f"  Visualizing PCs: {plot_pcs}")
print(f"  Top projections to show: {n_top_projections}")

## 2. Full MLP Forward Pass Analysis

In [None]:
# Run MLP analysis for configured layers
mlp_results = []

print(f"Running MLP forward pass for layers {analysis_layers[0]}-{analysis_layers[-1]}...")

# Prepare input vectors (configured PCs)
input_vectors = []
vector_names = []

for pc_name in plot_pcs:
    if pc_name.startswith('-'):
        # Negative PC
        pc_idx = int(pc_name[3:]) - 1  # Extract number from "-PC1"
        vec = -pcs_all[pc_idx].float()
    else:
        # Positive PC
        pc_idx = int(pc_name[2:]) - 1  # Extract number from "PC1"
        vec = pcs_all[pc_idx]
    
    input_vectors.append(vec)
    vector_names.append(pc_name)

vectors_batch = torch.stack(input_vectors)

for layer_num in tqdm(analysis_layers):
    # Get MLP weights
    gate_base = base_state_dict[f"model.layers.{layer_num}.mlp.gate_proj.weight"]
    up_base = base_state_dict[f"model.layers.{layer_num}.mlp.up_proj.weight"]
    down_base = base_state_dict[f"model.layers.{layer_num}.mlp.down_proj.weight"]
    pre_ln_base = base_state_dict[f"model.layers.{layer_num}.pre_feedforward_layernorm.weight"]
    post_ln_base = base_state_dict[f"model.layers.{layer_num}.post_feedforward_layernorm.weight"]
    
    gate_inst = instruct_state_dict[f"model.layers.{layer_num}.mlp.gate_proj.weight"]
    up_inst = instruct_state_dict[f"model.layers.{layer_num}.mlp.up_proj.weight"]
    down_inst = instruct_state_dict[f"model.layers.{layer_num}.mlp.down_proj.weight"]
    pre_ln_inst = instruct_state_dict[f"model.layers.{layer_num}.pre_feedforward_layernorm.weight"]
    post_ln_inst = instruct_state_dict[f"model.layers.{layer_num}.post_feedforward_layernorm.weight"]
    
    # Forward pass through base and instruct
    with torch.inference_mode():
        output_base = full_mlp_forward_batch(vectors_batch, gate_base, up_base, down_base, pre_ln_base, post_ln_base)
        output_inst = full_mlp_forward_batch(vectors_batch, gate_inst, up_inst, down_inst, pre_ln_inst, post_ln_inst)
    
    # Compute norms and delta for each vector
    for i, vec_name in enumerate(vector_names):
        base_norm = output_base[i].float().norm().item()
        inst_norm = output_inst[i].float().norm().item()
        delta = (output_inst[i] - output_base[i]).float().norm().item()
        
        mlp_results.append({
            'layer': layer_num,
            'input_vector': vec_name,
            'base_output_norm': base_norm,
            'instruct_output_norm': inst_norm,
            'delta_norm': delta
        })

mlp_df = pd.DataFrame(mlp_results)
print(f"\n✓ MLP analysis complete: {len(mlp_df)} results")

In [None]:
# Visualize MLP transformation results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Delta norm by layer
for vec_name in plot_pcs:
    data = mlp_df[mlp_df['input_vector'] == vec_name]
    axes[0].plot(data['layer'], data['delta_norm'], marker='o', label=vec_name, linewidth=2)

axes[0].axvline(focus_layer, color='red', linestyle=':', alpha=0.5, label=f'Focus layer ({focus_layer})')
axes[0].set_title('MLP Output Delta (Instruct - Base) by Layer')
axes[0].set_xlabel('Layer')
axes[0].set_ylabel('Delta Norm')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Compare base vs instruct norms (first PC only for clarity)
first_pc = plot_pcs[0]
pc_data = mlp_df[mlp_df['input_vector'] == first_pc]
axes[1].plot(pc_data['layer'], pc_data['base_output_norm'], marker='o', label='Base', linewidth=2)
axes[1].plot(pc_data['layer'], pc_data['instruct_output_norm'], marker='s', label='Instruct', linewidth=2)
axes[1].axvline(focus_layer, color='red', linestyle=':', alpha=0.5, label=f'Focus layer ({focus_layer})')
axes[1].set_title(f'MLP Output Norms for {first_pc}')
axes[1].set_xlabel('Layer')
axes[1].set_ylabel('Output Norm')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Find peak layer
peak_idx = mlp_df[mlp_df['input_vector'] == first_pc]['delta_norm'].idxmax()
peak_layer = mlp_df.loc[peak_idx, 'layer']
print(f"\n🎯 Peak delta at layer {peak_layer} (focus layer: {focus_layer})")

## 3. Semantic Decomposition at Focus Layer

Decompose what instruction tuning adds to PC transformations.

In [None]:
# Run focus layer analysis with first configured PC
first_pc_name = plot_pcs[0]
if first_pc_name.startswith('-'):
    pc_idx = int(first_pc_name[3:]) - 1
    pc_vec = -pcs_all[pc_idx].float()
else:
    pc_idx = int(first_pc_name[2:]) - 1
    pc_vec = pcs_all[pc_idx]

print(f"Analyzing {first_pc_name} transformation at layer {focus_layer}...")

# Get MLP weights for focus layer
gate_base = base_state_dict[f"model.layers.{focus_layer}.mlp.gate_proj.weight"]
up_base = base_state_dict[f"model.layers.{focus_layer}.mlp.up_proj.weight"]
down_base = base_state_dict[f"model.layers.{focus_layer}.mlp.down_proj.weight"]
pre_ln_base = base_state_dict[f"model.layers.{focus_layer}.pre_feedforward_layernorm.weight"]
post_ln_base = base_state_dict[f"model.layers.{focus_layer}.post_feedforward_layernorm.weight"]

gate_inst = instruct_state_dict[f"model.layers.{focus_layer}.mlp.gate_proj.weight"]
up_inst = instruct_state_dict[f"model.layers.{focus_layer}.mlp.up_proj.weight"]
down_inst = instruct_state_dict[f"model.layers.{focus_layer}.mlp.down_proj.weight"]
pre_ln_inst = instruct_state_dict[f"model.layers.{focus_layer}.pre_feedforward_layernorm.weight"]
post_ln_inst = instruct_state_dict[f"model.layers.{focus_layer}.post_feedforward_layernorm.weight"]

with torch.inference_mode():
    pc_batch = pc_vec.unsqueeze(0)
    output_base = full_mlp_forward_batch(pc_batch, gate_base, up_base, down_base, pre_ln_base, post_ln_base)[0]
    output_inst = full_mlp_forward_batch(pc_batch, gate_inst, up_inst, down_inst, pre_ln_inst, post_ln_inst)[0]
    
# Compute difference: what instruction tuning ADDS
difference = output_inst - output_base

print(f"\nLayer {focus_layer} {first_pc_name} MLP transformation:")
print(f"  Base output norm: {output_base.float().norm().item():.4f}")
print(f"  Instruct output norm: {output_inst.float().norm().item():.4f}")
print(f"  Difference norm: {difference.float().norm().item():.4f}")

# Project difference onto semantic vectors
semantic_names = list(all_semantic.keys())
projections = {}
for name in semantic_names:
    vec = all_semantic[name]
    proj = (difference.float() @ vec.float()).item()
    projections[name] = proj

# Sort by strength
sorted_projs = sorted(projections.items(), key=lambda x: abs(x[1]), reverse=True)

print(f"\n🔍 Top {n_top_projections} semantic projections (what instruction tuning adds):")
print("="*80)
for name, proj in sorted_projs[:n_top_projections]:
    print(f"{proj:+.4f}  {name}")

print(f"\n🔍 Bottom {n_top_projections} semantic projections:")
print("="*80)
for name, proj in sorted_projs[-n_top_projections:]:
    print(f"{proj:+.4f}  {name}")

## 4. Summary

This analysis reveals how instruction tuning modifies MLP transformations of PC vectors:

**Key Insights:**
- **Full MLP pass** reveals nonlinear transformation effects missed by linear analysis
- **Layer-specific effects** show where instruction tuning has strongest impact
- **Semantic decomposition** reveals which semantic directions are amplified/suppressed
- **Focus layer** (originally layer 18) shows particularly strong semantic effects

The projections onto role/trait vectors reveal the semantic meaning of instruction tuning's modifications.