In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle

import circuits.othello_utils as othello_utils

def load_results_pickle_files(directory):
    data = []
    for filename in os.listdir(directory):
        if filename.endswith('.pkl') and "ablation" not in filename:
            with open(os.path.join(directory, filename), 'rb') as f:
                data.append(pickle.load(f))
    return data

def calculate_good_features(results, threshold):
    dt_r2 = torch.tensor(results["decision_tree"]["r2"])
    good_dt_r2 = (dt_r2 > threshold).sum().item()

    f1 = torch.tensor(results["binary_decision_tree"]["f1"])
    good_f1 = (f1 > threshold).sum().item()

    return good_dt_r2, good_f1

def extract_good_features(data, thresholds, func_name: str):
    nested_good_features = {}
    for run in data:
        hyperparams = run['hyperparameters']
        input_location = hyperparams['input_location']
        trainer_id = hyperparams['trainer_id']

        if input_location not in nested_good_features:
            nested_good_features[input_location] = {}
        if trainer_id not in nested_good_features[input_location]:
            nested_good_features[input_location][trainer_id] = {threshold: {} for threshold in thresholds}

        run = run['results']
        for layer, results in run.items():
            results = results[func_name]
            for threshold in thresholds:
                good_dt_r2, good_f1 = calculate_good_features(results, threshold)
                if layer not in nested_good_features[input_location][trainer_id][threshold]:
                    nested_good_features[input_location][trainer_id][threshold][layer] = {}
                nested_good_features[input_location][trainer_id][threshold][layer][func_name] = {
                    'dt_r2': good_dt_r2,
                    'f1': good_f1
                }

    return nested_good_features

def plot_good_features(nested_good_features, metric='f1', threshold=0.8):
    for input_location, trainer_ids in nested_good_features.items():
        plt.figure(figsize=(12, 6))
        for trainer_id, thresholds in trainer_ids.items():
            layers = sorted(thresholds[threshold].keys())
            custom_function_name = list(thresholds[threshold][layers[0]].keys())[0]  # Assume same custom function for all layers
            values = [thresholds[threshold][layer][custom_function_name][metric] for layer in layers]

            plt.plot(layers, values, 'o-', label=f'Trainer {trainer_id}')

        plt.xlabel("Layer")
        plt.ylabel(f"Number of good features ({metric} > {threshold})")
        plt.title(f"Good features per Layer ({input_location}, {metric} > {threshold})")
        plt.legend()
        plt.show()

# Usage
directory = 'decision_trees'
thresholds = [0.7, 0.8, 0.9]

func_name = othello_utils.games_batch_to_input_tokens_flipped_bs_valid_moves_classifier_input_BLC.__name__

data = load_results_pickle_files(directory)
nested_good_features = extract_good_features(data, thresholds, func_name)

# Plot for F1 score with threshold 0.8
plot_good_features(nested_good_features, metric='f1', threshold=0.8)

# Plot for Decision Tree R2 with threshold 0.8
# plot_good_features(nested_good_features, metric='dt_r2', threshold=0.8)