# Plotting Custom Metric Results


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import json
import torch
import pickle
from typing import Optional
from matplotlib.colors import Normalize
import numpy as np
import os

from sae_bench_utils.graphing_utils import (
    sae_name_to_info,
    plot_2var_graph,
    plot_2var_graph_dict_size,
    plot_3var_graph,
    plot_interactive_3var_graph,
    plot_training_steps,
    plot_correlation_heatmap,
    plot_correlation_scatter,
)

from sae_bench_utils.sae_selection_utils import select_saes_multiple_patterns

This cell is for the following purpose:

Currently, we have a handle of folders, like `absorption` or `core`. In each folder, we have a bunch of tar.gz files. We want a single e.g. `absorption` folder which has a single level and contains all .json results for all SAEs.

To run this, create a folder called `eval_results/` and move `absorption/`, `core/`, etc in to this folder.

In [None]:
import os
import tarfile
import shutil

# List of folders to process
# folders = ["absorption", "core", "scr", "tpp", "autointerp"]
folders = ["core"]


# Function to extract tar.gz files and move JSON files to the parent folder
def extract_and_move_json_files(base_folder):
    # Get all files in the folder
    for filename in os.listdir(base_folder):
        file_path = os.path.join(base_folder, filename)

        # Process only .tar.gz files
        if filename.endswith(".tar.gz"):
            # Extract the tar.gz file
            with tarfile.open(file_path, "r:gz") as tar:
                # Extract to a temporary subfolder to avoid conflicts
                temp_extract_folder = os.path.join(base_folder, "temp_extract")
                os.makedirs(temp_extract_folder, exist_ok=True)
                tar.extractall(path=temp_extract_folder)

            # Remove the original tar.gz file after extraction
            os.remove(file_path)

            # Move all extracted .json files from the temp folder to the base folder
            for root, _, files in os.walk(temp_extract_folder):
                for file in files:
                    if file.endswith(".json"):
                        json_file_path = os.path.join(root, file)
                        destination_path = os.path.join(base_folder, file)

                        # Check if the file already exists and handle overwriting
                        if os.path.exists(destination_path):
                            print(f"Overwriting: {destination_path}")
                            os.remove(destination_path)

                        # Move the file
                        shutil.move(json_file_path, destination_path)

            # Clean up the temporary extraction folder
            shutil.rmtree(temp_extract_folder)


# Iterate over each folder and process its contents
for folder in folders:
    folder_path = os.path.join("eval_results", folder)
    extract_and_move_json_files(folder_path)

print("Extraction and file moving completed.")

## Load data

Select one of the following `eval_path`, or add your own.

In [None]:
eval_path = "./eval_results/scr"
eval_path = "./eval_results/tpp"
eval_path = "./evals/autointerp/11_12_24_autointerp_results"

core_results_path = "./eval_results/core"
image_path = "./images"

if not os.path.exists(image_path):
    os.makedirs(image_path)

Now select SAEs using the regex patterns. Using a list of sae regex patterns allows selecting SAEs using multiple patterns.

This cell stores both the custom eval (e.g. SCR or sparse probing) and the core evals (L0 / Loss Recovered) for every SAE identified by the regex pattern

In [None]:
sae_regex_patterns = [
    r"sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824",
    # r"sae_bench_gemma-2-2b_sweep_standard_ctx128_ef8_0824",
    # r"sae_bench_gemma-2-2b_sweep_topk_ctx128_ef2_0824",
    # r"sae_bench_gemma-2-2b_sweep_standard_ctx128_ef2_0824",
    # r"(gemma-scope-2b-pt-res)",
]

layer = 19

sae_block_pattern = [
    # rf".*blocks\.{layer}(?!.*step).*",
    # rf".*blocks\.{layer}(?!.*step).*",
    rf".*blocks\.{layer}(?!.*step).*",
    rf".*blocks\.{layer}(?!.*step).*",
    # rf".*layer_({layer}).*(16k).*", # For Gemma-Scope
]

# Include checkpoints
sae_block_pattern = [
    # rf".*blocks\.{layer}.*",
    rf".*blocks\.{layer}.*",
]

sae_regex_patterns = [
    r"sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824",
    r"sae_bench_gemma-2-2b_sweep_standard_ctx128_ef8_0824",
    # r"(gemma-scope-2b-pt-res)",
]
sae_block_pattern = [
    r".*blocks\.19(?!.*step).*",
    r".*blocks\.19(?!.*step).*",
    # r".*layer_(19).*(16k).*",
]

assert len(sae_regex_patterns) == len(sae_block_pattern)

selected_saes = select_saes_multiple_patterns(sae_regex_patterns, sae_block_pattern)


def get_eval_results(selected_saes: list[tuple[str, str]], results_path: str) -> dict:
    eval_results = {}
    for sae_release, sae_id in selected_saes:
        filename = f"{sae_release}_{sae_id}_eval_results.json".replace("/", "_")
        filepath = os.path.join(results_path, filename)

        with open(filepath, "r") as f:
            single_sae_results = json.load(f)

        if "tpp" in results_path:
            eval_results[f"{sae_release}_{sae_id}"] = single_sae_results["eval_result_metrics"][
                "tpp_metrics"
            ]
        elif "scr" in results_path:
            eval_results[f"{sae_release}_{sae_id}"] = single_sae_results["eval_result_metrics"][
                "scr_metrics"
            ]
        elif "absorption" in results_path:
            eval_results[f"{sae_release}_{sae_id}"] = single_sae_results["eval_result_metrics"][
                "mean"
            ]
        elif "autointerp" in results_path:
            eval_results[f"{sae_release}_{sae_id}"] = single_sae_results["eval_result_metrics"][
                "autointerp"
            ]
        else:
            raise ValueError("Please add the correct key for the eval results")
    return eval_results


def get_core_results(selected_saes: list[tuple[str, str]], core_path: str) -> dict:
    core_results = {}
    for sae_release, sae_id in selected_saes:
        filename = f"{sae_release}-{sae_id}_128_Skylion007_openwebtext.json".replace("/", "_")
        filepath = os.path.join(core_path, filename)

        with open(filepath, "r") as f:
            single_sae_results = json.load(f)

        l0 = single_sae_results["eval_result_metrics"]["sparsity"]["l0"]
        ce_score = single_sae_results["eval_result_metrics"]["model_performance_preservation"][
            "ce_loss_score"
        ]

        core_results[f"{sae_release}_{sae_id}"] = {"l0": l0, "frac_recovered": ce_score}
    return core_results


eval_results = get_eval_results(selected_saes, eval_path)
core_results = get_core_results(selected_saes, core_results_path)

for sae in eval_results:
    eval_results[sae].update(core_results[sae])

In [None]:
sae_names = list(eval_results.keys())

print(eval_results.keys())
print("\nAvailable SAEs:\n", eval_results.keys())

For plotting purposes we also want dictionary size, sae type, and number of training steps. The following cell populates this information.

In [None]:
# Gather all values in one dict for plotting
plotting_results = eval_results

for sae_name in eval_results:
    sae_config = sae_name_to_info(sae_name)
    plotting_results[sae_name].update(sae_config)

## Plot custom metric above unsupervised metrics


In [None]:
print("\nAvailable custom metrics:\n", eval_results[sae_names[0]].keys())

In [None]:
k = 100

if "tpp" in eval_path:
    custom_metric = f"tpp_threshold_{k}_total_metric"
    custom_metric_name = f"TPP Top {k} Metric"
elif "scr" in eval_path:
    custom_metric = f"scr_metric_threshold_{k}"
    custom_metric_name = f"SCR Top {k} Metric"
elif "absorption" in eval_path:
    custom_metric = "mean_absorption_score"
    custom_metric_name = "Mean Absorption Score"
elif "autointerp" in eval_path:
    custom_metric = "autointerp_score"
    custom_metric_name = "Autointerp Score"
else:
    raise ValueError("Please add the correct key for the custom metric")

title_3var = f"L0 vs Loss Recovered vs {custom_metric_name}"
title_2var = f"L0 vs {custom_metric_name}"
image_base_name = os.path.join(image_path, custom_metric)

plot_3var_graph(
    plotting_results,
    title_3var,
    custom_metric,
    colorbar_label="Custom Metric",
    output_filename=f"{image_base_name}_3var.png",
)
plot_2var_graph(
    plotting_results,
    custom_metric,
    y_label=custom_metric_name,
    title=title_2var,
    output_filename=f"{image_base_name}_2var.png",
)
# plot_interactive_3var_graph(plotting_results, custom_metric)

# At this point, if there's any additional .json files located alongside the ae.pt and eval_results.json
# You can easily adapt them to be included in the plotting_results dictionary by using something similar to add_ae_config_results()

### ...with interactive hovering


In [None]:
plot_interactive_3var_graph(
    plotting_results,
    custom_metric,
    title=title_3var,
    output_filename=f"{image_base_name}_3var_interactive.html",
)

In [None]:
plot_2var_graph_dict_size(
    plotting_results,
    custom_metric,
    y_label=custom_metric_name,
    title=title_2var,
    output_filename=f"{image_base_name}_2var.png",
)

## Plot metric over training checkpoints


Note: We have SAE checkpoints at initialization (step 0), which does not fit on
a log scale (log(0) = -inf). We visualize this with a cut in the graph.

In [None]:
plot_training_steps(
    plotting_results,
    custom_metric,
    title=f"Steps vs {custom_metric_name} Gemma Layer {layer}",
    output_filename=f"{image_base_name}_steps_vs_diff.png",
)

This cell combines all of the above steps into a single function so we can plot results from multiple runs.

In [None]:
def plot_results(
    eval_path: str,
    core_results_path: str,
    sae_regex_patterns: list[str],
    sae_block_pattern: list[str],
    k: int
):
    assert len(sae_regex_patterns) == len(sae_block_pattern)

    selected_saes = select_saes_multiple_patterns(sae_regex_patterns, sae_block_pattern)

    eval_results = get_eval_results(selected_saes, eval_path)
    core_results = get_core_results(selected_saes, core_results_path)

    for sae in eval_results:
        eval_results[sae].update(core_results[sae])

    plotting_results = eval_results

    for sae_name in eval_results:
        sae_config = sae_name_to_info(sae_name)
        plotting_results[sae_name].update(sae_config)

    if "tpp" in eval_path:
        custom_metric = f"tpp_threshold_{k}_total_metric"
        custom_metric_name = f"TPP Top {k} Metric"
    elif "scr" in eval_path:
        custom_metric = f"scr_metric_threshold_{k}"
        custom_metric_name = f"SCR Top {k} Metric"
    elif "absorption" in eval_path:
        custom_metric = "mean_absorption_score"
        custom_metric_name = "Mean Absorption Score"
    else:
        raise ValueError("Please add the correct key for the custom metric")

    title_3var = f"L0 vs Loss Recovered vs {custom_metric_name}"
    title_2var = f"L0 vs {custom_metric_name}"
    image_base_name = os.path.join(image_path, custom_metric)

    plot_3var_graph(
        plotting_results,
        title_3var,
        custom_metric,
        colorbar_label="Custom Metric",
        output_filename=f"{image_base_name}_3var.png",
    )
    plot_2var_graph(
        plotting_results,
        custom_metric,
        y_label=custom_metric_name,
        title=title_2var,
        output_filename=f"{image_base_name}_2var.png",
    )
    plot_2var_graph_dict_size(
        plotting_results,
        custom_metric,
        y_label=custom_metric_name,
        title=title_2var,
        output_filename=f"{image_base_name}_2var_dict_size.png",
    )

    plot_training_steps(
        plotting_results,
        custom_metric,
        title=f"Steps vs {custom_metric_name} Gemma Layer {layer}",
        output_filename=f"{image_base_name}_steps_vs_diff.png",
    )
eval_path = "./eval_results/scr"
# eval_path = "./eval_results/tpp"
# eval_path = "./eval_results/absorption"

core_results_path = "./eval_results/core"
for layer in [7, 19]:
    sae_regex_patterns = [
        r"sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824",
        # r"sae_bench_gemma-2-2b_sweep_standard_ctx128_ef8_0824",
        r"sae_bench_gemma-2-2b_sweep_topk_ctx128_ef2_0824",
        # r"sae_bench_gemma-2-2b_sweep_standard_ctx128_ef2_0824",
        # r"(gemma-scope-2b-pt-res)",
    ]

    sae_block_pattern = [
        # rf".*blocks\.{layer}(?!.*step).*",
        # rf".*blocks\.{layer}(?!.*step).*",
        rf".*blocks\.{layer}(?!.*step).*",
        rf".*blocks\.{layer}(?!.*step).*",
        # rf".*layer_({layer}).*(16k).*", # For Gemma-Scope
    ]

    # Include checkpoints
    # sae_block_pattern = [
        # rf".*blocks\.{layer}.*",
        # rf".*blocks\.{layer}.*",
    # ]

    plot_results(eval_path, core_results_path, sae_regex_patterns, sae_block_pattern, k=20)

## Plot metric correlations


In [None]:
# k=100
# custom_metric = f'sae_top_{k}_test_accuracy'

metric_keys = [
    "l0",
    "frac_recovered",
    custom_metric,
]

plot_correlation_heatmap(plotting_results, metric_names=metric_keys, ae_names=None)

In [None]:
# Simple example usage:
# plot_metric_scatter(plotting_results, metric_x="l0", metric_y="frac_recovered", title="L0 vs Fraction Recovered")

threshold_x = 50
threshold_y = 100

metric_x = f"sae_top_{threshold_x}_test_accuracy"
metric_y = f"sae_top_{threshold_y}_test_accuracy"

title = f""
x_label = "k=1 Sparse Probe Accuracy"
y_label = "k=100 Sparse Probe Accuracy"
output_filename = os.path.join(
    image_path,
    f"sparse_probing_result_correlation_for_thresholds_{threshold_y}_{threshold_y}.png",
)

plot_correlation_scatter(
    plotting_results,
    metric_x=metric_x,
    metric_y=metric_y,
    title=title,
    x_label=x_label,
    y_label=y_label,
    output_filename=output_filename,
)