# Plotting Custom Metric Results


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import json
import torch
import pickle
from typing import Optional
from matplotlib.colors import Normalize
import numpy as np
import os

from sae_bench_utils.graphing_utils import (
    plot_2var_graph,
    plot_3var_graph,
    plot_interactive_3var_graph,
    plot_training_steps,
    plot_correlation_heatmap,
    plot_correlation_scatter,
)

from sae_bench_utils.formatting_utils import (
    get_sparsity_penalty,
    extract_saes_unique_info,
    ae_config_results,
    add_custom_metric_results,
    filter_by_l0_threshold,
)

## Load data


In [None]:
eval_path = "./evals/sparse_probing"
image_path = os.path.join(eval_path, "images")
results_path = os.path.join(eval_path, "results")

if not os.path.exists(image_path):
    os.makedirs(image_path)

In [None]:
## Example results for Pythia (does not contain training checkpoints)
filename = "example_pythia-70m-deduped_layer_4_eval_results.json"

## Example results for Gemma (does not contain training checkpoints)
# filename = "example_gemma-2-2b_layer_19_eval_results.json"

## Example results for Gemma (does contain training checkpoints)
filename = "example_gemma-2-2b_layer_19_with_checkpoints_eval_results.json"


filepath = os.path.join(results_path, filename)

with open(filepath, "r") as f:
    eval_results = json.load(f)

In [None]:
sae_names = list(eval_results["custom_eval_results"].keys())

print(eval_results.keys())
print("\nAvailable SAEs:\n", eval_results["custom_eval_results"].keys())
print(
    "\nAvailable custom metrics:\n", eval_results["custom_eval_results"][sae_names[0]].keys()
)

In this cell, we find all of the sae_releases for the data file, and aggregate
all of the data into `sae_data`. `sae_data` contains basic metrics like L0 and
Loss Recovered, in addition to trainer parameters like dict size, sparsity
penalty, SAE type, etc.


In [None]:
sae_releases = eval_results["custom_eval_config"]["sae_releases"]

sae_data = {"basic_eval_results": {}, "sae_config_dictionary_learning": {}}

for release_name in sae_releases:
    sae_data_filename = f"sae_bench_data/{release_name}_data.json"

    with open(sae_data_filename, "r") as f:
        sae_release_data = json.load(f)

    sae_data["basic_eval_results"].update(sae_release_data["basic_eval_results"])
    sae_data["sae_config_dictionary_learning"].update(
        sae_release_data["sae_config_dictionary_learning"]
    )

In [None]:
print(sae_data.keys())
# print('\nAvailable SAEs:\n', sae_data["basic_eval_results"].keys())

first_sae_name = next(iter(sae_data["basic_eval_results"]))
print("\nAvailable basic metrics:\n", sae_data["basic_eval_results"][first_sae_name].keys())

In [None]:
first_sae_name = next(iter(sae_data["sae_config_dictionary_learning"]))
print(
    "\nAvailable config info:\n",
    sae_data["sae_config_dictionary_learning"][first_sae_name]["trainer"].keys(),
)

In [None]:
# Gather all values in one dict for plotting
plotting_results = eval_results["custom_eval_results"]

for sae_name in eval_results["custom_eval_results"]:
    plotting_results[sae_name]["l0"] = sae_data["basic_eval_results"][sae_name]["l0"]
    plotting_results[sae_name]["sparsity_penalty"] = get_sparsity_penalty(
        sae_data["sae_config_dictionary_learning"][sae_name]
    )
    plotting_results[sae_name]["frac_recovered"] = sae_data["basic_eval_results"][sae_name][
        "frac_recovered"
    ]

    # Add all trainer info
    plotting_results[sae_name] = (
        plotting_results[sae_name]
        | sae_data["sae_config_dictionary_learning"][sae_name]["trainer"]
    )
    plotting_results[sae_name]["buffer"] = sae_data["sae_config_dictionary_learning"][
        sae_name
    ]["buffer"]

## Plot custom metric above unsupervised metrics


In [None]:
k = 2
custom_metric = f"sae_top_{k}_test_accuracy"

custom_metric_name = f"k={k}-Sparse Probe Accuracy"
title_3var = f"L0 vs Loss Recovered vs {custom_metric_name}"
title_2var = f"L0 vs {custom_metric_name}"
image_base_name = os.path.join(image_path, custom_metric)

plot_3var_graph(
    plotting_results,
    title_3var,
    custom_metric,
    colorbar_label="Custom Metric",
    output_filename=f"{image_base_name}_3var.png",
)
plot_2var_graph(
    plotting_results,
    custom_metric,
    title=title_2var,
    output_filename=f"{image_base_name}_2var.png",
)
# plot_interactive_3var_graph(plotting_results, custom_metric)

# At this point, if there's any additional .json files located alongside the ae.pt and eval_results.json
# You can easily adapt them to be included in the plotting_results dictionary by using something similar to add_ae_config_results()

### ...with interactive hovering


In [None]:
plot_interactive_3var_graph(
    plotting_results,
    custom_metric,
    title=title_3var,
    output_filename=f"{image_base_name}_3var_interactive.html",
)

## Plot metric over training checkpoints


In [None]:
# Check which SAEs with checkpoints are actually available
extract_saes_unique_info(sae_names, checkpoint_only=True)

Note: We have SAE checkpoints at initialization (step 0), which does not fit on
a log scale (log(0) = -inf). We visualize this with a cut in the graph.

Note: If the list above is empty, there are no checkpoints available. The plot
below will only show values for the final training step.


In [None]:
plot_training_steps(
    plotting_results,
    custom_metric,
    title=f"Steps vs {custom_metric_name}",
    output_filename=f"{image_base_name}_steps_vs_diff.png",
)

## Plot metric correlations


In [None]:
# k=100
# custom_metric = f'sae_top_{k}_test_accuracy'

metric_keys = [
    "l0",
    "frac_recovered",
    custom_metric,
]

plot_correlation_heatmap(plotting_results, metric_names=metric_keys, ae_names=None)

In [None]:
# Simple example usage:
# plot_metric_scatter(plotting_results, metric_x="l0", metric_y="frac_recovered", title="L0 vs Fraction Recovered")

threshold_x = 50
threshold_y = 100

metric_x = f"sae_top_{threshold_x}_test_accuracy"
metric_y = f"sae_top_{threshold_y}_test_accuracy"

title = f""
x_label = "k=1 Sparse Probe Accuracy"
y_label = "k=100 Sparse Probe Accuracy"
output_filename = os.path.join(
    image_path,
    f"sparse_probing_result_correlation_for_thresholds_{threshold_y}_{threshold_y}.png",
)

plot_correlation_scatter(
    plotting_results,
    metric_x=metric_x,
    metric_y=metric_y,
    title=title,
    x_label=x_label,
    y_label=y_label,
    output_filename=output_filename,
)