# Analyzing SAE Activations for Perfect vs Non-Perfect Matches

This notebook analyzes the layer 2 SAE activations from Llama 3.1 8B for perfect matches vs non-perfect matches in the memorization results.

In [1]:
import pandas as pd
import torch
import numpy as np
import plotly.express as px
from sae_lens import SAE
from transformer_lens import HookedTransformer
import random
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load model using TransformerLens
model = HookedTransformer.from_pretrained("meta-llama/Llama-3.1-8B", device=device, torch_dtype=torch.bfloat16)

# No need for separate tokenizer as it's included in the HookedTransformer

Using device: cuda


Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 42.62it/s]


Loaded pretrained model meta-llama/Llama-3.1-8B into HookedTransformer


In [123]:
# Load the SAE
LAYER_NUMBER = 7
release = "llama_scope_lxr_8x"
sae_id = f"l{LAYER_NUMBER}r_8x"
sae = SAE.from_pretrained(release, sae_id)[0]
sae = sae.to(device)

### Intermediate Stuff

In [124]:
# Load the CSV file
df = pd.read_csv('memorization_results_L50_O50.csv')

# Split into perfect and non-perfect matches
perfect_matches = df[df['perfect_match'] == True]
non_perfect_matches = df[(df['perfect_match'] == False) & (df['matching_tokens'] == 0)]

# Sample 10 from each
random.seed(42)  # For reproducibility
perfect_sample = perfect_matches.sample(n=10)
non_perfect_sample = non_perfect_matches.sample(n=10)

print(f"Total perfect matches: {len(perfect_matches)}")
print(f"Total non-perfect matches: {len(non_perfect_matches)}")

Total perfect matches: 26
Total non-perfect matches: 83


In [125]:
def get_activations(text_input, text_output):
    """Get SAE activations for a given input-output pair using TransformerLens"""
    # Concatenate input and output
    full_text = text_input + text_output
    
    # Tokenize and get model activations using run_with_cache
    tokens = model.to_tokens(full_text, prepend_bos=False)
    _, cache = model.run_with_cache(tokens)
    
    # Get layer activations
    layer_acts = cache['resid_pre', LAYER_NUMBER]
    
    # Get SAE activations
    sae_acts = sae.encode(layer_acts)
    
    # Only keep activations for output tokens (last 50)
    output_acts = sae_acts[:, -50:, :]
    
    return output_acts.squeeze(0)  # Remove batch dimension

In [126]:
# Get activations for perfect matches
perfect_activations = []
print("Processing perfect matches...")
for _, row in tqdm(perfect_sample.iterrows()):
    acts = get_activations(row['input_text'], row['generated_continuation_text'])
    perfect_activations.append(acts)

# Get activations for non-perfect matches
non_perfect_activations = []
print("\nProcessing non-perfect matches...")
for _, row in tqdm(non_perfect_sample.iterrows()):
    acts = get_activations(row['input_text'], row['generated_continuation_text'])
    non_perfect_activations.append(acts)

# Stack all activations
perfect_acts_stacked = torch.cat(perfect_activations, dim=0)
non_perfect_acts_stacked = torch.cat(non_perfect_activations, dim=0)

Processing perfect matches...


10it [00:00, 10.36it/s]



Processing non-perfect matches...


10it [00:00, 12.83it/s]


In [127]:
# Calculate statistics
def compute_stats(activations):
    """Compute various statistics for the activations"""
    # L0 (number of active features)
    l0 = (activations > 0).float().sum(-1).mean().item()
    
    # Mean activation when active
    mean_active = activations[activations > 0].mean().item()
    
    # Max activation
    max_act = activations.max().item()
    
    # Feature sparsity (fraction of features that never activate)
    feature_sparsity = ((activations > 0).sum(0) == 0).float().mean().item()
    
    return {
        'L0 (avg active features)': l0,
        'Mean activation when active': mean_active,
        'Max activation': max_act,
        'Feature sparsity': feature_sparsity
    }

perfect_stats = compute_stats(perfect_acts_stacked)
non_perfect_stats = compute_stats(non_perfect_acts_stacked)

# Print statistics
print("Statistics for perfect matches:")
for k, v in perfect_stats.items():
    print(f"{k}: {v:.4f}")

print("\nStatistics for non-perfect matches:")
for k, v in non_perfect_stats.items():
    print(f"{k}: {v:.4f}")

Statistics for perfect matches:
L0 (avg active features): 18.1860
Mean activation when active: 3.8438
Max activation: 22.2500
Feature sparsity: 0.9162

Statistics for non-perfect matches:
L0 (avg active features): 17.7680
Mean activation when active: 3.7812
Max activation: 22.6250
Feature sparsity: 0.9233


In [128]:
# Plot L0 histograms
l0_perfect = (perfect_acts_stacked > 0).float().sum(-1).cpu().numpy()
l0_non_perfect = (non_perfect_acts_stacked > 0).float().sum(-1).cpu().numpy()

# Create a DataFrame for plotting
l0_data = pd.DataFrame({
    'L0': np.concatenate([l0_perfect, l0_non_perfect]),
    'Type': ['Perfect Match'] * len(l0_perfect) + ['Non-Perfect Match'] * len(l0_non_perfect)
})

# Plot histogram
fig = px.histogram(l0_data, x='L0', color='Type', barmode='overlay',
                  title='Distribution of Active Features (L0)',
                  labels={'L0': 'Number of Active Features', 'count': 'Frequency'})
fig.show()

In [129]:
# Plot activation magnitude distributions
act_perfect = perfect_acts_stacked[perfect_acts_stacked > 0].float().detach().cpu().numpy()
act_non_perfect = non_perfect_acts_stacked[non_perfect_acts_stacked > 0].float().detach().cpu().numpy()

# Create a DataFrame for plotting
act_data = pd.DataFrame({
    'Activation': np.concatenate([act_perfect, act_non_perfect]),
    'Type': ['Perfect Match'] * len(act_perfect) + ['Non-Perfect Match'] * len(act_non_perfect)
})

# Plot histogram
fig = px.histogram(act_data, x='Activation', color='Type', barmode='overlay',
                  title='Distribution of Activation Magnitudes',
                  labels={'Activation': 'Activation Value', 'count': 'Frequency'})
fig.show()

In [130]:
num_top_features = 10

# Analyze feature usage patterns
def get_top_features(activations, n=num_top_features):
    """Get the most frequently activated features"""
    feature_counts = (activations > 0).float().sum((0))
    top_features = torch.topk(feature_counts, n)
    return top_features.indices.cpu().numpy(), top_features.values.cpu().numpy()

# Get top features for both types
perfect_top_idx, perfect_top_counts = get_top_features(perfect_acts_stacked)
non_perfect_top_idx, non_perfect_top_counts = get_top_features(non_perfect_acts_stacked)

print(f"Top {num_top_features} most active features for perfect matches:")
for idx, count in zip(perfect_top_idx, perfect_top_counts):
    print(f"Feature {idx}: activated {count:.0f} times")

print(f"\nTop {num_top_features} most active features for non-perfect matches:")
for idx, count in zip(non_perfect_top_idx, non_perfect_top_counts):
    print(f"Feature {idx}: activated {count:.0f} times")

# Check overlap in top features
overlap = set(perfect_top_idx) & set(non_perfect_top_idx)
print(f"\nNumber of overlapping features in top {num_top_features}: {len(overlap)}")

Top 10 most active features for perfect matches:
Feature 27514: activated 455 times
Feature 27281: activated 309 times
Feature 24580: activated 79 times
Feature 28968: activated 73 times
Feature 6288: activated 73 times
Feature 7582: activated 67 times
Feature 10554: activated 61 times
Feature 12146: activated 52 times
Feature 3893: activated 51 times
Feature 17423: activated 48 times

Top 10 most active features for non-perfect matches:
Feature 27514: activated 432 times
Feature 27281: activated 404 times
Feature 24580: activated 164 times
Feature 7582: activated 74 times
Feature 12146: activated 67 times
Feature 6288: activated 67 times
Feature 15432: activated 53 times
Feature 32356: activated 50 times
Feature 6411: activated 45 times
Feature 10514: activated 44 times

Number of overlapping features in top 10: 6


### What I Care About

In [131]:
# Create a set of all features that were activated in either type
overlap_features = set(perfect_top_idx) | set(non_perfect_top_idx)

# For each feature, see how many times it was activated in each type
feature_counts = (perfect_acts_stacked > 0).float().sum((0)).int().cpu().numpy()
feature_counts_non_perfect = (non_perfect_acts_stacked > 0).float().sum((0)).int().cpu().numpy()

# Create a table to show all the features (even some that we only activated by one type)
feature_counts_table = pd.DataFrame({
    'Feature': list(overlap_features),
    'Perfect': feature_counts[list(overlap_features)],
    'Non-Perfect': feature_counts_non_perfect[list(overlap_features)]
})

# Sort the table by the number of activations in each type
feature_counts_table = feature_counts_table.sort_values(by=['Perfect', 'Non-Perfect'], ascending=False)

# Print the table
print(feature_counts_table)


    Feature  Perfect  Non-Perfect
13    27514      455          432
5     27281      309          404
0     24580       79          164
4      6288       73           67
9     28968       73           23
7      7582       67           74
12    10554       61            2
10    12146       52           67
11     3893       51           21
3     17423       48            0
8     32356       34           50
1     15432       19           53
6     10514        7           44
2      6411        6           45
