# Match SAE Features Through Training

This notebook goes through the process of matching up SAE features for different SAEs trained on the same model and dataset but for different model checkpoints throughout training.

In [1]:
import numpy as np
import wandb
import torch
import plotly.express as px
import pandas as pd
import os
import pickle as pkl

from src.feature_matching import get_full_cfg, get_sae, get_model, get_dataset, get_activations, match_features
from scripts.train_sparse_autoencoders import get_checkpoint_indices
from transformer_lens.loading_from_pretrained import PYTHIA_CHECKPOINTS

torch.set_grad_enabled(False)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
api = wandb.Api()

# Project is specified by <entity/project-name>
runs = api.runs("roganinglis/pythia-sae")

summary_list, config_list, name_list = [], [], []
checkpoint_data = []
for run in runs:
    if 'pythia-70m-anthropic-l1-3' not in run.tags:
        continue

    checkpoint_data.append({
        'path': f'/home/rogan/git_repos/sparse-feature-circuit-development/scripts/{run.config["checkpoint_path"]}/final_400003072',
        'original_checkpoint_index': run.config['model_from_pretrained_kwargs']['checkpoint_index'],
        'summary': run.summary._json_dict,
        'config': {k: v for k,v in run.config.items()
          if not k.startswith('_')},
        'name': run.name
    })
checkpoint_data = sorted(checkpoint_data, key=lambda x: -x['original_checkpoint_index'])

In [3]:
batch_size = 32
total_tokens = 640000
sparsity_threshold = 1e-6

Compute activations for SAEs

In [4]:
if os.path.isfile('activation_data.pkl'):
    with open('activation_data.pkl', 'rb') as f:
        activation_data_to_load = pkl.load(f)
    
    activation_data = dict()
    for k,v in activation_data_to_load.items():
        activation_data[k] = {
            'activations': v['activations'],
            'dead_indices': v['dead_indices'],
        }
    
    del activation_data_to_load
else:
    activation_data = dict()

In [5]:
for checkpoint_dict in checkpoint_data:
    checkpoint_path = checkpoint_dict['path']
    full_cfg = get_full_cfg(checkpoint_path)
    sae = get_sae(checkpoint_path, full_cfg)
    model = get_model(sae, full_cfg)
    
    if checkpoint_path in activation_data:
        activation_data[checkpoint_path]['cfg'] = full_cfg
        activation_data[checkpoint_path]['sae'] = sae
        activation_data[checkpoint_path]['model'] = model
        continue
    
    dataset = get_dataset(full_cfg)
    activations = get_activations(
        sae,
        model,
        dataset,
        batch_size=batch_size,
        total_tokens=total_tokens,
        sparsity_threshold=sparsity_threshold
    )
    
    activation_data[checkpoint_path] = {
        'activations': activations,
        'dead_indices': torch.argwhere(activations.sum(0).to_dense() == 0).squeeze(),
        'cfg': full_cfg,
        'sae': sae,
        'model': model
    }

  return self.fget.__get__(instance, owner)()
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m-deduped into HookedTransformer


In [6]:
# Cache the data to disk to save processing time if running again
if not os.path.isfile('activation_data.pkl'):
    activation_data_to_save = {k: {
        'activations': v['activations'],
        'dead_indices': v['dead_indices'],
    } for k,v in activation_data.items()}
    with open('activation_data.pkl', 'wb') as f:
        pkl.dump(activation_data_to_save, f)
    
    del activation_data_to_save

Match features between the SAEs

In [7]:
for i in range(len(checkpoint_data) - 1):
    indices_orig, indices_matched, matched_distances, dead_feature_indices = match_features(
        activation_data[checkpoint_data[0]['path']]['activations'],
        activation_data[checkpoint_data[i + 1]['path']]['activations']
    )
    
    # Record the matched indices and distances
    activation_data[checkpoint_data[i + 1]['path']]['matched_indices'] = indices_matched
    activation_data[checkpoint_data[i + 1]['path']]['matched_distances'] = matched_distances
    activation_data[checkpoint_data[i + 1]['path']]['dead_indices'] = dead_feature_indices
    unmatched_indices_mask = torch.ones(activation_data[checkpoint_data[i + 1]['path']]['activations'].shape[1], dtype=torch.bool)
    unmatched_indices_mask[indices_matched] = False
    activation_data[checkpoint_data[i + 1]['path']]['unmatched_indices'] = torch.argwhere(unmatched_indices_mask).squeeze()

  similarity = activations_orig.T @ activations_matched


Extract some data for plotting

In [8]:
checkpoint_indices = []
num_matched_features = []
num_unmatched_features = []
num_dead_features = []
cosine_similarities = []
num_features = activation_data[checkpoint_data[0]['path']]['activations'].shape[1]
distance_threshold = 0.5
distance_thresholds = [0.1, 0.3, 0.5, 0.7, 0.9]
num_matched_features_thresholds = {x: [] for x in distance_thresholds}
for checkpoint_dict in checkpoint_data[::-1]:
    activation_dict = activation_data[checkpoint_dict['path']]
    checkpoint_indices.append(checkpoint_dict['original_checkpoint_index'])
    if 'matched_indices' not in activation_dict:
        num_matched_features.append(num_features)
        num_unmatched_features.append(0)
        num_dead_features.append(len(activation_dict['dead_indices']))
        cosine_similarities.append(torch.ones(num_features))
        for dist_thresh in distance_thresholds:
            num_matched_features_thresholds[dist_thresh].append(num_features)
        continue
    
    threshold_mask = activation_dict['matched_distances'] < distance_threshold
    num_matched_features.append(len(activation_dict['matched_indices'][threshold_mask]))
    
    for dist_thresh in distance_thresholds:
        threshold_mask = activation_dict['matched_distances'] < dist_thresh
        num_matched_features_thresholds[dist_thresh].append(len(activation_dict['matched_indices'][threshold_mask]))
    
    num_unmatched_features.append(len(activation_dict['unmatched_indices']) + len(activation_dict['matched_indices'][~threshold_mask]))
    num_dead_features.append(0 if not activation_dict['dead_indices'].shape else len(activation_dict['dead_indices']))
    cosine_similarities.append(1 - activation_dict['matched_distances'])

In [13]:
cosine_similarities_binned = [np.histogram(x, bins=np.linspace(0, 1, 50)) for x in cosine_similarities]
bin_edges = cosine_similarities_binned[0][1]
histogram_data = np.array([x[0] for x in cosine_similarities_binned])
checkpoint_steps = [PYTHIA_CHECKPOINTS[i] for i in checkpoint_indices]
#px.imshow(histogram_data.T, x=checkpoint_steps, y=bin_edges[:-1], title='Activation Cosine Similarity Histograms Through Training', aspect='auto')
import plotly.graph_objects as go

z_data_log = np.log(histogram_data.T + 1)
z_min_log = np.min(z_data_log)
z_max_log = np.max(z_data_log)

# Generate more tick values within the desired range
num_ticks = 10  # Adjust the number of desired tick values
tick_vals = np.linspace(z_min_log, z_max_log, num_ticks)

# Convert tick values back to original scale
tick_vals_original = (np.exp(tick_vals) - 1)

fig = go.Figure()
fig.add_heatmap(
    z=histogram_data.T[:, :-1],
    x=checkpoint_steps[:-1],
    y=bin_edges[:-1],
    #colorscale='blackbody',
    reversescale=True,
    colorbar=dict(title='Number of Features')  #, tickvals=tick_vals.astype(int).tolist(), ticktext=tick_vals_original.astype(int).tolist())
)
fig.update_layout(
    title='Activation Cosine Similarity Histograms Through Training'
)
fig.update_xaxes(title_text='Training Steps')
fig.update_yaxes(title_text='Cosine Similarity')

#fig.write_html('../plots/activation_similarity_histograms.html')
#fig.write_image('../plots/activation_similarity_histograms.png')

Plot histogram of feature similarities

In [10]:
feature_similarities = []
final_w_dec = activation_data[checkpoint_data[0]['path']]['sae'].W_dec
for checkpoint_dict in checkpoint_data[::-1]:
    activation_dict = activation_data[checkpoint_dict['path']]
    if 'matched_distances' not in activation_dict:
        activation_dict['matched_feature_similarities'] = np.ones(activation_dict['activations'].shape[1])
    else:
        activation_dict['matched_feature_similarities'] = torch.sum(final_w_dec * activation_dict['sae'].W_dec[activation_dict['matched_indices']], dim=1).detach().cpu().numpy()
    
    feature_similarities.append(activation_dict['matched_feature_similarities'])

feature_similarities_binned = [np.histogram(x, bins=np.linspace(0, 1, 50)) for x in feature_similarities]
bin_edges = feature_similarities_binned[0][1]
histogram_data = np.array([x[0] for x in feature_similarities_binned])
#px.imshow(histogram_data.T, x=checkpoint_indices, y=bin_edges[:-1], title='Feature Similarity Histograms', aspect='auto')

z_data_log = np.log(histogram_data.T + 1)
z_min_log = np.min(z_data_log)
z_max_log = np.max(z_data_log)

# Generate more tick values within the desired range
num_ticks = 10  # Adjust the number of desired tick values
tick_vals = np.linspace(z_min_log, z_max_log, num_ticks)

# Convert tick values back to original scale
tick_vals_original = (np.exp(tick_vals) - 1)

fig = go.Figure()
fig.add_heatmap(
    z=z_data_log[:, :-1],
    x=checkpoint_steps[:-1],
    y=bin_edges[:-1],
    #colorscale='Blues',
    colorbar=dict(title='Number of Features', tickvals=tick_vals.astype(int).tolist(), ticktext=tick_vals_original.astype(int).tolist())
)
fig.update_layout(
    title='Feature Cosine Similarity Histograms Through Training'
)
fig.update_xaxes(title_text='Training Steps')
fig.update_yaxes(title_text='Cosine Similarity')

fig.write_html('../plots/feature_similarity_histograms.html')
fig.write_image('../plots/feature_similarity_histograms.png')

In [18]:
df = pd.DataFrame({
    'checkpoint_index': checkpoint_indices[:-1],
    'checkpoint_steps': checkpoint_steps[:-1],
    'num_matched_features': [x/num_features for x in num_matched_features[:-1]],
    'num_unmatched_features': [x/num_features for x in num_unmatched_features[:-1]],
    'num_dead_features': [x/num_features for x in num_dead_features[:-1]]
})
thresholds = []
for threshold, data in num_matched_features_thresholds.items():
    df[f'{threshold}'] = [x / num_features for x in data[:-1]]
    thresholds.append(threshold)
fig = px.line(
    df, 
    x='checkpoint_steps', 
    y=[f'{threshold}' for threshold in num_matched_features_thresholds.keys()], 
    title='Fraction of Matched Features',
    range_y=[0, 1],
    labels={
        'checkpoint_steps': 'Training Steps',  # Custom X-axis label
        'value': 'Fraction of Matched Features',  # Custom Y-axis label for all lines
        'variable': 'Distance Threshold'  # Custom label for the legend
    }
)

fig.write_html('../plots/fraction_matched_features.html')
fig.write_image('../plots/fraction_matched_features.png')

In [22]:

print([x for x in zip(checkpoint_steps, checkpoint_indices)])


[(0, 0), (2, 2), (8, 4), (32, 6), (128, 8), (512, 10), (2000, 12), (4000, 14), (6000, 16), (8000, 18), (10000, 20), (16000, 26), (32000, 42), (48000, 58), (64000, 74), (79000, 89), (95000, 105), (111000, 121), (127000, 137), (143000, 153)]


In [71]:
df = pd.DataFrame({
    'checkpoint_index': checkpoint_indices,
    'checkpoint_steps': checkpoint_steps,
    'num_matched_features': [x/num_features for x in num_matched_features],
    'num_unmatched_features': [x/num_features for x in num_unmatched_features],
    'num_dead_features': [x/num_features for x in num_dead_features],
    'Overall Loss': [x['summary']['losses/overall_loss'] for x in checkpoint_data],
    'L1 Loss': [x['summary']['losses/l1_loss'] for x in checkpoint_data],
    'MSE Loss': [x['summary']['losses/mse_loss'] for x in checkpoint_data]
})
# All three of the above on the same plot
df_long = pd.melt(df, id_vars='checkpoint_steps', value_vars=['Overall Loss', 'L1 Loss', 'MSE Loss'])
fig = px.line(
    df_long, 
    x='checkpoint_steps', 
    y='value', 
    color='variable', 
    title='Losses',
    #log_y=True,
    labels={
        'checkpoint_steps': 'Training Steps',  # Custom X-axis label
        'value': 'Loss',  # Custom Y-axis label for all lines
        'variable': ''  # Custom label for the legend
    }
)
fig.write_html('../plots/losses.html')
fig.write_image('../plots/losses.png')

In [14]:
distances = []
checkpoints = []
for checkpoint_dict in checkpoint_data[::-1]:
    checkpoints.append(checkpoint_dict['original_checkpoint_index'])
    if 'matched_distances' not in activation_data[checkpoint_dict['path']]:
        distances.append(np.zeros(activation_data[checkpoint_dict['path']]['activations'].shape[1]))
        continue
    
    distances.append(activation_data[checkpoint_dict['path']]['matched_distances'])
    
distances = np.stack(distances).T
print(distances.shape)
print(len(checkpoints))

sorted_indices = np.argsort(distances.sum(1))
print(sorted_indices.shape)

px.imshow(distances[sorted_indices, :], x=checkpoints, aspect='auto', title='Matched Feature Distances')

(8192, 20)
20
(8192,)


## Find features which appear at certain points throughout training
In order to do this we need to find features which are not present before a certain point in training and then are present after. We should exclude features which appear and then disappear again. We can do this by splitting the distances tensor at the point we want to check and then thresholding both parts. We then need none of the entries on the left to be active and all the entries on the right to be active.

In [None]:
def get_appearing_features(dist, ckpt_ind, appearing_checkpoint_index, dist_threshold, num_features_sample):
    thresholded_distances = dist < dist_threshold
    
    # Get early appearing features
    appearing_index = checkpoint_indices.index(appearing_checkpoint_index)
    appearing_features = (~thresholded_distances[:, :appearing_index]).all(1) & thresholded_distances[:, appearing_index:].all(1)
    appearing_feature_indices = np.argwhere(appearing_features).squeeze()
    
    # Sample some of these features
    return np.random.choice(appearing_feature_indices, num_features_sample, replace=False)

In [74]:
threshold = 0.5
num_features_sample = 10
print(checkpoint_indices)
early_checkpoint_index = 8
middle_checkpoint_index = 74
late_checkpoint_index = 121
print(f'Early step: {PYTHIA_CHECKPOINTS[early_checkpoint_index]}')
print(f'Middle step: {PYTHIA_CHECKPOINTS[middle_checkpoint_index]}')
print(f'Late step: {PYTHIA_CHECKPOINTS[late_checkpoint_index]}')

early_feature_indices = get_appearing_features(distances, checkpoint_indices, early_checkpoint_index, threshold, num_features_sample)
middle_feature_indices = get_appearing_features(distances, checkpoint_indices, middle_checkpoint_index, threshold, num_features_sample)
late_feature_indices = get_appearing_features(distances, checkpoint_indices, late_checkpoint_index, threshold, num_features_sample)

print(early_feature_indices)
print(middle_feature_indices)
print(late_feature_indices)
px.imshow(distances[middle_feature_indices, :] < threshold, x=checkpoints, aspect='auto', title='Early Appearing Features')

[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 26, 42, 58, 74, 89, 105, 121, 137, 153]
Early step: 128
Middle step: 64000
Late step: 111000


NameError: name 'get_appearing_features' is not defined

## Visualise the features throughout training

In [None]:
from scripts.feature_visualisation import visualise_features_for_sae_path
from shutil import copytree, rmtree
import os
from safetensors import safe_open
from safetensors.torch import save_file


def update_sae_feature_order(sae_path_src, sae_path_dst, feature_indices):
    copytree(sae_path_src, sae_path_dst)
    ckpt_path = os.path.join(sae_path_dst, 'sae_weights.safetensors')
    
    tensors = {}
    with safe_open(ckpt_path, framework='pt', device='cpu') as f:
        for k in f.keys():
            tensors[k] = f.get_tensor(k)
    
    tensors['W_dec'] = tensors['W_dec'][feature_indices]
    tensors['W_enc'] = tensors['W_enc'][:, feature_indices]
    tensors['b_enc'] = tensors['b_enc'][feature_indices]
    
    save_file(tensors, ckpt_path)
    

def save_visualisations_for_features(feature_indices, out_dir):
    for checkpoint_dict in checkpoint_data:
        sae_path = checkpoint_dict['path']
        if 'matched_indices' in activation_data[sae_path]:
            # Copy sae checkpoint to another location and update the feature indices so they are in the same order as the final checkpoint
            new_sae_path = sae_path
            if new_sae_path.endswith('/'):
                new_sae_path = new_sae_path[:-1]
            new_sae_path = f'{new_sae_path}_updated_feature_order'
            if os.path.isdir(new_sae_path):
                rmtree(new_sae_path)
            update_sae_feature_order(sae_path, new_sae_path, activation_data[sae_path]['matched_indices'])
            sae_path = new_sae_path
        visualise_features_for_sae_path(sae_path, feature_indices, out_dir=out_dir)
        
# Early appearing features
save_visualisations_for_features(early_feature_indices, f'features_during_training/early_appearing_features_{early_checkpoint_index}')

# Middle appearing features
save_visualisations_for_features(middle_feature_indices, f'features_during_training/middle_appearing_features_{middle_checkpoint_index}')

# Late appearing features
save_visualisations_for_features(late_feature_indices, f'features_during_training/late_appearing_features_{late_checkpoint_index}')