In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import einops
from collections import defaultdict
from typing import Callable, Optional
import os
import importlib
import pickle
import pandas as pd
from typing import List, Dict, Tuple, Optional, Union

import circuits.utils as utils
import circuits.othello_utils as othello_utils
import neuron_simulation.simulation_config as sim_config
import neuron_simulation.simulate_activations_with_dts as sim_activations

In [None]:
default_config = sim_config.selected_config
device = "cpu"

# If desired, you can run simulate_activations_with_dts.py from this cell
# Values from default config will also be used later on in the notebook to filter out saved pickle files

# Example of filtering out certain configurations

# default_config.n_batches = 2

# for combination in default_config.combinations:
#     combination.ablate_not_selected = [True]

# sim_activations.run_simulations(default_config)

In [None]:
def load_ablation_pickle_files(
    directory: str,
    dataset_size: int,
    ablation_method: str,
    ablate_not_selected: bool,
    add_error: bool,
):
    data = []
    for filename in os.listdir(directory):
        if filename.endswith(".pkl") and "ablation" in filename:
            with open(os.path.join(directory, filename), "rb") as f:
                single_data = pickle.load(f)

            hyperparams = single_data["hyperparameters"]
            if hyperparams["ablate_not_selected"] != ablate_not_selected:
                continue

            if hyperparams["input_location"] != "mlp_neuron":
                if hyperparams["add_error"] != add_error:
                    continue

            if hyperparams["ablation_method"] != ablation_method:
                continue

            if hyperparams["dataset_size"] != dataset_size:
                continue
            data.append(single_data)
    return data


def extract_mean_ablation_results(
    data: List[Dict],
    desired_metric: str,
    desired_layer_tuples: Optional[List[Tuple[int]]] = None,
) -> Dict:
    allowed_metrics = ["kl", "patch_accuracy"]
    if desired_metric not in allowed_metrics:
        raise ValueError(f"desired_metric must be one of {allowed_metrics}")

    nested_results = {}
    for run in data:
        hyperparams = run["hyperparameters"]
        input_location = hyperparams["input_location"] + "_mean_ablate"
        trainer_id = hyperparams["trainer_id"]

        nested_results[input_location] = {trainer_id: {}}

        for layer_tuple, func_results in run["results"].items():
            if desired_layer_tuples is not None and layer_tuple not in desired_layer_tuples:
                continue
            prev_result = None
            for idx, func_name in enumerate(func_results):

                result = func_results[func_name][desired_metric]
                nested_results[input_location][trainer_id][layer_tuple] = result

                if idx > 0:
                    assert prev_result == result

                prev_result = result

    return nested_results


def extract_ablation_results(
    data: List[Dict],
    custom_function_names: list[str],
    desired_metric: str,
    group_by: str,
    input_location_filter: Optional[str] = None,
    func_name_filter: Optional[str] = None,
    desired_layer_tuples: Optional[List[Tuple[int]]] = None,
) -> Dict:
    allowed_metrics = ["kl", "patch_accuracy"]
    if desired_metric not in allowed_metrics:
        raise ValueError(f"desired_metric must be one of {allowed_metrics}")

    group_by_options = ["input_location", "custom_function"]
    if group_by not in group_by_options:
        raise ValueError(f"group_by must be one of {group_by_options}")

    nested_results = {}
    for run in data:
        hyperparams = run["hyperparameters"]
        input_location = hyperparams["input_location"]
        trainer_id = hyperparams["trainer_id"]

        if input_location_filter is not None and input_location != input_location_filter:
            continue

        for custom_function_name in custom_function_names:

            if func_name_filter is not None and custom_function_name != func_name_filter:
                continue

            primary_key = input_location if group_by == "input_location" else custom_function_name

            if primary_key not in nested_results:
                nested_results[primary_key] = {}
            if trainer_id not in nested_results[primary_key]:
                nested_results[primary_key][trainer_id] = {}

            for layer_tuple, func_results in run["results"].items():
                print(layer_tuple, desired_layer_tuples)
                if desired_layer_tuples is not None and layer_tuple not in desired_layer_tuples:
                    continue
                if custom_function_name in func_results:
                    result = func_results[custom_function_name][desired_metric]
                    nested_results[primary_key][trainer_id][layer_tuple] = result

    return nested_results


def min_or_max_metric_per_layer(nested_data, min_or_max: str):
    allowed_options = ["min", "max"]
    if min_or_max not in allowed_options:
        raise ValueError(f"min_or_max must be one of {allowed_options}")

    sorted_metrics = {}
    for input_location, trainer_ids in nested_data.items():
        sorted_metrics[input_location] = {}
        all_layers = set()
        for layer_results in trainer_ids.values():
            all_layers.update(layer_results.keys())

        for layer in all_layers:
            if min_or_max == "min":
                min_kl = float("inf")
            else:
                min_kl = float("-inf")

            for trainer_id, layer_results in trainer_ids.items():
                if layer in layer_results:
                    if min_or_max == "min":
                        min_kl = min(min_kl, layer_results[layer])
                    else:
                        min_kl = max(min_kl, layer_results[layer])
            sorted_metrics[input_location][layer] = min_kl

    return sorted_metrics


def plot_grouped_metrics(nested_metrics, ablation_method):
    plt.figure(figsize=(12, 6))

    all_layers = set()
    for input_location, trainer_ids in nested_metrics.items():
        for trainer_id, layer_results in trainer_ids.items():
            all_layers.update(layer_results.keys())

    all_layers = sorted(all_layers)
    layer_labels = [str(layer) for layer in all_layers]

    for input_location, trainer_ids in nested_metrics.items():
        for trainer_id, layer_results in trainer_ids.items():
            values = [layer_results.get(layer, np.nan) for layer in all_layers]
            plt.plot(
                range(len(all_layers)),
                values,
                "o-",
                label=f"{input_location} - Trainer {trainer_id}",
            )

    plt.xlabel("Layer")
    plt.ylabel("Ablation Metric")
    plt.title(f"Ablation Results per Layer ({ablation_method} ablation)")
    plt.legend()

    plt.xticks(range(len(all_layers)), layer_labels, rotation=45, ha="right")

    plt.tight_layout()
    plt.show()


directory = "decision_trees"
custom_function_names = [
    # othello_utils.games_batch_to_input_tokens_flipped_bs_classifier_input_BLC.__name__,
    othello_utils.games_batch_to_board_state_classifier_input_BLC.__name__,
    othello_utils.games_batch_to_input_tokens_classifier_input_BLC.__name__,
    othello_utils.games_batch_to_input_tokens_flipped_classifier_input_BLC.__name__,
]

default_config = sim_config.selected_config
dataset_size = default_config.batch_size * default_config.n_batches
dataset_size = 100
dataset_size = 80

ablation_method = "dt"
ablate_not_selected = True
add_error = False

desired_layer_tuples = [
    (0, 1, 2),
    (1, 2, 3),
    (2, 3, 4),
    (3, 4, 5),
    (4, 5, 6),
    (5, 6, 7),
    # (0, 1, 2, 3, 4, 5, 6, 7),
]

desired_layer_tuples = []

for i in range(8):
    desired_layer_tuples.append((i,))

desired_ablation_metric = "patch_accuracy"
group_by = "input_location"
# group_by = "custom_function"

func_name_filter = custom_function_names[0]
# func_name_filter = None
# input_location_filter = "mlp_neuron"
input_location_filter = None

data = load_ablation_pickle_files(
    directory, dataset_size, ablation_method, ablate_not_selected, add_error
)

mean_ablate_data = load_ablation_pickle_files(directory, dataset_size, "mean", True, True)

example_dict = data[0]

print(example_dict["hyperparameters"].keys())
print(example_dict)

# Current dict structure: input_location / custom function -> trainer_id -> layer_tuple -> kl_divergence or other metric

metric_per_layers = extract_ablation_results(
    data,
    custom_function_names,
    desired_ablation_metric,
    group_by,
    input_location_filter=input_location_filter,
    func_name_filter=func_name_filter,
    desired_layer_tuples=desired_layer_tuples,
)
mean_ablate_per_layers = extract_mean_ablation_results(
    mean_ablate_data, desired_ablation_metric, desired_layer_tuples
)

for key in mean_ablate_per_layers.keys():
    metric_per_layers[key] = mean_ablate_per_layers[key]

print(metric_per_layers.keys())
print(metric_per_layers)
plot_grouped_metrics(metric_per_layers, ablation_method)

In [None]:
def graph_3var_results(metric_per_layer: dict, eval_df: pd.DataFrame, layer: int, input_location: str):
    # Filter the evaluation dataframe
    eval_df_filtered = eval_df[eval_df['layer_idx'] == layer]
    
    # Prepare data for plotting
    l0 = []
    frac_recovered = []
    kl_values = []
    tuple_layer = (layer,)
    
    for trainer_id, layer_results in metric_per_layer[input_location].items():

        if tuple_layer in layer_results:
            # Find matching row in eval_df_filtered
            eval_row = eval_df_filtered[eval_df_filtered['trainer_idx'] == int(trainer_id)]
            if not eval_row.empty:
                l0.append(eval_row['l0'].values[0])
                frac_recovered.append(eval_row['frac_recovered'].values[0])
                kl_values.append(layer_results[tuple_layer].item())
    
    # Create the 2D plot
    plt.figure(figsize=(12, 8))
    
    # Plot the points
    scatter = plt.scatter(l0, frac_recovered, c=kl_values, cmap='viridis', s=50)
    
    # Set labels and title
    plt.xlabel('L0', fontsize=12)
    plt.ylabel('Fraction Recovered', fontsize=12)
    plt.title(f'L0 vs Fraction Recovered vs Neuron Simulation KL Divergence\n'
              f'Layer {layer}, {input_location}', fontsize=14)
    
    # Add a color bar
    cbar = plt.colorbar(scatter)
    cbar.set_label(f'Layer {layer} KL Divergence', fontsize=12)
    
    # Add grid lines
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Improve layout
    plt.tight_layout()

    plt.savefig(f"l0_vs_frac_recovered_vs_kl_divergence_{input_location}_layer_{layer}.png")
    
    # Show the plot
    plt.show()

# Load evaluation data
sae_mlp_out_eval = pd.read_csv("sae_eval_csvs/sae_mlp_out_feature_evaluations.csv")
transcoder_eval = pd.read_csv("sae_eval_csvs/transcoder_evaluations.csv")


for layer in range(4, 8):

    # Plot for sae_mlp_out_feature
    graph_3var_results(metric_per_layers, sae_mlp_out_eval, layer, 'sae_mlp_out_feature')

    # Plot for transcoder
    graph_3var_results(metric_per_layers, transcoder_eval, layer, 'transcoder')

In [None]:
# Currently broken, just need to get the dictionary keys right
# This can be used to plot multiple thresholds for a single location / custom function combination

# colors = ['b', 'g', 'r']  # Colors for different thresholds
# markers = ['o', 's', '^']  # Markers for different thresholds

# for custom_function in custom_functions:
#     function_name = custom_function.__name__
    
#     plt.figure(figsize=(12, 7))
    
#     for i, threshold in enumerate(thresholds):
#         f1_counts = accuracy_by_layer[function_name][threshold]
        
#         plt.plot(intervention_layers, f1_counts, color=colors[i], marker=markers[i], 
#                  label=f'Threshold: {threshold}', linewidth=2, markersize=8)
        
#         # Optionally, add value labels on each point
#         for j, count in enumerate(f1_counts):
#             plt.annotate(str(count), (intervention_layers[j], count), 
#                          textcoords="offset points", xytext=(0,5), ha='center', 
#                          fontsize=8, color=colors[i])
    
#     plt.title(f'Number of Neurons with F1 > Threshold for \n{function_labels[custom_function]} by Layer \n{dataset_size} datapoints Input location: {input_location} depth: {max_depth}', fontsize=14)
#     plt.xlabel('Layer Number', fontsize=12)
#     plt.ylabel('F1 Count', fontsize=12)
#     plt.legend(loc='best', fontsize=10)
#     plt.grid(True, linestyle='--', alpha=0.7)
    
#     plt.tight_layout()
#     output_filename = f"images/{input_location}_{function_name}_inputs_{dataset_size}_depth_{max_depth}_f1_count_by_layer_all_thresholds.png"
#     plt.savefig(output_filename, dpi=300, bbox_inches='tight')
#     plt.show()

# print("All graphs have been created and saved.")

In [None]:
# Currently broken, just need to get the dictionary keys right

# layer = 1
# neuron_idx = 421
# custom_function_name = custom_functions[0].__name__

# decision_tree_filename = f"decision_trees/decision_trees_{input_location}_{dataset_size}.pkl"

# with open(decision_tree_filename, "rb") as f:
#     decision_tree = pickle.load(f)

# layer_dt = decision_tree[layer][custom_function_name]['decision_tree']['model']
# layer_dt = decision_tree[layer][custom_function_name]['binary_decision_tree']['model']

# games_BLC = train_data[custom_function_name]

# feature_names = []

# X_binary_train, X_binary_test, y_binary_train, y_binary_test = prepare_data(
#     games_BLC, binary_acts[layer]
# )
# accuracy, precision, recall, f1 = calculate_binary_metrics(
#     decision_tree[layer][custom_function_name]["binary_decision_tree"]["model"], X_binary_test, y_binary_test
# )

# print(f"Neuron 421 F1: {f1[neuron_idx]}")

# for i in range(games_BLC.shape[2]):
#     if i < 64:
#         square = idx_to_square_notation(i)
#         feature_names.append(f"Input_{square}")
#     elif i < 128:
#         j = i - 64
#         square = idx_to_square_notation(j)
#         feature_names.append(f"Occupied_{square}")
#     else:
#         feature_names.append(f"Output_{i}")

# print_decision_tree_rules(layer_dt, feature_names=feature_names, neuron_index=neuron_idx, max_depth=5)

