# Aggregate SAE Features to Compare Memorized vs NonMemorized Sequences

In [1]:
import pandas as pd
import torch
import numpy as np
import plotly.express as px
from sae_lens import SAE
from transformer_lens import HookedTransformer
import random
from tqdm import tqdm

random_seed = 42
random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load model using TransformerLens
model = HookedTransformer.from_pretrained("meta-llama/Llama-3.1-8B", device=device, torch_dtype=torch.bfloat16)

# No need for separate tokenizer as it's included in the HookedTransformer

Using device: cuda


Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 119.97it/s]


Loaded pretrained model meta-llama/Llama-3.1-8B into HookedTransformer


In [3]:
# Load the SAEs for all layers
sae_layers = range(28, 32)
release = "llama_scope_lxr_8x"
saes = []
print("Loading SAEs for all layers...")
for layer in tqdm(sae_layers):
    sae_id = f"l{layer}r_8x"
    sae = SAE.from_pretrained(release, sae_id)[0]
    sae = sae.to(device)
    saes.append(sae)
print("Finished loading SAEs")

Loading SAEs for all layers...


100%|██████████| 4/4 [00:52<00:00, 13.13s/it]

Finished loading SAEs





### Load Memorization Data

In [4]:
# Load the CSV file
df = pd.read_csv('memorization_results_L50_O50.csv')

# Split into perfect and non-perfect matches
perfect_matches = df[df['perfect_match'] == True]
non_perfect_matches = df[(df['perfect_match'] == False) & (df['matching_tokens'] == 0)]

# Sample from each
num_samples = 20
perfect_sample = perfect_matches.sample(n=num_samples, random_state=random_seed)
non_perfect_sample = non_perfect_matches.sample(n=num_samples, random_state=random_seed)

print(f"Total perfect matches: {len(perfect_matches)}")
print(f"Total non-perfect matches: {len(non_perfect_matches)}")

Total perfect matches: 26
Total non-perfect matches: 83


### Get SAE Activations

In [5]:
def get_activations(text_input, text_output):
    """Get SAE activations for a given input-output pair using TransformerLens"""
    # Concatenate input and output
    full_text = text_input + text_output
    
    # Tokenize and get model activations using run_with_cache
    tokens = model.to_tokens(full_text, prepend_bos=False)
    _, cache = model.run_with_cache(tokens)
    
    # Get layer activations for all layers
    all_layer_acts = {}
    for layer, sae in zip(sae_layers, saes):
        # Get layer activations
        layer_acts = cache['resid_pre', layer]
        
        # Get SAE activations
        sae_acts = sae.encode(layer_acts)
        
        # Only keep activations for output tokens (last 50)
        output_acts = sae_acts[:, -50:, :]
        
        all_layer_acts[layer] = output_acts.squeeze(0)  # Remove batch dimension
        
    return all_layer_acts

In [6]:
if 'perfect_acts_stacked' in locals():
    del perfect_acts_stacked
if 'non_perfect_acts_stacked' in locals():
    del non_perfect_acts_stacked

# Get activations for perfect matches
perfect_activations_by_layer = {layer: [] for layer in sae_layers}
print("Processing perfect matches...")
for _, row in tqdm(perfect_sample.iterrows()):
    acts_dict = get_activations(row['input_text'], row['generated_continuation_text'])
    for layer in sae_layers:
        perfect_activations_by_layer[layer].append(acts_dict[layer])

# Get activations for non-perfect matches
non_perfect_activations_by_layer = {layer: [] for layer in sae_layers}
print("\nProcessing non-perfect matches...")
for _, row in tqdm(non_perfect_sample.iterrows()):
    acts_dict = get_activations(row['input_text'], row['generated_continuation_text'])
    for layer in sae_layers:
        non_perfect_activations_by_layer[layer].append(acts_dict[layer])

# Stack all activations for each layer
perfect_acts_stacked = {
    layer: torch.stack(acts, dim=0) 
    for layer, acts in perfect_activations_by_layer.items()
}
non_perfect_acts_stacked = {
    layer: torch.stack(acts, dim=0)
    for layer, acts in non_perfect_activations_by_layer.items()
}

Processing perfect matches...


20it [00:02,  9.98it/s]



Processing non-perfect matches...


20it [00:01, 11.49it/s]


### Compute Activation Statistics

In [7]:
# # Calculate statistics
# def compute_stats(activations):
#     """Compute various statistics for the activations"""
#     # L0 (number of active features)
#     l0 = (activations > 0).float().sum(-1).mean().item()
    
#     # Mean activation when active
#     mean_active = activations[activations > 0].mean().item()
    
#     # Max activation
#     max_act = activations.max().item()
    
#     # Feature sparsity (fraction of features that never activate)
#     feature_sparsity = ((activations > 0).sum(0) == 0).float().mean().item()
    
#     return {
#         'L0 (avg active features)': l0,
#         'Mean activation when active': mean_active,
#         'Max activation': max_act,
#         'Feature sparsity': feature_sparsity
#     }

# perfect_stats = compute_stats(perfect_acts_stacked)
# non_perfect_stats = compute_stats(non_perfect_acts_stacked)

# # Print statistics
# print("Statistics for perfect matches:")
# for k, v in perfect_stats.items():
#     print(f"{k}: {v:.4f}")

# print("\nStatistics for non-perfect matches:")
# for k, v in non_perfect_stats.items():
#     print(f"{k}: {v:.4f}")

In [8]:
# # Plot L0 histograms
# l0_perfect = (perfect_acts_stacked > 0).float().sum(-1).cpu().numpy()
# l0_non_perfect = (non_perfect_acts_stacked > 0).float().sum(-1).cpu().numpy()

# # Create a DataFrame for plotting
# l0_data = pd.DataFrame({
#     'L0': np.concatenate([l0_perfect, l0_non_perfect]),
#     'Type': ['Perfect Match'] * len(l0_perfect) + ['Non-Perfect Match'] * len(l0_non_perfect)
# })

# # Plot histogram
# fig = px.histogram(l0_data, x='L0', color='Type', barmode='overlay',
#                   title='Distribution of Active Features (L0)',
#                   labels={'L0': 'Number of Active Features', 'count': 'Frequency'})
# fig.show()

In [9]:
# # Plot activation magnitude distributions
# act_perfect = perfect_acts_stacked[perfect_acts_stacked > 0].float().detach().cpu().numpy()
# act_non_perfect = non_perfect_acts_stacked[non_perfect_acts_stacked > 0].float().detach().cpu().numpy()

# # Create a DataFrame for plotting
# act_data = pd.DataFrame({
#     'Activation': np.concatenate([act_perfect, act_non_perfect]),
#     'Type': ['Perfect Match'] * len(act_perfect) + ['Non-Perfect Match'] * len(act_non_perfect)
# })

# # Plot histogram
# fig = px.histogram(act_data, x='Activation', color='Type', barmode='overlay',
#                   title='Distribution of Activation Magnitudes',
#                   labels={'Activation': 'Activation Value', 'count': 'Frequency'})
# fig.show()

### Compare SAE Features

In [10]:
num_top_features = 50

# Analyze feature usage patterns
def get_top_features(activations, n=num_top_features):
    """Get the most frequently activated features"""
    # Compress activations across all samples first
    compressed_acts = activations.reshape(-1, activations.shape[-1])
    feature_counts = (compressed_acts > 0).float().sum(0)
    top_features = torch.topk(feature_counts, n)
    return top_features.indices.cpu().numpy(), top_features.values.cpu().numpy()

# Since we're working with a dict now, we need to handle one layer at a time
aggregate_feature_comparisons_tables = {}
for layer in sae_layers:
    print(f"\nAnalyzing layer {layer}:")
    
    # Get the activation tensors for this layer
    perfect_acts = perfect_acts_stacked[layer]
    non_perfect_acts = non_perfect_acts_stacked[layer]

    # Get top features for both types
    perfect_top_idx, perfect_top_counts = get_top_features(perfect_acts)
    non_perfect_top_idx, non_perfect_top_counts = get_top_features(non_perfect_acts)

    # Create a set of all features that were activated in either type
    overlap_features = set(perfect_top_idx.tolist()) | set(non_perfect_top_idx.tolist())

    # For each feature, see how many times it was activated in each type
    perfect_acts_flat = perfect_acts.reshape(-1, perfect_acts.shape[-1])
    non_perfect_acts_flat = non_perfect_acts.reshape(-1, non_perfect_acts.shape[-1])
    
    feature_counts = (perfect_acts_flat > 0).float().sum(0).int().cpu().numpy()
    feature_counts_non_perfect = (non_perfect_acts_flat > 0).float().sum(0).int().cpu().numpy()
    
    # Create a table to show all the features (even some that we only activated by one type)
    feature_counts_table = pd.DataFrame({
        'Feature': list(overlap_features),
        'Perfect': feature_counts[list(overlap_features)],
        'Non-Perfect': feature_counts_non_perfect[list(overlap_features)]
    })
    
    # Add difference columns
    feature_counts_table['Abs_Difference'] = abs(feature_counts_table['Perfect'] - feature_counts_table['Non-Perfect'])
    feature_counts_table['Percent_Difference'] = (feature_counts_table['Abs_Difference'] / 
                                                feature_counts_table[['Perfect', 'Non-Perfect']].max(axis=1) * 100)
    
    # Filter for absolute difference >= 100
    feature_counts_table = feature_counts_table[feature_counts_table['Abs_Difference'] >= 100]

    # Sort the table by the number of activations in each type
    feature_counts_table = feature_counts_table.sort_values(by=['Perfect', 'Non-Perfect'], ascending=False)

    # Print the table
    print(feature_counts_table)
    aggregate_feature_comparisons_tables[layer] = feature_counts_table


Analyzing layer 28:
    Feature  Perfect  Non-Perfect  Abs_Difference  Percent_Difference
64     9966      511          795             284           35.723270
35     4245      445          103             342           76.853933
12    21044      432          323             109           25.231481
29    24190      123            0             123          100.000000
33    24462       63          290             227           78.275862
58    11993       46          172             126           73.255814
14    32063       11          192             181           94.270833

Analyzing layer 29:
    Feature  Perfect  Non-Perfect  Abs_Difference  Percent_Difference
73    22526      485          820             335           40.853659
21    16698      465          624             159           25.480769
54    14008      277           74             203           73.285199
24     8008      224          359             135           37.604457
43     4238      189          295             10

In [11]:
# Create directory for feature comparison tables if it doesn't exist
import os
output_dir = "aggregate_feature_comparisons"
os.makedirs(output_dir, exist_ok=True)

for layer, feature_counts_table in aggregate_feature_comparisons_tables.items():
    # Save the feature counts table for this layer
    output_path = os.path.join(output_dir, f"feature_counts_layer_{layer}.csv")
    feature_counts_table.to_csv(output_path, index=False)
    print(f"Saved feature counts for layer {layer} to {output_path}")

Saved feature counts for layer 28 to aggregate_feature_comparisons/feature_counts_layer_28.csv
Saved feature counts for layer 29 to aggregate_feature_comparisons/feature_counts_layer_29.csv
Saved feature counts for layer 30 to aggregate_feature_comparisons/feature_counts_layer_30.csv
Saved feature counts for layer 31 to aggregate_feature_comparisons/feature_counts_layer_31.csv
