In [None]:
import matplotlib.pyplot as plt
import json
from typing import Optional
import os

def get_nested_folders(path: str) -> list[str]:
    """
    Recursively get a list of folders that contain a config.json file, starting the search from the given path
    """
    folder_names = []

    for root, dirs, files in os.walk(path):
        if "config.json" in files:
            folder_names.append(root)

    return folder_names



In [None]:
TRAINER_LABELS = {
    "StandardTrainer": "Standard",
    "JumpReluTrainer": "JumpReLU",
    "TopKTrainer": "Top K",
    "BatchTopKTrainer": "Batch Top K",
    "GatedSAETrainer": "Gated",
    "PAnnealTrainer": "P-Anneal",
}

TRAINER_MARKERS = {
    "StandardTrainer": "o",
    "JumpReluTrainer": "X",
    "TopKTrainer": "s",
    "BatchTopKTrainer": "d",
    "GatedSAETrainer": "d",
    "PAnnealTrainer": "s",
}

TRAINER_COLORS = {
    "StandardTrainer": "blue",
    "JumpReluTrainer": "orange",
    "TopKTrainer": "green",
    "BatchTopKTrainer": "black",
    "GatedSAETrainer": "red",
    "PAnnealTrainer": "purple",
}

In [None]:
save_dirs = ["./top_k", "./batch_top_k", "./jumprelu"]
# save_dirs = ["./run2"]
ae_paths = []

for save_dir in save_dirs:
    ae_paths.extend(get_nested_folders(save_dir))

print(ae_paths)


In [None]:
plotting_results = {}

for ae_path in ae_paths:
    with open(ae_path + "/config.json") as f:
        config = json.load(f)

    with open(ae_path + "/eval_results.json") as f:
        eval_results = json.load(f)

    ae_results = {}

    ae_results["l0"] = eval_results["l0"]
    ae_results["frac_recovered"] = eval_results["frac_recovered"]
    ae_results["trainer_class"] = config["trainer"]["trainer_class"]
    ae_results["dict_size"] = config["trainer"]["dict_size"]

    ae_results['frac_alive'] = eval_results['frac_alive']

    plotting_results[ae_path] = ae_results
print(plotting_results)

In [None]:
def plot_2var_graph(
    results: dict[str, dict[str, float]],
    custom_metric: str,
    title: str = "L0 vs Custom Metric",
    y_label: str = "Custom Metric",
    xlims: Optional[tuple[float, float]] = None,
    ylims: Optional[tuple[float, float]] = None,
    output_filename: Optional[str] = None,
    legend_location: str = "lower right",
    x_axis_key: str = "l0",
    return_fig: bool = False,
):
    # Extract data from results
    l0_values = [data[x_axis_key] for data in results.values()]
    custom_metric_values = [data[custom_metric] for data in results.values()]

    # Create the scatter plot
    fig, ax = plt.subplots(figsize=(10, 6))

    handles, labels = [], []

    for trainer, marker in TRAINER_MARKERS.items():
        # Filter data for this trainer
        trainer_data = {k: v for k, v in results.items() if v["trainer_class"] == trainer}

        if not trainer_data:
            continue  # Skip this trainer if no data points

        l0_values = [data[x_axis_key] for data in trainer_data.values()]
        custom_metric_values = [data[custom_metric] for data in trainer_data.values()]

        # Plot data points
        scatter = ax.scatter(
            l0_values,
            custom_metric_values,
            marker=marker,
            s=100,
            label=trainer,
            color=TRAINER_COLORS[trainer],
            edgecolor="black",
        )

        # Create custom legend handle with both marker and color
        legend_handle = plt.scatter(
            [], [], marker=marker, s=100, color=TRAINER_COLORS[trainer], edgecolor="black"
        )
        handles.append(legend_handle)

        if trainer in TRAINER_LABELS:
            trainer_label = TRAINER_LABELS[trainer]
        else:
            trainer_label = trainer.capitalize()
        labels.append(trainer_label)

    # Set labels and title
    ax.set_xlabel("L0 (Sparsity)")
    ax.set_ylabel(y_label)
    ax.set_title(title)

    ax.legend(handles, labels, loc=legend_location)

    # Set axis limits
    if xlims:
        ax.set_xlim(*xlims)
    if ylims:
        ax.set_ylim(*ylims)

    plt.tight_layout()

    # Save and show the plot
    if output_filename:
        plt.savefig(output_filename, bbox_inches="tight")

    if return_fig:
        return fig

    plt.show()
    
plt.rcParams.update({"font.size": 20})
plot_2var_graph(plotting_results, "frac_recovered", title="Fraction Recovered vs L0", y_label="Fraction Recovered", output_filename="frac_recovered_vs_l0.png")