# Test SAE on Custom Strings
This notebook allows you to test the trained SAE on custom text inputs to see which neurons activate most strongly and what they represent.

## Setup: Load Dependencies

In [None]:
import sys
sys.path.insert(0, "HypotheSAEs")

from hypothesaes.sae import load_model
import torch
import numpy as np
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download
import os

In [None]:
from huggingface_hub import hf_hub_download

# Configuration
SAE_REPO = "Koalacrown/llama3.1-8b-it-cognitive-actions-sae-l11"
LOCAL_SAE_DIR = "sae_checkpoint"

# Download SAE from HuggingFace
print(f"Downloading SAE from {SAE_REPO}...")
snapshot_download(
    repo_id=SAE_REPO,
    local_dir=LOCAL_SAE_DIR,
    repo_type="model"
)

# Load SAE
sae_files = [f for f in os.listdir(LOCAL_SAE_DIR) if f.startswith('SAE_') and f.endswith('.pt')]
sae_path = os.path.join(LOCAL_SAE_DIR, sae_files[0])
print(f"Loading SAE from {sae_path}...")
sae = load_model(sae_path)
print(f"SAE loaded: M={sae.m_total_neurons}, K={sae.k_active_neurons}")

# Download interpretations from HuggingFace
print(f"\nDownloading interpretations from {SAE_REPO}...")
interpretations_path = hf_hub_download(
    repo_id=SAE_REPO,
    filename="neuron_interpretations.csv",
    repo_type="model"
)

# Load interpretations
print(f"Loading interpretations from {interpretations_path}...")
interpretations_df = pd.read_csv(interpretations_path)
print(f"Loaded {len(interpretations_df)} neuron interpretations")
print("\nSample interpretations:")
print(interpretations_df.head())

In [None]:
# Configuration
SAE_REPO = "Koalacrown/llama3.1-8b-it-cognitive-actions-sae-l11"
LOCAL_SAE_DIR = "sae_checkpoint"
INTERPRETATIONS_PATH = "sae_neuron_interpretations.csv"

# Download SAE if needed
if not os.path.exists(LOCAL_SAE_DIR):
    print(f"Downloading SAE from {SAE_REPO}...")
    snapshot_download(
        repo_id=SAE_REPO,
        local_dir=LOCAL_SAE_DIR,
        repo_type="model"
    )

# Load SAE
sae_files = [f for f in os.listdir(LOCAL_SAE_DIR) if f.startswith('SAE_') and f.endswith('.pt')]
sae_path = os.path.join(LOCAL_SAE_DIR, sae_files[0])
print(f"Loading SAE from {sae_path}...")
sae = load_model(sae_path)
print(f"SAE loaded: M={sae.m_total_neurons}, K={sae.k_active_neurons}")

# Load interpretations
print(f"\nLoading interpretations from {INTERPRETATIONS_PATH}...")
interpretations_df = pd.read_csv(INTERPRETATIONS_PATH)
print(f"Loaded {len(interpretations_df)} neuron interpretations")
print("\nSample interpretations:")
print(interpretations_df.head())

## Load Base Model (Llama 3.1 8B)

In [None]:
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
LAYER_IDX = 11  # Layer from which SAE was trained
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading base model: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)
model.eval()
print(f"Model loaded on {device}")

## Helper Functions

In [None]:
def get_activations(text, model, tokenizer, layer_idx=11, pool_method="mean"):
    """
    Extract activations from the base model for a given text.
    
    Args:
        text: Input string
        model: Base language model
        tokenizer: Tokenizer
        layer_idx: Layer to extract from
        pool_method: How to pool across tokens ("mean", "max", or "first")
    
    Returns:
        Pooled activation vector (numpy array)
    """
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=512,
        padding=False,
    ).to(model.device)
    
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True, return_dict=True)
        layer_acts = outputs.hidden_states[layer_idx].squeeze(0).cpu().float().numpy()
    
    # Pool across tokens
    if pool_method == "mean":
        pooled = layer_acts.mean(axis=0)
    elif pool_method == "max":
        pooled = layer_acts.max(axis=0)
    elif pool_method == "first":
        pooled = layer_acts[0]
    else:
        raise ValueError(f"Unknown pool_method: {pool_method}")
    
    return pooled


def analyze_text(text, sae, model, tokenizer, interpretations_df, 
                 top_k=10, layer_idx=11, pool_method="mean"):
    """
    Analyze a text and show which SAE neurons activate most strongly.
    
    Args:
        text: Input string to analyze
        sae: Trained SAE model
        model: Base language model
        tokenizer: Tokenizer
        interpretations_df: DataFrame with neuron interpretations
        top_k: Number of top neurons to display
        layer_idx: Layer to extract from
        pool_method: How to pool activations
    
    Returns:
        DataFrame with top activated neurons and their interpretations
    """
    print(f"\nAnalyzing: \"{text}\"")
    print("=" * 80)
    
    # Step 1: Get base model activations
    activations = get_activations(text, model, tokenizer, layer_idx, pool_method)
    print(f"Extracted activations: shape={activations.shape}")
    
    # Step 2: Pass through SAE
    sae_activations = sae.get_activations(
        activations.reshape(1, -1),  # Add batch dimension
        show_progress=False
    ).squeeze()  # Remove batch dimension
    print(f"SAE activations: shape={sae_activations.shape}")
    
    # Step 3: Find top-k activated neurons
    top_indices = np.argsort(sae_activations)[::-1][:top_k]
    top_activations = sae_activations[top_indices]
    
    # Step 4: Get interpretations for top neurons
    results = []
    for idx, activation in zip(top_indices, top_activations):
        # Find interpretation for this neuron
        # Check different possible column names for neuron ID
        if 'neuron_id' in interpretations_df.columns:
            neuron_row = interpretations_df[interpretations_df['neuron_id'] == idx]
        elif 'neuron_index' in interpretations_df.columns:
            neuron_row = interpretations_df[interpretations_df['neuron_index'] == idx]
        elif 'neuron' in interpretations_df.columns:
            neuron_row = interpretations_df[interpretations_df['neuron'] == idx]
        else:
            # Assume index corresponds to neuron ID
            neuron_row = interpretations_df.iloc[[idx]] if idx < len(interpretations_df) else None
        
        if neuron_row is not None and len(neuron_row) > 0:
            # Find interpretation column
            interp_col = None
            for col in ['interpretation', 'description', 'label']:
                if col in neuron_row.columns:
                    interp_col = col
                    break
            
            interpretation = neuron_row[interp_col].values[0] if interp_col else "No interpretation available"
        else:
            interpretation = "No interpretation available"
        
        results.append({
            'neuron_id': int(idx),
            'activation': float(activation),
            'interpretation': interpretation
        })
    
    results_df = pd.DataFrame(results)
    return results_df


def batch_analyze(texts, sae, model, tokenizer, interpretations_df,
                  top_k=5, layer_idx=11, pool_method="mean"):
    """
    Analyze multiple texts and compare their neuron activations.
    """
    all_results = {}
    
    for i, text in enumerate(texts):
        results_df = analyze_text(
            text, sae, model, tokenizer, interpretations_df,
            top_k=top_k, layer_idx=layer_idx, pool_method=pool_method
        )
        all_results[f"Text {i+1}"] = results_df
        print(f"\nTop {top_k} activated neurons:")
        print(results_df.to_string(index=False))
        print("\n" + "=" * 80)
    
    return all_results

## Test on Single Text

In [None]:
# Example: Test a single string
test_text = "I need to reconsider my assumptions about this problem and think more carefully."

results = analyze_text(
    test_text,
    sae=sae,
    model=model,
    tokenizer=tokenizer,
    interpretations_df=interpretations_df,
    top_k=10,
    pool_method="mean"  # Try "mean", "max", or "first"
)

print("\nTop 10 activated neurons:")
print(results.to_string(index=False))

## Test on Multiple Texts (Batch Analysis)

In [None]:
# Example: Test multiple strings
test_texts = [
    "I realize I was wrong about my initial hypothesis and need to revise my thinking.",
    "Let me break down this complex problem into smaller, manageable steps.",
    "I'm uncertain whether this approach will work, so I should test it first.",
    "I need to check my biases before making this decision.",
]

batch_results = batch_analyze(
    test_texts,
    sae=sae,
    model=model,
    tokenizer=tokenizer,
    interpretations_df=interpretations_df,
    top_k=5
)

## Interactive Testing
Run this cell multiple times with different inputs

In [None]:
# Enter your own text here
custom_text = input("Enter text to analyze: ")

if custom_text.strip():
    results = analyze_text(
        custom_text,
        sae=sae,
        model=model,
        tokenizer=tokenizer,
        interpretations_df=interpretations_df,
        top_k=10
    )
    print("\nTop 10 activated neurons:")
    print(results.to_string(index=False))
else:
    print("No text entered.")

## Visualize Activation Patterns

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def visualize_activations(text, sae, model, tokenizer, top_k=20):
    """Visualize top-k neuron activations as a bar chart."""
    activations = get_activations(text, model, tokenizer, layer_idx=LAYER_IDX)
    sae_activations = sae.get_activations(
        activations.reshape(1, -1),
        show_progress=False
    ).squeeze()
    
    top_indices = np.argsort(sae_activations)[::-1][:top_k]
    top_activations = sae_activations[top_indices]
    
    plt.figure(figsize=(12, 6))
    plt.bar(range(top_k), top_activations)
    plt.xlabel('Neuron Rank')
    plt.ylabel('Activation Value')
    plt.title(f'Top {top_k} SAE Neuron Activations\n"{text[:50]}..."')
    plt.xticks(range(top_k), [f'N{i}' for i in top_indices], rotation=45)
    plt.tight_layout()
    plt.show()
    
# Example usage
visualize_activations(
    "I need to question my initial assumptions and consider alternative explanations.",
    sae, model, tokenizer, top_k=15
)

## Compare Activation Heatmap Across Multiple Texts

In [None]:
def compare_activations_heatmap(texts, sae, model, tokenizer, top_k=15):
    """Create a heatmap comparing neuron activations across multiple texts."""
    activation_matrix = []
    
    for text in texts:
        activations = get_activations(text, model, tokenizer, layer_idx=LAYER_IDX)
        sae_activations = sae.get_activations(
            activations.reshape(1, -1),
            show_progress=False
        ).squeeze()
        activation_matrix.append(sae_activations)
    
    activation_matrix = np.array(activation_matrix)
    
    # Get top-k most varying neurons across all texts
    variance = activation_matrix.var(axis=0)
    top_varying = np.argsort(variance)[::-1][:top_k]
    
    # Create heatmap
    plt.figure(figsize=(14, 8))
    sns.heatmap(
        activation_matrix[:, top_varying].T,
        cmap='viridis',
        yticklabels=[f'Neuron {i}' for i in top_varying],
        xticklabels=[f'Text {i+1}' for i in range(len(texts))],
        cbar_kws={'label': 'Activation'}
    )
    plt.xlabel('Text')
    plt.ylabel('Neuron')
    plt.title(f'Top {top_k} Most Varying SAE Neurons Across {len(texts)} Texts')
    plt.tight_layout()
    plt.show()

# Example usage
compare_texts = [
    "I need to reconsider my initial assumptions.",
    "Let me break this problem into smaller parts.",
    "I'm uncertain about the best approach here.",
    "I should check for biases in my reasoning.",
]

compare_activations_heatmap(compare_texts, sae, model, tokenizer, top_k=20)

## Cleanup (Optional)
Free up GPU memory when done

In [None]:
# Uncomment to free GPU memory
# del model, tokenizer, sae
# import gc
# gc.collect()
# if torch.cuda.is_available():
#     torch.cuda.empty_cache()
# print("Models unloaded and GPU memory freed")