In [None]:
import matplotlib.pyplot as plt
import json
import torch
import pickle
from typing import Optional
from matplotlib.colors import Normalize
import numpy as np
import os

from graphing_utils import plot_2var_graph, plot_3var_graph, plot_interactive_3var_graph, plot_steps_vs_average_diff
from formatting_utils import get_sparsity_penalty, ae_config_results, add_custom_metric_results, filter_by_l0_threshold

In [None]:
image_folder_name = "images"

if not os.path.exists(image_folder_name):
    os.makedirs(image_folder_name)

In [None]:
release_name = "sae_bench_pythia70m_sweep_topk_ctx128_0730"

folder_path = "sparse_probing/src/sparse_probing_results"
filename = f"example_results_{release_name}_eval_results.json"

filepath = os.path.join(folder_path, filename)

with open(filepath, "r") as f:
    eval_results = json.load(f)

In [None]:
print(eval_results.keys())
print(eval_results["custom_eval_results"].keys())
print(eval_results["custom_eval_results"]["pythia70m_sweep_topk_ctx128_0730/resid_post_layer_4/trainer_0"].keys())

In [None]:
sae_data_filename = f"sae_bench_data/{release_name}_data.json"

with open(sae_data_filename, "r") as f:
    sae_data = json.load(f)

In [None]:
print(sae_data.keys())
print(sae_data["basic_eval_results"].keys())

In [None]:
plotting_results = {}
k= 100

custom_metric = f'sae_top_{k}_test_accuracy'

for sae_name in eval_results['custom_eval_results']:
    plotting_results[sae_name] = {}

    plotting_results[sae_name]['trainer_class'] = sae_data['sae_config_dictionary_learning'][sae_name]["trainer"]["trainer_class"]
    plotting_results[sae_name]['l0'] = sae_data['basic_eval_results'][sae_name]['l0']
    plotting_results[sae_name]['frac_recovered'] = sae_data['basic_eval_results'][sae_name]['frac_recovered']

    plotting_results[sae_name][custom_metric] = eval_results['custom_eval_results'][sae_name][custom_metric]

In [None]:
custom_metric_name = f"{k}-Sparse Probe Accuracy"
title = f"L0 vs Loss Recovered vs {custom_metric_name}"
image_base_name = os.path.join(image_folder_name, custom_metric)

plot_3var_graph(
    plotting_results,
    title,
    custom_metric,
    colorbar_label="Custom Metric",
    output_filename=f"{image_base_name}_3var.png",
)
plot_2var_graph(
    plotting_results,
    custom_metric,
    title=title,
    output_filename=f"{image_base_name}_2var.png",
)
# plot_interactive_3var_graph(plotting_results, custom_metric)

# At this point, if there's any additional .json files located alongside the ae.pt and eval_results.json
# You can easily adapt them to be included in the plotting_results dictionary by using something similar to add_ae_config_results()

TODO: Fix below code

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np


def plot_correlation_heatmap(
    plotting_results: dict[str, dict[str, float]],
    metric_names: list[str],
    ae_names: Optional[list[str]] = None,
    title: str = "Metric Correlation Heatmap",
    output_filename: str = None,
    figsize: tuple = (12, 10),
    cmap: str = "coolwarm",
    annot: bool = True,
):
    # If ae_names is not provided, use all ae_names from plotting_results
    if ae_names is None:
        ae_names = list(plotting_results.keys())

    # If metric_names is not provided, use all metric names from the first ae_name
    # if metric_names is None:
    #     metric_names = list(plotting_results[ae_names[0]].keys())

    # Create a DataFrame from the plotting_results
    data = []
    for ae in ae_names:
        row = [plotting_results[ae].get(metric, np.nan) for metric in metric_names]
        data.append(row)

    df = pd.DataFrame(data, index=ae_names, columns=metric_names)

    # Calculate the correlation matrix
    corr_matrix = df.corr()

    # Create the heatmap
    plt.figure(figsize=figsize)
    sns.heatmap(corr_matrix, annot=annot, cmap=cmap, vmin=-1, vmax=1, center=0)

    plt.title(title)
    plt.tight_layout()

    # Save the plot if output_filename is provided
    if output_filename:
        plt.savefig(output_filename, bbox_inches="tight")

    plt.show()


metric_keys = [
    "l0",
    "frac_recovered",
    "tpp_attrib_threshold_10_total_metric",
    "tpp_attrib_threshold_50_total_metric",
    "tpp_attrib_threshold_500_total_metric",
    "tpp_auto_interp_threshold_10_total_metric",
    "tpp_auto_interp_threshold_50_total_metric",
]

plot_correlation_heatmap(plotting_results, metric_names=metric_keys, ae_names=None)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy import stats
from typing import Optional


def plot_metric_scatter(
    plotting_results: dict[str, dict[str, float]],
    metric_x: str,
    metric_y: str,
    x_label: Optional[str] = None,
    y_label: Optional[str] = None,
    ae_names: Optional[list[str]] = None,
    title: str = "Metric Comparison Scatter Plot",
    output_filename: Optional[str] = None,
    figsize: tuple = (10, 8),
):
    # If ae_names is not provided, use all ae_names from plotting_results
    if ae_names is None:
        ae_names = list(plotting_results.keys())

    # Extract x and y values for the specified metrics
    x_values = [plotting_results[ae].get(metric_x, float("nan")) for ae in ae_names]
    y_values = [plotting_results[ae].get(metric_y, float("nan")) for ae in ae_names]

    # Remove any NaN values
    valid_data = [
        (x, y, ae)
        for x, y, ae in zip(x_values, y_values, ae_names)
        if not (np.isnan(x) or np.isnan(y))
    ]
    if not valid_data:
        print("No valid data points after removing NaN values.")
        return

    x_values, y_values, valid_ae_names = zip(*valid_data)

    # Convert to numpy arrays
    x_values = np.array(x_values)
    y_values = np.array(y_values)

    # Calculate correlation coefficients
    r, p_value = stats.pearsonr(x_values, y_values)
    r_squared = r**2

    # Create the scatter plot
    plt.figure(figsize=figsize)
    scatter = sns.scatterplot(x=x_values, y=y_values, label="SAE", color="blue")

    if x_label is None:
        x_label = metric_x
    if y_label is None:
        y_label = metric_y

    # Add labels and title
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)

    # Add a trend line
    sns.regplot(x=x_values, y=y_values, scatter=False, color="red", label=f"r = {r:.4f}")

    plt.legend()

    plt.tight_layout()

    # Save the plot if output_filename is provided
    if output_filename:
        plt.savefig(output_filename, bbox_inches="tight")

    plt.show()

    # Print correlation coefficients
    print(f"Pearson correlation coefficient (r): {r:.4f}")
    print(f"Coefficient of determination (r²): {r_squared:.4f}")
    print(f"P-value: {p_value:.4f}")


# Example usage:
# plot_metric_scatter(plotting_results, metric_x="l0", metric_y="frac_recovered", title="L0 vs Fraction Recovered")

metric1 = f"tpp_auto_interp_threshold_{threshold}_total_metric"
metric2 = f"tpp_attrib_threshold_{threshold}_total_metric"
title = f""

x_label = "TPP with LLM Judge"
y_label = "TPP without LLM Judge"
output_filename = os.path.join(image_folder_name, f"tpp_comparison_{threshold}_{model_name}.png")

plot_metric_scatter(
    plotting_results,
    metric_x=metric1,
    metric_y=metric2,
    title=title,
    x_label=x_label,
    y_label=y_label,
    output_filename=output_filename,
)


In [None]:
first_key = next(iter(plotting_results.keys()))
print(plotting_results[first_key].keys())

metric_keys = [
    "l0",
    "frac_recovered",
    "tpp_attrib_threshold_20_total_metric",
    "tpp_attrib_threshold_50_total_metric",
    "tpp_attrib_threshold_500_total_metric",
    "tpp_auto_interp_threshold_20_total_metric",
    "tpp_auto_interp_threshold_50_total_metric",
    "scr_bias_shift_dir2_threshold_20",
    "scr_bias_shift_dir2_threshold_50",
    "scr_bias_shift_dir1_threshold_20",
    "scr_bias_shift_dir1_threshold_50",
    "scr_attrib_dir2_threshold_20",
    "scr_attrib_dir2_threshold_50",
    "scr_attrib_dir1_threshold_20",
    "scr_attrib_dir1_threshold_50",
]

metric_keys = [
    "l0",
    "frac_recovered",
    "tpp_attrib_threshold_50_total_metric",
    "tpp_auto_interp_threshold_50_total_metric",
    "scr_bias_shift_dir2_threshold_50",
    # "scr_bias_shift_dir1_threshold_50",
    "scr_attrib_dir2_threshold_50",
    # "scr_attrib_dir1_threshold_50",
]

plot_correlation_heatmap(plotting_results, metric_names=metric_keys, ae_names=None, annot=True)

In [None]:
custom_metric = f"scr_bias_shift_dir1_threshold_{threshold}"
custom_metric = f"scr_attrib_dir2_threshold_{threshold}"

title = f"L0 vs Loss Recovered vs {custom_metric}"

plot_3var_graph(plotting_results, title, custom_metric)
plot_2var_graph(plotting_results, custom_metric, title=title, y_label="Custom Metric")
plot_interactive_3var_graph(plotting_results, custom_metric)

In [None]:
print(plotting_results[first_key].keys())

In [None]:
metric1 = f"scr_bias_shift_dir1_threshold_{threshold}"
metric2 = f"scr_attrib_dir1_threshold_{threshold}"
title = f"{metric1} vs {metric2}"
title = ""

output_filename = os.path.join(image_folder_name, f"scr_comparison_{threshold}_{model_name}.png")

plot_metric_scatter(
    plotting_results,
    metric_x=metric1,
    metric_y=metric2,
    title=title,
    x_label="SHIFT with LLM Judge",
    y_label="SHIFT without LLM Judge",
    output_filename=output_filename,
)

In [None]:
metric1 = f"tpp_auto_interp_threshold_{threshold}_total_metric"
metric2 = f"scr_attrib_dir1_threshold_{threshold}"
title = f"{metric1} vs {metric2}"
plot_metric_scatter(plotting_results, metric_x=metric1, metric_y=metric2, title=title)