In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from typing import Any, Optional

import evals.core.main as core
import evals.scr_and_tpp.main as scr_and_tpp
import evals.sparse_probing.main as sparse_probing
import sae_bench_utils.general_utils as general_utils
import custom_saes.custom_sae_config as custom_sae_config
import custom_saes.vanilla_sae as vanilla_sae
from sae_bench_utils.sae_selection_utils import get_saes_from_regex
import custom_saes.run_all_evals_custom_saes as run_all_evals_custom_saes

RANDOM_SEED = 42

output_folders = {
    "absorption": "eval_results/absorption",
    "autointerp": "eval_results/autointerp",
    "core": "eval_results/core",
    "scr": "eval_results/scr",
    "tpp": "eval_results/tpp",
    "sparse_probing": "eval_results/sparse_probing",
    "unlearning": "eval_results/unlearning",
}

# Note: Unlearning is not recommended for models with < 2B parameters and we recommend an instruct tuned model
# Unlearning will also require requesting permission for the WMDP dataset (see unlearning/README.md)
# Absorption not recommended for models < 2B parameters
# asyncio doesn't like notebooks, so autointerp must be ran using a python script

# Select your eval types here. 
eval_types = [
    "absorption",
    # "autointerp",
    "core",
    "scr",
    "tpp",
    "sparse_probing",
    # "unlearning",
]

if "autointerp" in eval_types:
    try:
        with open("openai_api_key.txt") as f:
            api_key = f.read().strip()
    except FileNotFoundError:
        raise Exception("Please create openai_api_key.txt with your API key")
else:
    api_key = None

This cell loads your custom SAEs.

In [None]:
device = general_utils.setup_environment()

model_name = "pythia-70m-deduped"
llm_batch_size = 512
dtype = "float32"


# If evaluating multiple SAEs on the same layer, set save_activations to True
# This will require at least 100GB of disk space
save_activations = False

repo_id = "canrager/lm_sae"
filename = "pythia70m_sweep_standard_ctx128_0712/resid_post_layer_4/trainer_8/ae.pt"
hook_layer = 4
hook_name = f"blocks.{hook_layer}.hook_resid_post"

sae = vanilla_sae.load_vanilla_sae(repo_id, filename, hook_layer)
sae = sae.to(device, dtype=general_utils.str_to_dtype(dtype))

print(f"sae dtype: {sae.dtype}, device: {sae.device}")

d_sae, d_in = sae.W_dec.data.shape

print(f"d_in: {d_in}, d_sae: {d_sae}")

sae.cfg = custom_sae_config.CustomSAEConfig(
    model_name, d_in=d_in, d_sae=d_sae, hook_name=hook_name, hook_layer=hook_layer
)

# Annoyingly, for core evals we currently need to set the dtype in the config separately, and it must be a string
sae.cfg.dtype = dtype

custom_sae_id = filename.replace("/", "_").replace(".", "_")
print(f"sae_id: {custom_sae_id}")

# list of tuple of (sae_id, sae object)
selected_saes = [(custom_sae_id, sae)]


Select your baseline SAEs here. Refer to `sae_regex_selection.ipynb` for more regex patterns. We are going to get a topk SAE from the same layer.

`selected_saes` is a list of tuples of (sae_id, sae object) OR (sae lens release, sae lens id).

In [None]:
sae_regex_pattern = r"(sae_bench_pythia70m_sweep_topk_ctx128_0730).*"
sae_block_pattern = r".*blocks\.([4])\.hook_resid_post__trainer_(8)$"

baseline_saes = get_saes_from_regex(sae_regex_pattern, sae_block_pattern)
print(f"baseline_saes: {baseline_saes}")
selected_saes.extend(baseline_saes)
baseline_sae_id = f"{baseline_saes[0][0]}_{baseline_saes[0][1]}".replace(".", "_")
print(f"baseline_sae_id: {baseline_sae_id}")

In [None]:
_ = core.multiple_evals(
    filtered_saes=selected_saes,
    n_eval_reconstruction_batches=200,
    n_eval_sparsity_variance_batches=200,
    eval_batch_size_prompts=32,
    compute_featurewise_density_statistics=False,
    compute_featurewise_weight_based_metrics=False,
    exclude_special_tokens_from_reconstruction=True,
    dataset="Skylion007/openwebtext",
    context_size=128,
    output_folder="eval_results/core",
    verbose=True,
    dtype=dtype,
)

In [None]:
# We do a subset of the sparse probing datasets here for shorter runtime
dataset_names = ["LabHC/bias_in_bios_class_set1"]

# TODO: Add a verbose flag

_ = sparse_probing.run_eval(
    sparse_probing.SparseProbingEvalConfig(
        model_name=model_name,
        random_seed=RANDOM_SEED,
        llm_batch_size=llm_batch_size,
        llm_dtype=dtype,
        dataset_names=dataset_names,
    ),
    selected_saes,
    device,
    "eval_results/sparse_probing",
    force_rerun=False,
    clean_up_activations=True,
    save_activations=save_activations,
)


The below cell will run all evals on the full datasets. By default, we don't do this as it's pretty time consuming (~1 hour).

In [None]:
# _ = run_all_evals_custom_saes.run_evals(
#     model_name,
#     selected_saes,
#     llm_batch_size,
#     dtype,
#     device,
#     eval_types,
#     api_key,
#     force_rerun=False,
#     save_activations=save_activations,
# )

In [None]:
import matplotlib.pyplot as plt
import json
import torch
import pickle
from typing import Optional
from matplotlib.colors import Normalize
import numpy as np
import os

from sae_bench_utils.graphing_utils import (
    sae_name_to_info,
    plot_2var_graph,
    plot_2var_graph_dict_size,
    plot_3var_graph,
    plot_interactive_3var_graph,
    plot_training_steps,
    plot_correlation_heatmap,
    plot_correlation_scatter,
)

from sae_bench_utils.sae_selection_utils import select_saes_multiple_patterns

In [None]:
eval_path = "./eval_results/sparse_probing"

core_results_path = "./eval_results/core"
image_path = "./images"

if not os.path.exists(image_path):
    os.makedirs(image_path)

In [None]:
core_results = {}

# TODO: Come up with a better way than this janky manual id creation
sae_ids = [custom_sae_id, baseline_sae_id]

for sae_id in sae_ids:
    if sae_id == custom_sae_id:
        filename = f"{sae_id}_custom_sae_128_Skylion007_openwebtext.json"
    else:
        filename = f"{sae_id}_128_Skylion007_openwebtext.json"
    filepath = os.path.join(core_results_path, filename)

    with open(filepath, "r") as f:
        single_sae_results = json.load(f)

    l0 = single_sae_results["eval_result_metrics"]["sparsity"]["l0"]
    ce_score = single_sae_results["eval_result_metrics"]["model_performance_preservation"][
        "ce_loss_score"
    ]

    core_results[sae_id] = {"l0": l0, "frac_recovered": ce_score}

In [None]:
eval_results = {}
llm_results = None
for sae_id in sae_ids:
    if sae_id == custom_sae_id:
        filename = f"{sae_id}_custom_sae_eval_results.json"
    else:
        filename = f"{sae_id}_eval_results.json"
    filepath = os.path.join(eval_path, filename)
    with open(filepath, "r") as f:
        single_sae_results = json.load(f)

    print(single_sae_results["eval_result_metrics"].keys())
    eval_results[sae_id] = single_sae_results["eval_result_metrics"]["sae"]
    llm_results = single_sae_results["eval_result_metrics"]["llm"]

    eval_results[sae_id].update(core_results[sae_id])

print(eval_results[custom_sae_id].keys())
print(llm_results.keys())

`trainer_markers` is a dict of sae_class to shape. Please replace `new_sae_key` with your preferred SAE class name.

In [None]:
new_sae_key = "Vanilla"
trainer_markers = {
    "Standard": "o",
    "JumpReLU": "X",
    "TopK": "^",
    "Standard w/ p-annealing": "*",
    "Gated": "d",
    new_sae_key: "s",  # New SAE
}

eval_results[custom_sae_id]["sae_class"] = new_sae_key
eval_results[baseline_sae_id]["sae_class"] = "TopK"

In [None]:
print(f"Custom SAE top 1 accuracy was: {eval_results[custom_sae_id]['sae_top_1_test_accuracy']}")
print(
    f"Baseline SAE top 1 accuracy was: {eval_results[baseline_sae_id]['sae_top_1_test_accuracy']}"
)
print(f"LLM top 1 accuracy was: {llm_results['llm_top_1_test_accuracy']}")

In [None]:
custom_metric = "sae_top_1_test_accuracy"

title_3var = "Sparse Probing vs L0 vs Loss Recovered"
title_2var = "Sparse Probing vs L0"
custom_metric_name = "Loss Recovered"
image_base_name = os.path.join(image_path, custom_metric)

plot_3var_graph(
    eval_results,
    title_3var,
    custom_metric,
    colorbar_label="Custom Metric",
    output_filename=f"{image_base_name}_3var.png",
    trainer_markers=trainer_markers,
)
plot_2var_graph(
    eval_results,
    custom_metric,
    y_label=custom_metric_name,
    title=title_2var,
    output_filename=f"{image_base_name}_2var.png",
    trainer_markers=trainer_markers,
)