The following cell will display the dataframe containing SAEBench releases and saes.

In [None]:
import pandas as pd
import os
import re
import json

import utils.formatting_utils as formatting_utils

overview_df = formatting_utils.make_available_sae_df(for_printing=True)

# pandas display options
max_hook_point_length = overview_df["unique_hook_points"].astype(str).map(len).max()
pd.set_option("display.max_columns", None)
pd.set_option("display.width", None)
pd.set_option("display.max_colwidth", int(max_hook_point_length))

# print selected columns
show_cols = [
    "release",
    "model",
    "unique_hook_points",
    "n_saes_per_hook",
    "has_training_checkpoints",
    "saes_map",
]
overview_df[show_cols]

Each row is a "release" which has multiple SAEs which may have different configs / match different hook points in a model. These are 8 SAE Bench releases: 4 for Pythia and 4 for Gemma-2-2B.

Each release will contain an saes_map, a dict of sae_id: sae_name. The `sae_ids` are SAE Lens specific, used to load the SAEs into SAELens.

In this project, we use the `sae_names` as keys in our results dictionaries, rather than the sae_ids. This is because the names are unique, and there's no possibility of mixing data between different SAEs.

In [None]:
sae_df = formatting_utils.make_available_sae_df(for_printing=False)

sae_release = "sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824"

sae_id_to_name_map = sae_df.saes_map[sae_release]
sae_name_to_id_map = {v: k for k, v in sae_id_to_name_map.items()}

print(f"First sae id: {list(sae_id_to_name_map.keys())[0]}")
print(f"First sae name: {list(sae_id_to_name_map.values())[0]}")

As an example, here's a dictionary of sae_release: all saes for a given layer for that Gemma release. This is the input format that we are using for `sparse_probing/`. Note that in this particular example we are not including checkpoints.

In [None]:
sae_releases = [
    "gemma-scope-2b-pt-res",
    "sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824",
    "sae_bench_gemma-2-2b_sweep_standard_ctx128_ef8_0824",
]

layer = 19

selected_saes_dict = {}

for release in sae_releases:
    if "gemma-scope" in release:
        selected_saes_dict[release] = formatting_utils.find_gemmascope_average_l0_sae_names(layer)
    else:
        selected_saes_dict[release] = formatting_utils.filter_sae_names(
            sae_names=release, layers=[layer], include_checkpoints=False, trainer_ids=None
        )

for key in selected_saes_dict:
    print("\n\n", key, "\n\n",selected_saes_dict[key])

This cells gets all Gemma checkpoints. Notice that it also includes the final SAE, which is not included in the checkpoints folder.

In [None]:
sae_releases = [
    # "gemma-scope-2b-pt-res",
    "sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824",
    "sae_bench_gemma-2-2b_sweep_standard_ctx128_ef8_0824",
]

layer = 19

selected_saes_dict = {}

for release in sae_releases:
    if "gemma-scope" in release:
        selected_saes_dict[release] = formatting_utils.find_gemmascope_average_l0_sae_names(layer)
    else:
        selected_saes_dict[release] = formatting_utils.filter_sae_names(
            sae_names=release, layers=[layer], include_checkpoints=True, trainer_ids=None
        )

for key in selected_saes_dict:
    print("\n\n", key, "\n\n",selected_saes_dict[key])

This cell gets all standard and topk SAEs for Pythia layer 4. 

In [None]:
pythia_sae_releases = [
    "sae_bench_pythia70m_sweep_standard_ctx128_0712",
    "sae_bench_pythia70m_sweep_topk_ctx128_0730",
    # "sae_bench_pythia70m_sweep_gated_ctx128_0730",
    # "sae_bench_pythia70m_sweep_panneal_ctx128_0730",
]

layer = 4

selected_saes_dict = {}

for release in pythia_sae_releases:
    if "gemma-scope" in release:
        selected_saes_dict[release] = formatting_utils.find_gemmascope_average_l0_sae_names(layer)
    else:
        selected_saes_dict[release] = formatting_utils.filter_sae_names(
            sae_names=release, layers=[layer], include_checkpoints=False, trainer_ids=None
        )

for key in selected_saes_dict:
    print("\n\n", key, "\n\n",selected_saes_dict[key])

When testing we may want to run only a single SAE. This cell only runs a single Pythia TopK SAE.

In [None]:
pythia_sae_releases = [
    # "sae_bench_pythia70m_sweep_standard_ctx128_0712",
    "sae_bench_pythia70m_sweep_topk_ctx128_0730",
    # "sae_bench_pythia70m_sweep_gated_ctx128_0730",
    # "sae_bench_pythia70m_sweep_panneal_ctx128_0730",
]

layer = 4

selected_saes_dict = {}

for release in pythia_sae_releases:
    if "gemma-scope" in release:
        selected_saes_dict[release] = formatting_utils.find_gemmascope_average_l0_sae_names(layer)
    else:
        selected_saes_dict[release] = formatting_utils.filter_sae_names(
            sae_names=release, layers=[layer], include_checkpoints=False, trainer_ids=[10]
        )

for key in selected_saes_dict:
    print(key, selected_saes_dict[key])

Here is an example of loading a Pythia SAE. In the current version of SAE Lens, SAE Bench SAEs are loaded as Standard SAEs, not TopK SAEs. Until a new version of SAE Lens is released with a fix, there is a workaround below, where we call `fix_topk_saes()`.

In [None]:
from sae_lens import SAE
from sae_lens.sae import TopK

pythia_sae_release = "sae_bench_pythia70m_sweep_topk_ctx128_0730"

sae_name = 'pythia70m_sweep_topk_ctx128_0730/resid_post_layer_4/trainer_10'

sae_id_to_name_map = sae_df.saes_map[pythia_sae_release]
sae_name_to_id_map = {v: k for k, v in sae_id_to_name_map.items()}

sae_id = sae_name_to_id_map[sae_name]

device = "cpu"
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release=pythia_sae_release,
    sae_id=sae_id,
    device=device,
)
sae = sae.to(device=device)

print(f"Is sae topk? {isinstance(sae.activation_fn, TopK)}")

sae = formatting_utils.fix_topk_saes(sae, pythia_sae_release, sae_name)

print(f"Is sae topk? {isinstance(sae.activation_fn, TopK)}")

print(cfg_dict)

This repo already contains info on all SAEs we're using at  `sae_bench_data/{release_name}_data.json`. This contains the config used in the `dictionary_learning` repo, which includes training hyperparameters, SAE type, etc. It also contains the `basic_eval_results`, which includes the `l0` and `frac_recovered`, which was obtained using the `dictionary_learning evaluate()` function. These are already computed, so we can use them when making graphs.

In [None]:
release_name = "sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824"
sae_data_filename = f"sae_bench_data/{release_name}_data.json"

with open(sae_data_filename, "r") as f:
    sae_data = json.load(f)

As we can see, `sae_data` contains two keys: 'sae_config_dictionary_learning' and 'basic_eval_results'. Within each key, we have all SAE names for that release.

In [None]:
example_sae_names = list(sae_data["sae_config_dictionary_learning"].keys())[:5]
example_sae_name = example_sae_names[0]
example_sae_config = sae_data["sae_config_dictionary_learning"][example_sae_name]
example_basic_eval_result = sae_data["basic_eval_results"][example_sae_name]

print(sae_data.keys())
print("\nExample evaluated SAEs:\n", example_sae_names)
print("\nFirst SAE config:\n", example_sae_config)
print("\nFirst basic eval result:\n", example_basic_eval_result)

The results file of the custom eval you're implementing will contain a `custom_eval_config` and `custom_eval_results`. 

`custom_eval_config` contains a dict of hyperparameters and config values to reproduce the results.

`custom_eval_results` contains a dict, where every key is an SAE name, and every value is another dict containing various results from the eval. This dict can be immediately loaded in to `graph_sae_results.ipynb` to create various plots.

In [None]:
folder_path = "evals/sparse_probing/results"
filename = "example_gemma-2-2b_layer_19_eval_results.json"
filepath = os.path.join(folder_path, filename)

with open(filepath, "r") as f:
    custom_eval_results = json.load(f)

print(custom_eval_results.keys())
print(f'\nCustom eval config:\n{custom_eval_results["custom_eval_config"]}')
print(f'\nCustom eval results for {example_sae_name}:\n{custom_eval_results["custom_eval_results"][example_sae_name]}')