# SAE Analysis for PVA-SAE Project

This notebook implements the core SAE analysis pipeline:
1. **SAE Activation Collection** - Extract activations at final token using hooks
2. **Compute Separation Scores** - Calculate activation fractions per thesis methodology  
3. **ArgMax Selection** - Find best features for correct/incorrect code
4. **PVA Latent Direction** - Identify program validity awareness directions

Using: Gemma 2 2B model, Layer 20 SAE, 10 MBPP samples

## 1. Setup & Load Data

In [1]:
# Imports
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from pathlib import Path
import sys
sys.path.append('..')

# Setup
DEVICE = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")

Device: mps


In [2]:
# Load dataset
data_dir = Path("../data/datasets")
dataset_files = list(data_dir.glob("dataset_*.parquet"))

if dataset_files:
    latest_dataset = max(dataset_files, key=lambda x: x.stat().st_mtime)
    df = pd.read_parquet(latest_dataset)
    print(f"Loaded: {latest_dataset.name}")
    print(f"Shape: {df.shape}")
    
    # Add is_correct column
    if 'test_passed' in df.columns:
        df['is_correct'] = df['test_passed']
    
    print(f"\nCorrect: {df['is_correct'].sum()}, Incorrect: {(~df['is_correct']).sum()}")
else:
    print("No dataset found!")
    df = None

Loaded: dataset_gemma-2-2b_2025-06-01_14-54-19.parquet
Shape: (10, 4)

Correct: 3, Incorrect: 7


In [3]:
# Reconstruct prompts using shared template
from phase2_sae_analysis.prompt_utils import build_prompt_template
from phase1_dataset_building.dataset_manager import PromptAwareDatasetManager

if df is not None and 'prompt' not in df.columns:
    print("Reconstructing prompts...")
    
    # Load MBPP dataset
    dataset_manager = PromptAwareDatasetManager()
    dataset_manager.load_dataset()
    
    # Reconstruct prompts
    prompts = []
    for task_id in df['task_id']:
        record = dataset_manager.get_record(task_id)
        
        # Extract components
        problem_desc = record['text'].strip()
        test_cases = '\n'.join(record['test_list'])
        
        # Build prompt using shared template
        prompt = build_prompt_template(problem_desc, test_cases)
        prompts.append(prompt)
    
    df['prompt'] = prompts
    print(f"Reconstructed {len(prompts)} prompts")

  from .autonotebook import tqdm as notebook_tqdm


Reconstructing prompts...
Reconstructed 10 prompts


## 2. Configuration

In [4]:
# Model and SAE configuration
MODEL_NAME = "google/gemma-2-2b"
SAE_LAYER = 20
SAE_CONFIG = {
    "repo_id": "google/gemma-scope-2b-pt-res",
    "sae_id": f"layer_{SAE_LAYER}/width_16k/average_l0_71",
}

print(f"Model: {MODEL_NAME}")
print(f"SAE Layer: {SAE_LAYER}")
print(f"SAE: {SAE_CONFIG['repo_id']}/{SAE_CONFIG['sae_id']}")

Model: google/gemma-2-2b
SAE Layer: 20
SAE: google/gemma-scope-2b-pt-res/layer_20/width_16k/average_l0_71


## 3. SAE Activation Collection

In [5]:
# Import SAE analyzer functions
from phase2_sae_analysis.sae_analyzer import (
    load_gemma_scope_sae,
    ActivationExtractor,
    compute_separation_scores,
    PVALatentDirection
)
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Tokenizer loaded")

Tokenizer loaded


In [6]:
# Load model and SAE
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
    device_map="auto" if DEVICE == "cuda" else None
)
if DEVICE == "mps":
    model = model.to(DEVICE)
model.eval()
print(f"Model loaded on {DEVICE}")

# Load SAE
sae_model = load_gemma_scope_sae(
    repo_id=SAE_CONFIG["repo_id"],
    sae_id=SAE_CONFIG["sae_id"],
    device=DEVICE
)
print("SAE loaded")

Loading model...


Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 20.38it/s]


Model loaded on mps
SAE loaded


In [7]:
# Collect activations
if df is not None:
    print(f"Extracting activations for {len(df)} samples...")
    
    # Prepare prompts (last token of prompt before generation)
    prompts = df['prompt'].tolist()
    
    # Extract activations
    extractor = ActivationExtractor(model, tokenizer, DEVICE)
    
    all_activations = []
    for i, prompt in enumerate(prompts):
        print(f"\rProcessing {i+1}/{len(prompts)}", end="")
        acts = extractor.extract_activations([prompt], SAE_LAYER)
        all_activations.append(acts[0])  # Get first (only) sample
    
    print("\nActivations extracted")
    
    # Apply SAE encoding
    print("Applying SAE encoding...")
    sae_activations = []
    with torch.no_grad():
        for act in all_activations:
            sae_act = sae_model.encode(act.unsqueeze(0)).squeeze(0)
            sae_activations.append(sae_act)
    
    # Split by correctness
    correct_activations = [sae_activations[i] for i in range(len(df)) if df.iloc[i]['is_correct']]
    incorrect_activations = [sae_activations[i] for i in range(len(df)) if not df.iloc[i]['is_correct']]
    
    print(f"SAE activations: {len(correct_activations)} correct, {len(incorrect_activations)} incorrect")

Extracting activations for 10 samples...
Processing 1/10

TypeError: ActivationExtractor._create_hook.<locals>.hook_fn() missing 1 required positional argument: 'output'

## 4. Compute Separation Scores

In [None]:
# Compute separation scores
if 'correct_activations' in locals() and 'incorrect_activations' in locals():
    print("Computing separation scores...")
    
    scores = compute_separation_scores(
        torch.stack(correct_activations),
        torch.stack(incorrect_activations)
    )
    
    print(f"\nActivation fraction statistics:")
    print(f"  Mean f_correct: {scores.f_correct.mean():.4f}")
    print(f"  Mean f_incorrect: {scores.f_incorrect.mean():.4f}")
    print(f"  Max s_correct: {scores.s_correct.max():.4f}")
    print(f"  Max s_incorrect: {scores.s_incorrect.max():.4f}")

## 5. ArgMax Selection

In [None]:
# Find best PVA directions
if 'scores' in locals():
    # Create PVA directions
    correct_direction = PVALatentDirection(
        direction_type="correct",
        layer=SAE_LAYER,
        feature_idx=scores.best_correct_idx,
        separation_score=scores.s_correct[scores.best_correct_idx].item(),
        f_correct=scores.f_correct[scores.best_correct_idx].item(),
        f_incorrect=scores.f_incorrect[scores.best_correct_idx].item()
    )
    
    incorrect_direction = PVALatentDirection(
        direction_type="incorrect",
        layer=SAE_LAYER,
        feature_idx=scores.best_incorrect_idx,
        separation_score=scores.s_incorrect[scores.best_incorrect_idx].item(),
        f_correct=scores.f_correct[scores.best_incorrect_idx].item(),
        f_incorrect=scores.f_incorrect[scores.best_incorrect_idx].item()
    )
    
    print("\n" + "="*50)
    print("PVA LATENT DIRECTIONS")
    print("="*50)
    print(f"\n{correct_direction}")
    print(f"  Activates on {correct_direction.f_correct:.1%} of correct, {correct_direction.f_incorrect:.1%} of incorrect")
    print(f"\n{incorrect_direction}")
    print(f"  Activates on {incorrect_direction.f_correct:.1%} of correct, {incorrect_direction.f_incorrect:.1%} of incorrect")

## 6. Visualization

In [None]:
# Simple visualization of top features
if 'scores' in locals():
    # Get top 10 features for each direction
    top_k = 10
    
    top_correct_scores, top_correct_idx = torch.topk(scores.s_correct, top_k)
    top_incorrect_scores, top_incorrect_idx = torch.topk(scores.s_incorrect, top_k)
    
    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Correct direction
    ax1.bar(range(top_k), top_correct_scores.cpu().numpy())
    ax1.set_xlabel('Rank')
    ax1.set_ylabel('Separation Score (f_correct - f_incorrect)')
    ax1.set_title('Top Correct Code Features')
    ax1.set_xticks(range(top_k))
    ax1.set_xticklabels([f'{idx}' for idx in top_correct_idx.cpu().numpy()], rotation=45)
    
    # Incorrect direction
    ax2.bar(range(top_k), top_incorrect_scores.cpu().numpy(), color='orange')
    ax2.set_xlabel('Rank')
    ax2.set_ylabel('Separation Score (f_incorrect - f_correct)')
    ax2.set_title('Top Incorrect Code Features')
    ax2.set_xticks(range(top_k))
    ax2.set_xticklabels([f'{idx}' for idx in top_incorrect_idx.cpu().numpy()], rotation=45)
    
    plt.tight_layout()
    plt.show()

In [None]:
# Clean up memory
if 'model' in locals():
    del model
    if DEVICE == "mps":
        torch.mps.empty_cache()
    elif DEVICE == "cuda":
        torch.cuda.empty_cache()
    print("Memory cleaned")

## Summary

This notebook successfully:
1. ✅ Collected SAE activations at the final prompt token
2. ✅ Computed separation scores using activation fractions (thesis methodology)
3. ✅ Selected best features via argmax for both correct/incorrect directions
4. ✅ Identified PVA latent directions for layer 20

The identified directions can be used for model steering in Phase 3.