In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
from typing import List, Dict, Any, Tuple
import pandas as pd

import circuits.othello_utils as othello_utils
import neuron_simulation.simulation_config as sim_config
import neuron_simulation.simulate_activations_with_dts as sim_activations

NOTE: This notebook is deprecated in favor of `graph_dt_results.ipynb`.

In [None]:
default_config = sim_config.selected_config

def set_notebook_path(config: sim_config.SimulationConfig):
    config.repo_dir = "../"
    config.output_location = ""
    return config


# If desired, you can run simulate_activations_with_dts.py from this cell
# Values from default config will also be used later on in the notebook to filter out saved pickle files

# Example of filtering out certain configurations

# default_config.n_batches = 2

# for combination in default_config.combinations:
#     combination.ablate_not_selected = [True]

# sim_activations.run_simulations(default_config)

In [None]:
def load_results_pickle_files(directory: str, dataset_size: int) -> List[Dict[str, Any]]:
    data = []
    for filename in os.listdir(directory):
        if filename.endswith('.pkl') and "ablation" not in filename:
            with open(os.path.join(directory, filename), 'rb') as f:
                single_data = pickle.load(f)

            if single_data['hyperparameters']['dataset_size'] != dataset_size:
                continue
            data.append(single_data)
    return data

def calculate_good_features(results: Dict[str, Any], threshold: float) -> Tuple[int, int]:
    dt_r2 = torch.tensor(results["decision_tree"]["r2"])
    good_dt_r2 = (dt_r2 > threshold).sum().item()

    f1 = torch.tensor(results["binary_decision_tree"]["f1"])
    good_f1 = (f1 > threshold).sum().item()

    return good_dt_r2, good_f1

def extract_good_features(data: List[Dict[str, Any]], threshold: float, func_name: str) -> Dict[str, Dict[str, Dict[int, Dict[str, Dict[str, int]]]]]:
    nested_good_features: Dict[str, Dict[str, Dict[int, Dict[str, Dict[str, int]]]]] = {}
    for run in data:
        hyperparams = run['hyperparameters']
        input_location: str = hyperparams['input_location']
        trainer_id: str = hyperparams['trainer_id']

        if input_location not in nested_good_features:
            nested_good_features[input_location] = {}
        if trainer_id not in nested_good_features[input_location]:
            nested_good_features[input_location][trainer_id] = {}

        run = run['results']
        for layer, results in run.items():
            results = results[func_name]
            good_dt_r2, good_f1 = calculate_good_features(results, threshold)
            if layer not in nested_good_features[input_location][trainer_id]:
                nested_good_features[input_location][trainer_id][layer] = {}
            nested_good_features[input_location][trainer_id][layer] = good_f1
            

    return nested_good_features

def max_counts_per_layer(nested_counts):
    max_counts = {}
    for input_location, trainer_ids in nested_counts.items():
        max_counts[input_location] = {}
        all_layers = set()
        for layer_results in trainer_ids.values():
            all_layers.update(layer_results.keys())
        
        for layer in all_layers:
            max_count = 0
            for trainer_id, layer_results in trainer_ids.items():
                if layer in layer_results:
                    max_count = max(max_count, layer_results[layer])
            max_counts[input_location][layer] = max_count
    
    return max_counts

def plot_good_features(nested_good_features: Dict[str, Dict[str, Dict[int, Dict[str, Dict[str, int]]]]], metric: str = 'f1', threshold: float = 0.8):
    for input_location, trainer_ids in nested_good_features.items():
        plt.figure(figsize=(12, 6))
        for trainer_id, layers in trainer_ids.items():
            sorted_layers = sorted(layers.keys())
            values = [layers[layer] for layer in sorted_layers]

            plt.plot(sorted_layers, values, 'o-', label=f'Trainer {trainer_id}')

        plt.xlabel("Layer")
        plt.ylabel(f"Number of good features ({metric} > {threshold})")
        plt.title(f"Good features per Layer ({input_location}, {metric} > {threshold})")
        plt.legend()
        plt.show()

def plot_max_counts(max_counts):
    plt.figure(figsize=(12, 6))
    for input_location, layer_results in max_counts.items():
        layers = sorted(layer_results.keys())
        counts = [layer_results[layer] for layer in layers]
        
        plt.plot(layers, counts, 'o-', label=input_location)
    
    plt.xlabel("Layer")
    plt.ylabel("Max Count")
    plt.title(f"Max Count per Layer)")
    plt.legend()
    plt.show()


# Usage
directory: str = 'decision_trees'
threshold: float = 0.7

func_name: str = othello_utils.games_batch_to_input_tokens_flipped_bs_valid_moves_probe_classifier_input_BLC.__name__

# dataset_size = default_config.batch_size * default_config.n_batches
dataset_size = 100

data: List[Dict[str, Any]] = load_results_pickle_files(directory, dataset_size=dataset_size)
nested_good_features: Dict[str, Dict[str, Dict[int, Dict[str, Dict[str, int]]]]] = extract_good_features(data, threshold, func_name)

# Plot for F1 score
# plot_good_features(nested_good_features, metric='f1', threshold=threshold)

# Plot for Decision Tree R2
plot_good_features(nested_good_features, metric='dt_r2', threshold=threshold)

In [None]:
max_counts = max_counts_per_layer(nested_good_features)

for input_location, layer_results in max_counts.items():
    print(input_location, layer_results)

plot_max_counts(max_counts)

In [None]:
def graph_3var_results(nested_good_features: dict, eval_df: pd.DataFrame, layer: int, input_location: str):
    # Filter the evaluation dataframe
    eval_df_filtered = eval_df[eval_df['layer_idx'] == layer]
    
    # Prepare data for plotting
    l0 = []
    frac_recovered = []
    kl_values = []
    
    for trainer_id, layer_results in nested_good_features[input_location].items():
        if layer in layer_results:
            # Find matching row in eval_df_filtered
            eval_row = eval_df_filtered[eval_df_filtered['trainer_idx'] == int(trainer_id)]
            if not eval_row.empty:
                l0.append(eval_row['l0'].values[0])
                frac_recovered.append(eval_row['frac_recovered'].values[0])
                kl_values.append(np.mean(layer_results[layer]))
    
    # Create the 2D plot
    plt.figure(figsize=(12, 8))
    
    # Plot the points
    scatter = plt.scatter(l0, frac_recovered, c=kl_values, cmap='viridis', s=50)
    
    # Set labels and title
    plt.xlabel('L0', fontsize=12)
    plt.ylabel('Fraction Recovered', fontsize=12)
    plt.title(f'L0 vs Fraction Recovered vs Neuron Simulation KL Divergence\n'
              f'Layer {layer}, {input_location}', fontsize=14)
    
    # Add a color bar
    cbar = plt.colorbar(scatter)
    cbar.set_label(f'Layer {layer} KL Divergence', fontsize=12)
    
    # Add grid lines
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Improve layout
    plt.tight_layout()

    plt.savefig(f"l0_vs_frac_recovered_vs_kl_divergence_{input_location}_layer_{layer}.png")
    
    # Show the plot
    plt.show()

# Load evaluation data
sae_mlp_out_eval = pd.read_csv("sae_eval_csvs/sae_mlp_out_feature_evaluations.csv")
transcoder_eval = pd.read_csv("sae_eval_csvs/transcoder_evaluations.csv")

# Example usage

for layer in range(8):

    # Plot for sae_mlp_out_feature
    graph_3var_results(nested_good_features, sae_mlp_out_eval, layer, 'sae_mlp_out_feature')

    # Plot for transcoder
    graph_3var_results(nested_good_features, transcoder_eval, layer, 'transcoder')