# Analyzing SAE Activations for Perfect vs Non-Perfect Matches

This notebook analyzes the layer 2 SAE activations from Llama 3.1 8B for perfect matches vs non-perfect matches in the memorization results.

In [None]:
import pandas as pd
import torch
import numpy as np
import plotly.express as px
from sae_lens import SAE
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
from tqdm import tqdm

In [None]:
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load the SAE
release = "llama_scope_lxr_8x"
sae_id = "l2r_8x"
sae = SAE.from_pretrained(release, sae_id)[0]
sae = sae.to(device)

input("Press Enter to start viewing...")

# Load model and tokenizer
model_name = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16)

In [None]:
# Load the CSV file
df = pd.read_csv('memorization_results_L50_O50.csv')

# Split into perfect and non-perfect matches
perfect_matches = df[df['perfect_match'] == True]
non_perfect_matches = df[df['perfect_match'] == False]

# Sample 10 from each
random.seed(42)  # For reproducibility
perfect_sample = perfect_matches.sample(n=10)
non_perfect_sample = non_perfect_matches.sample(n=10)

print(f"Total perfect matches: {len(perfect_matches)}")
print(f"Total non-perfect matches: {len(non_perfect_matches)}")

In [None]:
def get_activations(text_input, text_output):
    """Get SAE activations for a given input-output pair"""
    # Concatenate input and output
    full_text = text_input + text_output
    
    # Tokenize without special tokens
    tokens = tokenizer(full_text, return_tensors='pt', add_special_tokens=False)['input_ids'].to(device)
    
    # Get model activations
    with torch.no_grad():
        outputs = model(tokens, output_hidden_states=True)
        # Get layer 2 activations
        layer_2_acts = outputs.hidden_states[2]
        
        # Get SAE activations
        sae_acts = sae.encode(layer_2_acts)
        
        # Only keep activations for output tokens (last 50)
        output_acts = sae_acts[:, -50:, :]
        
    return output_acts.squeeze(0)  # Remove batch dimension

In [None]:
# Get activations for perfect matches
perfect_activations = []
print("Processing perfect matches...")
for _, row in tqdm(perfect_sample.iterrows()):
    acts = get_activations(row['input_text'], row['generated_continuation_text'])
    perfect_activations.append(acts)

# Get activations for non-perfect matches
non_perfect_activations = []
print("\nProcessing non-perfect matches...")
for _, row in tqdm(non_perfect_sample.iterrows()):
    acts = get_activations(row['input_text'], row['generated_continuation_text'])
    non_perfect_activations.append(acts)

# Stack all activations
perfect_acts_stacked = torch.cat(perfect_activations, dim=0)
non_perfect_acts_stacked = torch.cat(non_perfect_activations, dim=0)

In [None]:
# Calculate statistics
def compute_stats(activations):
    """Compute various statistics for the activations"""
    # L0 (number of active features)
    l0 = (activations > 0).float().sum(-1).mean().item()
    
    # Mean activation when active
    mean_active = activations[activations > 0].mean().item()
    
    # Max activation
    max_act = activations.max().item()
    
    # Feature sparsity (fraction of features that never activate)
    feature_sparsity = ((activations > 0).sum(0) == 0).float().mean().item()
    
    return {
        'L0 (avg active features)': l0,
        'Mean activation when active': mean_active,
        'Max activation': max_act,
        'Feature sparsity': feature_sparsity
    }

perfect_stats = compute_stats(perfect_acts_stacked)
non_perfect_stats = compute_stats(non_perfect_acts_stacked)

# Print statistics
print("Statistics for perfect matches:")
for k, v in perfect_stats.items():
    print(f"{k}: {v:.4f}")

print("\nStatistics for non-perfect matches:")
for k, v in non_perfect_stats.items():
    print(f"{k}: {v:.4f}")

In [None]:
# Plot L0 histograms
l0_perfect = (perfect_acts_stacked > 0).float().sum(-1).cpu().numpy()
l0_non_perfect = (non_perfect_acts_stacked > 0).float().sum(-1).cpu().numpy()

# Create a DataFrame for plotting
l0_data = pd.DataFrame({
    'L0': np.concatenate([l0_perfect, l0_non_perfect]),
    'Type': ['Perfect Match'] * len(l0_perfect) + ['Non-Perfect Match'] * len(l0_non_perfect)
})

# Plot histogram
fig = px.histogram(l0_data, x='L0', color='Type', barmode='overlay',
                  title='Distribution of Active Features (L0)',
                  labels={'L0': 'Number of Active Features', 'count': 'Frequency'})
fig.show()

In [None]:
# Plot activation magnitude distributions
act_perfect = perfect_acts_stacked[perfect_acts_stacked > 0].cpu().numpy()
act_non_perfect = non_perfect_acts_stacked[non_perfect_acts_stacked > 0].cpu().numpy()

# Create a DataFrame for plotting
act_data = pd.DataFrame({
    'Activation': np.concatenate([act_perfect, act_non_perfect]),
    'Type': ['Perfect Match'] * len(act_perfect) + ['Non-Perfect Match'] * len(act_non_perfect)
})

# Plot histogram
fig = px.histogram(act_data, x='Activation', color='Type', barmode='overlay',
                  title='Distribution of Activation Magnitudes',
                  labels={'Activation': 'Activation Value', 'count': 'Frequency'})
fig.show()

In [None]:
# Analyze feature usage patterns
def get_top_features(activations, n=10):
    """Get the most frequently activated features"""
    feature_counts = (activations > 0).float().sum((0))
    top_features = torch.topk(feature_counts, n)
    return top_features.indices.cpu().numpy(), top_features.values.cpu().numpy()

# Get top features for both types
perfect_top_idx, perfect_top_counts = get_top_features(perfect_acts_stacked)
non_perfect_top_idx, non_perfect_top_counts = get_top_features(non_perfect_acts_stacked)

print("Top 10 most active features for perfect matches:")
for idx, count in zip(perfect_top_idx, perfect_top_counts):
    print(f"Feature {idx}: activated {count:.0f} times")

print("\nTop 10 most active features for non-perfect matches:")
for idx, count in zip(non_perfect_top_idx, non_perfect_top_counts):
    print(f"Feature {idx}: activated {count:.0f} times")

# Check overlap in top features
overlap = set(perfect_top_idx) & set(non_perfect_top_idx)
print(f"\nNumber of overlapping features in top 10: {len(overlap)}")