The following cell will display the dataframe containing SAEBench releases and saes.

In [None]:
import pandas as pd
import os
import json
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

# TODO: Make this nicer.
df = pd.DataFrame.from_records(
    {k: v.__dict__ for k, v in get_pretrained_saes_directory().items()}
).T
df.drop(
    columns=["expected_var_explained", "expected_l0", "config_overrides", "conversion_func"],
    inplace=True,
)
filtered_df = df[
    df.release.str.contains("bench")
]  # Each row is a "release" which has multiple SAEs which may have different configs / match different hook points in a model.

print(filtered_df.head())

These are 8 SAE Bench releases: 4 for Pythia and 4 for Gemma-2-2B.

In [None]:
print(filtered_df.release)

Each idea will contain a dict of sae_id: sae_name. The ids are used to load the SAEs into SAELens.

We use the SAE names as keys in our results dictionaries, rather than the SAE ids. This is because the names are unique, and there's no possibility of mixing data between different SAEs.

In [None]:
sae_release = "sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824"
sae_id_to_name_map = filtered_df.saes_map[sae_release]
sae_name_to_id_map = {v: k for k, v in sae_id_to_name_map.items()}

print(f"First sae id: {list(sae_id_to_name_map.keys())[0]}")
print(f"First sae name: {list(sae_id_to_name_map.values())[0]}")

Here is an example of loading a Pythia SAE.

In [None]:
from sae_lens import SAE

device = "cpu"
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="sae_bench_pythia70m_sweep_topk_ctx128_0730",
    sae_id="blocks.4.hook_resid_post__trainer_10",
    device=device,
)
sae = sae.to(device=device)

The following is an example input of `sae_names` and `sae_release` for the `sparse_probing` eval function input. I'm using the names, not ids, to match the convention of our results dictionaries, but you could also pass in SAE ids if you prefer.

In [None]:
sae_release = "sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824"

sae_names = [
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824/resid_post_layer_19/trainer_0",
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824/resid_post_layer_19/trainer_1",
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824/resid_post_layer_19/trainer_2",
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824/resid_post_layer_19/trainer_3",
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824/resid_post_layer_19/trainer_4",
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824/resid_post_layer_19/trainer_5",
]


We have already committed `sae_bench_data/{release_name}_data.json`. This contains the config used in the `dictionary_learning` repo, which includes training hyperparameters, SAE type, etc. It also contains the `basic_eval_results`, which includes the `l0` and `frac_recovered`, which was obtained using the `dictionary_learning evaluate()` function. These are already computed, so we can use them when making graphs.

In [None]:
release_name = "sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824"

folder_path = "sparse_probing/src/sparse_probing_results"
filename = f"example_results_{release_name}_eval_results.json"

filepath = os.path.join(folder_path, filename)

with open(filepath, "r") as f:
    custom_eval_results = json.load(f)

sae_data_filename = f"sae_bench_data/{release_name}_data.json"

with open(sae_data_filename, "r") as f:
    sae_data = json.load(f)

As we can see, `sae_data` contains two keys: 'sae_config_dictionary_learning' and 'basic_eval_results'. Within each key, we have all SAE names for that release.

In [None]:
print(sae_data.keys())
first_key = list(sae_data.keys())[0]
print(list(sae_data[first_key].keys())[:5])
first_sae_key = list(sae_data[first_key].keys())[0]
print("\n", sae_data[first_key][first_sae_key])

Our custom eval results should look like the following. It should contain a `custom_eval_config` key, which contains all hyperparameters and config values to reproduce the results.

`custom_eval_results` contains a dict, where every key is an SAE name, and every value is another dict containing various results from the eval. This dict can be immediately loaded in to `graph_sae_results.ipynb` to create various plots.

In [None]:
print(custom_eval_results.keys())
print(custom_eval_results["custom_eval_config"])
print(custom_eval_results["custom_eval_results"][first_sae_key])