# Gemma2-27B MLP Interpretation

This notebook analyzes how PC vectors transform through the full MLP forward pass, including nonlinear activations.

**Key Finding from Original Analysis:**
- **Layer 18 is the "smoking gun"** - shows strongest PC1 transformation effects
- Full MLP pass reveals semantic decomposition that linear analysis misses

**Approach:**
1. Run PC vectors through complete MLP (pre_ln → gate/up → down → post_ln)
2. Compare base vs instruct outputs
3. Decompose outputs onto role/trait semantic vectors
4. Focus on layer 18 for detailed interpretation

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoConfig

from chatspace.analysis import (
    load_pca_data,
    extract_pc_components,
    load_individual_role_vectors,
    load_individual_trait_vectors,
    normalize_vector,
    full_mlp_forward_batch,
    compute_z_scores
)

%matplotlib inline
sns.set_style('whitegrid')

## 1. Load Models and PC Vectors

In [None]:
# Load models (reuse from basic_weight_susceptibility or reload)
base_model_id = "google/gemma-2-27b"
instruct_model_id = "google/gemma-2-27b-it"

print("Loading models...")
config = AutoConfig.from_pretrained(base_model_id)
base_model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True)
instruct_model = AutoModelForCausalLM.from_pretrained(instruct_model_id, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True)
base_state_dict = {k: v.cpu() for k, v in base_model.state_dict().items()}
instruct_state_dict = {k: v.cpu() for k, v in instruct_model.state_dict().items()}
print("✓ Models loaded")

In [None]:
# Load PCA data
persona_data_root = Path("/workspace/persona-data")
roles_pca_dir = persona_data_root / "gemma-2-27b" / "roles_240" / "pca"
traits_pca_dir = persona_data_root / "gemma-2-27b" / "traits_240" / "pca"

pca_data, _ = load_pca_data(roles_pca_dir)
pcs, variance_explained = extract_pc_components(pca_data, n_components=3)
pc1 = pcs[0]
print(f"✓ Loaded PC1 from layer {pca_data['layer']}")
print(f"  Variance explained: {variance_explained[0]:.4f}")

## 2. Full MLP Forward Pass Analysis

In [None]:
# Run MLP analysis across layers
target_layers = range(15, 25)  # Focus on middle layers
mlp_results = []

print(f"Running MLP forward pass for layers {list(target_layers)}...")

for layer_num in tqdm(target_layers):
    # Get MLP weights
    gate_base = base_state_dict[f"model.layers.{layer_num}.mlp.gate_proj.weight"]
    up_base = base_state_dict[f"model.layers.{layer_num}.mlp.up_proj.weight"]
    down_base = base_state_dict[f"model.layers.{layer_num}.mlp.down_proj.weight"]
    pre_ln_base = base_state_dict[f"model.layers.{layer_num}.pre_feedforward_layernorm.weight"]
    post_ln_base = base_state_dict[f"model.layers.{layer_num}.post_feedforward_layernorm.weight"]
    
    gate_inst = instruct_state_dict[f"model.layers.{layer_num}.mlp.gate_proj.weight"]
    up_inst = instruct_state_dict[f"model.layers.{layer_num}.mlp.up_proj.weight"]
    down_inst = instruct_state_dict[f"model.layers.{layer_num}.mlp.down_proj.weight"]
    pre_ln_inst = instruct_state_dict[f"model.layers.{layer_num}.pre_feedforward_layernorm.weight"]
    post_ln_inst = instruct_state_dict[f"model.layers.{layer_num}.post_feedforward_layernorm.weight"]
    
    # Prepare input vectors
    vectors = torch.stack([pc1, -pc1])  # PC1 and -PC1
    vector_names = ['PC1', '-PC1']
    
    # Forward pass through base and instruct
    with torch.inference_mode():
        output_base = full_mlp_forward_batch(vectors, gate_base, up_base, down_base, pre_ln_base, post_ln_base)
        output_inst = full_mlp_forward_batch(vectors, gate_inst, up_inst, down_inst, pre_ln_inst, post_ln_inst)
    
    # Compute norms and delta
    for i, vec_name in enumerate(vector_names):
        base_norm = output_base[i].float().norm().item()
        inst_norm = output_inst[i].float().norm().item()
        delta = (output_inst[i] - output_base[i]).float().norm().item()
        
        mlp_results.append({
            'layer': layer_num,
            'input_vector': vec_name,
            'base_output_norm': base_norm,
            'instruct_output_norm': inst_norm,
            'delta_norm': delta
        })

mlp_df = pd.DataFrame(mlp_results)
print(f"\n✓ MLP analysis complete: {len(mlp_df)} results")

In [None]:
# Visualize results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Delta norm by layer
for vec_name in ['PC1', '-PC1']:
    data = mlp_df[mlp_df['input_vector'] == vec_name]
    axes[0].plot(data['layer'], data['delta_norm'], marker='o', label=vec_name)
axes[0].set_title('MLP Output Delta (Instruct - Base) by Layer')
axes[0].set_xlabel('Layer')
axes[0].set_ylabel('Delta Norm')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Compare base vs instruct norms
pc1_data = mlp_df[mlp_df['input_vector'] == 'PC1']
axes[1].plot(pc1_data['layer'], pc1_data['base_output_norm'], marker='o', label='Base')
axes[1].plot(pc1_data['layer'], pc1_data['instruct_output_norm'], marker='s', label='Instruct')
axes[1].set_title('MLP Output Norms for PC1')
axes[1].set_xlabel('Layer')
axes[1].set_ylabel('Output Norm')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Find peak layer
peak_layer = mlp_df[mlp_df['input_vector'] == 'PC1']['delta_norm'].idxmax()
peak_layer_num = mlp_df.loc[peak_layer, 'layer']
print(f"\n🎯 Peak delta at layer {peak_layer_num}")

## 3. Layer 18 Semantic Decomposition

Original analysis identified layer 18 as having the strongest effect. Let's decompose what instruction tuning adds to PC1's transformation.

In [None]:
# Focus on layer 18
target_layer = 18

# Load individual role and trait vectors for this layer
persona_data_root = Path("/workspace/persona-data")
roles_vectors_dir = persona_data_root / "gemma-2-27b" / "roles_240" / "vectors"
traits_vectors_dir = persona_data_root / "gemma-2-27b" / "traits_240" / "vectors"

role_vectors = load_individual_role_vectors(roles_vectors_dir, target_layer, vector_type='pos_all')
trait_vectors = load_individual_trait_vectors(traits_vectors_dir, target_layer, vector_type='pos_all')

print(f"Loaded semantic vectors for layer {target_layer}:")
print(f"  Roles: {len(role_vectors)}")
print(f"  Traits: {len(trait_vectors)}")

# Combine all semantic vectors
all_semantic = {**role_vectors, **trait_vectors}
semantic_names = list(all_semantic.keys())
semantic_matrix = torch.stack([all_semantic[name] for name in semantic_names])

In [None]:
# Run PC1 through layer 18 MLP
gate_base = base_state_dict[f"model.layers.{target_layer}.mlp.gate_proj.weight"]
up_base = base_state_dict[f"model.layers.{target_layer}.mlp.up_proj.weight"]
down_base = base_state_dict[f"model.layers.{target_layer}.mlp.down_proj.weight"]
pre_ln_base = base_state_dict[f"model.layers.{target_layer}.pre_feedforward_layernorm.weight"]
post_ln_base = base_state_dict[f"model.layers.{target_layer}.post_feedforward_layernorm.weight"]

gate_inst = instruct_state_dict[f"model.layers.{target_layer}.mlp.gate_proj.weight"]
up_inst = instruct_state_dict[f"model.layers.{target_layer}.mlp.up_proj.weight"]
down_inst = instruct_state_dict[f"model.layers.{target_layer}.mlp.down_proj.weight"]
pre_ln_inst = instruct_state_dict[f"model.layers.{target_layer}.pre_feedforward_layernorm.weight"]
post_ln_inst = instruct_state_dict[f"model.layers.{target_layer}.post_feedforward_layernorm.weight"]

with torch.inference_mode():
    pc1_batch = pc1.unsqueeze(0)
    output_base = full_mlp_forward_batch(pc1_batch, gate_base, up_base, down_base, pre_ln_base, post_ln_base)[0]
    output_inst = full_mlp_forward_batch(pc1_batch, gate_inst, up_inst, down_inst, pre_ln_inst, post_ln_inst)[0]
    
# Compute difference: what instruction tuning ADDS
difference = output_inst - output_base

print(f"\nLayer {target_layer} PC1 MLP transformation:")
print(f"  Base output norm: {output_base.float().norm().item():.4f}")
print(f"  Instruct output norm: {output_inst.float().norm().item():.4f}")
print(f"  Difference norm: {difference.float().norm().item():.4f}")

In [None]:
# Project difference onto semantic vectors
projections = {}
for name in semantic_names:
    vec = all_semantic[name]
    proj = (difference.float() @ vec.float()).item()
    projections[name] = proj

# Sort by strength
sorted_projs = sorted(projections.items(), key=lambda x: abs(x[1]), reverse=True)

print(f"\n🔍 Top 15 semantic projections (what instruction tuning adds):")
print("="*80)
for name, proj in sorted_projs[:15]:
    print(f"{proj:+.4f}  {name}")

print(f"\n🔍 Bottom 15 semantic projections:")
print("="*80)
for name, proj in sorted_projs[-15:]:
    print(f"{proj:+.4f}  {name}")

## 4. Summary

This analysis reveals how instruction tuning modifies the MLP transformation of PC vectors, with layer 18 showing particularly strong semantic effects. The projections onto role/trait vectors reveal which semantic directions are amplified or suppressed by instruction tuning.