The following cell will display the dataframe containing SAEBench releases and saes.

In [None]:
import pandas as pd
import os
import re
import json
from formatting_utils import make_available_sae_df

overview_df = make_available_sae_df(for_printing=True)

# pandas display options
max_hook_point_length = overview_df['unique_hook_points'].astype(str).map(len).max()
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', int(max_hook_point_length))

# print selected columns
show_cols = ['release', 'model', 'unique_hook_points', 'n_saes_per_hook', 'has_training_checkpoints', 'saes_map']
overview_df[show_cols]

Each row is a "release" which has multiple SAEs which may have different configs / match different hook points in a model. These are 8 SAE Bench releases: 4 for Pythia and 4 for Gemma-2-2B.

Each release will contain an saes_map, a dict of sae_id: sae_name. The `sae_ids` are SAE Lens specific, used to load the SAEs into SAELens.

In this project, we use the `sae_names` as keys in our results dictionaries, rather than the sae_ids. This is because the names are unique, and there's no possibility of mixing data between different SAEs.

In [None]:
sae_release = "sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824"
mask = overview_df['release'] == sae_release
sae_id_to_name_map = overview_df[mask].saes_map.item()
sae_name_to_id_map = {v: k for k, v in sae_id_to_name_map.items()}

print(f"First sae id: {list(sae_id_to_name_map.keys())[0]}")
print(f"First sae name: {list(sae_id_to_name_map.values())[0]}")

Here is an example of loading a Pythia SAE.

In [None]:
from sae_lens import SAE

device = "cpu"
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="sae_bench_pythia70m_sweep_topk_ctx128_0730",
    sae_id="blocks.4.hook_resid_post__trainer_10",
    device=device,
)
sae = sae.to(device=device)

The following is an example input of `sae_names` and `sae_release` for the `sparse_probing` eval function input. I'm using the names, not ids, to match the convention of our results dictionaries, but you could also pass in SAE ids if you prefer.

In [None]:
sae_release = "sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824"

sae_names = [
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824/resid_post_layer_19/trainer_0",
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824/resid_post_layer_19/trainer_1",
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824/resid_post_layer_19/trainer_2",
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824/resid_post_layer_19/trainer_3",
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824/resid_post_layer_19/trainer_4",
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824/resid_post_layer_19/trainer_5",
]


This repo already contains info on all SAEs we're using at  `sae_bench_data/{release_name}_data.json`. This contains the config used in the `dictionary_learning` repo, which includes training hyperparameters, SAE type, etc. It also contains the `basic_eval_results`, which includes the `l0` and `frac_recovered`, which was obtained using the `dictionary_learning evaluate()` function. These are already computed, so we can use them when making graphs.

In [None]:
release_name = "sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824"
sae_data_filename = f"sae_bench_data/{release_name}_data.json"

with open(sae_data_filename, "r") as f:
    sae_data = json.load(f)

As we can see, `sae_data` contains two keys: 'sae_config_dictionary_learning' and 'basic_eval_results'. Within each key, we have all SAE names for that release.

In [None]:
example_sae_names = list(sae_data['sae_config_dictionary_learning'].keys())[:5]
example_sae_name = example_sae_names[0]
example_sae_config = sae_data['sae_config_dictionary_learning'][example_sae_name]
example_basic_eval_result = sae_data['basic_eval_results'][example_sae_name]

print(sae_data.keys())
print('\nExample evaluated SAEs:\n', example_sae_names)
print('\nFirst SAE config:\n', example_sae_config)
print('\nFirst basic eval result:\n', example_basic_eval_result)

The results file of the custom eval you're implementing will contain a `custom_eval_config` and `custom_eval_results`. 

`custom_eval_config` contains a dict of hyperparameters and config values to reproduce the results.

`custom_eval_results` contains a dict, where every key is an SAE name, and every value is another dict containing various results from the eval. This dict can be immediately loaded in to `graph_sae_results.ipynb` to create various plots.

In [None]:
folder_path = "sparse_probing/sparse_probing_results"
filename = f"example_results_{release_name}_eval_results.json"
filepath = os.path.join(folder_path, filename)

with open(filepath, "r") as f:
    custom_eval_results = json.load(f)

print(custom_eval_results.keys())
print(f'\nCustom eval config:\n{custom_eval_results["custom_eval_config"]}')
print(f'\nCustom eval results:\n{custom_eval_results["custom_eval_results"][example_sae_name]}')