# SAE Interpretation with Local LLM
This notebook interprets all neurons in a trained SAE using a local Gemma 3 12B model via vLLM.

## Setup: Install Dependencies

In [None]:
# Clone HypotheSAEs if not already present
!git clone https://github.com/rmovva/HypotheSAEs.git 2>/dev/null || echo "HypotheSAEs already exists"
!pip install -e HypotheSAEs/

# Install vLLM and other dependencies
!pip install vllm transformers accelerate huggingface_hub

## Load SAE from HuggingFace

In [None]:
import sys
sys.path.insert(0, "HypotheSAEs")

from hypothesaes.sae import load_model
import torch
import numpy as np
from huggingface_hub import snapshot_download
import os

# Configuration
SAE_REPO = "Koalacrown/llama3.1-8b-it-cognitive-actions-sae-l11"
LOCAL_SAE_DIR = "sae_checkpoint"

print(f"Downloading SAE from {SAE_REPO}...")
snapshot_download(
    repo_id=SAE_REPO,
    local_dir=LOCAL_SAE_DIR,
    repo_type="model"
)

# Find the SAE checkpoint file
sae_files = [f for f in os.listdir(LOCAL_SAE_DIR) if f.startswith('SAE_') and f.endswith('.pt')]
if not sae_files:
    raise FileNotFoundError(f"No SAE checkpoint found in {LOCAL_SAE_DIR}")

sae_path = os.path.join(LOCAL_SAE_DIR, sae_files[0])
print(f"Loading SAE from {sae_path}...")

sae = load_model(sae_path)
print(f"SAE loaded: M={sae.m_total_neurons}, K={sae.k_active_neurons}")
print(f"Input dimension: {sae.input_dim}")

## Load Dataset for Interpretation

In [None]:
import json

# Load the cognitive actions dataset
DATASET_PATH = "cognitive_actions_7k_final_1759233061.jsonl"

def load_dataset(dataset_path: str):
    """Load JSONL dataset."""
    data = []
    with open(dataset_path, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

print(f"Loading dataset from {DATASET_PATH}...")
data = load_dataset(DATASET_PATH)
texts = [item['text'] for item in data]
print(f"Loaded {len(texts)} examples")

## Extract Activations from Base Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.auto import tqdm

def extract_activations_sequential(
    texts,
    model_name="meta-llama/Llama-3.1-8B-Instruct",
    layer_idx=11,
    max_length=512,
    device="cuda" if torch.cuda.is_available() else "cpu",
):
    """Extract activations from LLM sequentially."""
    print(f"Loading model: {model_name}")
    print(f"Extracting from layer {layer_idx}")
    
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    model.eval()
    
    all_activations = []
    
    with torch.no_grad():
        for text in tqdm(texts, desc="Extracting activations"):
            inputs = tokenizer(
                text,
                return_tensors="pt",
                truncation=True,
                max_length=max_length,
                padding=False,
            ).to(device)
            
            outputs = model(**inputs, output_hidden_states=True, return_dict=True)
            layer_activations = outputs.hidden_states[layer_idx]
            layer_activations = layer_activations.squeeze(0).cpu().float().numpy()
            all_activations.append(layer_activations)
            
            del inputs, outputs, layer_activations
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    return all_activations

# Extract activations
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
LAYER_IDX = 11

activations_list = extract_activations_sequential(
    texts=texts,
    model_name=MODEL_NAME,
    layer_idx=LAYER_IDX,
    max_length=512,
)

# Flatten activations (simple concatenation for interpretation)
print("Flattening activations...")
flattened_activations = np.vstack(activations_list)
print(f"Flattened shape: {flattened_activations.shape}")

# Flatten texts to match activations (repeat each text for its sequence length)
flattened_texts = []
for i, acts in enumerate(activations_list):
    flattened_texts.extend([texts[i]] * acts.shape[0])
print(f"Total text segments: {len(flattened_texts)}")

## Pre-load Local Model with vLLM (Optional)

In [None]:
from hypothesaes.llm_local import get_vllm_engine

# Use Gemma 3 12B for interpretation
INTERPRETER_MODEL = "google/gemma-2-12b-it"

print(f"Pre-loading vLLM with {INTERPRETER_MODEL}...")
print("(This is optional - the model will auto-load on first use if skipped)")

# Pre-load with custom vLLM settings
engine = get_vllm_engine(
    INTERPRETER_MODEL,
    gpu_memory_utilization=0.85,  # Adjust based on your GPU memory
    tensor_parallel_size=1,        # Use multiple GPUs if available
)

print("vLLM engine loaded successfully!")

## Interpret All SAE Neurons

In [None]:
from hypothesaes.quickstart import interpret_sae
import pandas as pd

# Configuration for interpretation
N_NEURONS_TO_INTERPRET = sae.m_total_neurons  # Interpret all neurons
N_EXAMPLES_FOR_INTERPRETATION = 20
MAX_WORDS_PER_EXAMPLE = 256

print("="*60)
print("Interpreting SAE Neurons with Local Model")
print("="*60)
print(f"Total neurons to interpret: {N_NEURONS_TO_INTERPRET}")
print(f"Examples per neuron: {N_EXAMPLES_FOR_INTERPRETATION}")
print(f"Interpreter model: {INTERPRETER_MODEL}")
print("Note: HypotheSAEs will automatically use vLLM for local models")

# Interpret all neurons using local Gemma 3 12B model
# HypotheSAEs automatically detects local models and uses vLLM
interpretations_df = interpret_sae(
    texts=flattened_texts,
    embeddings=flattened_activations,
    sae=sae,
    n_top_neurons=N_NEURONS_TO_INTERPRET,
    interpreter_model=INTERPRETER_MODEL,  # Local model will be used via vLLM
    n_examples_for_interpretation=N_EXAMPLES_FOR_INTERPRETATION,
    max_words_per_example=MAX_WORDS_PER_EXAMPLE,
    interpret_temperature=0.7,
    max_interpretation_tokens=50,
    n_candidates=1,
    print_examples_n=3,
    print_examples_max_chars=1024,
    task_specific_instructions="These are activations from a model processing cognitive action descriptions. Focus on identifying specific cognitive patterns, reasoning types, or mental processes.",
)

print("\n" + "="*60)
print("Interpretation Complete!")
print("="*60)
print(f"Interpreted {len(interpretations_df)} neurons")

## Save Interpretations

In [None]:
# Save interpretations to CSV
OUTPUT_PATH = "sae_neuron_interpretations.csv"

interpretations_df.to_csv(OUTPUT_PATH, index=False)
print(f"Interpretations saved to {OUTPUT_PATH}")

# Display summary
print("\nSample interpretations:")
print(interpretations_df.head(10))

## Upload Interpretations to HuggingFace

In [None]:
from huggingface_hub import HfApi

# Upload interpretations CSV to the SAE repo
api = HfApi()

print(f"Uploading interpretations to {SAE_REPO}...")
api.upload_file(
    path_or_fileobj=OUTPUT_PATH,
    path_in_repo="neuron_interpretations.csv",
    repo_id=SAE_REPO,
    repo_type="model",
)

print(f"✅ Interpretations uploaded to https://huggingface.co/{SAE_REPO}")