In [None]:
import pickle
import pandas as pd
from typing import Callable

import circuits.eval_sae_as_classifier as eval_sae
import circuits.analysis as analysis
import circuits.test_board_reconstruction as test_board_reconstruction
import circuits.get_eval_results as get_eval_results

In [None]:
def initialize_dataframe() -> pd.DataFrame:
    return pd.DataFrame(
        columns=[
            "autoencoder_group_path",
            "autoencoder_path",
            "reconstruction_file",
            "eval_results_n_inputs",
            "l0",
            "l1_loss",
            "l2_loss",
            "frac_alive",
            "frac_variance_explained",
            "cossim",
            "l2_ratio",
            "num_alive_features",
            "board_reconstruction_board_count",
            "best_idx",
            "zero_L0",
            "zero_f1_score",
            "best_L0",
            "best_f1_score",
            "zero_num_true_positive_squares",
            "best_num_true_positive_squares",
            "zero_num_false_positive_squares",
            "best_num_false_positive_squares",
            "zero_percent_active_classifiers",
            "best_percent_active_classifiers",
            "zero_classifiers_per_token",
            "best_classifiers_per_token",
            "zero_classified_per_token",
            "best_classified_per_token",
        ]
    )

def append_results(
    eval_results: dict,
    board_reconstruction_results: dict,
    custom_functions: list[Callable],
    df: pd.DataFrame,
    autoencoder_group_path: str,
    autoencoder_path: str,
    reconstruction_file: str,
) -> pd.DataFrame:
    
    print(eval_results)

    for custom_function in custom_functions:
        function_name = custom_function.__name__

        best_idx = board_reconstruction_results[custom_function.__name__]["f1_score"].argmax()

        new_row = {
            "autoencoder_group_path": autoencoder_group_path,
            "autoencoder_path": autoencoder_path,
            "reconstruction_file": reconstruction_file,
            "eval_results_n_inputs": eval_results["hyperparameters"]['n_inputs'],
            "l0": eval_results['eval_results']["l0"],
            "l1_loss": eval_results['eval_results']["l1_loss"],
            "l2_loss": eval_results['eval_results']["l2_loss"],
            "frac_alive": eval_results['eval_results']["frac_alive"],
            "frac_variance_explained": eval_results['eval_results']["frac_variance_explained"],
            "cossim": eval_results['eval_results']["cossim"],
            "l2_ratio": eval_results['eval_results']["l2_ratio"],
            "num_alive_features": board_reconstruction_results["alive_features"].shape[0],
            "board_reconstruction_board_count": board_reconstruction_results[function_name]["num_boards"],
            "best_idx": best_idx.item(),
            "zero_L0": board_reconstruction_results["active_per_token"][0].item(),
            "best_L0": board_reconstruction_results["active_per_token"][best_idx].item(),
            "zero_f1_score": board_reconstruction_results[function_name]["f1_score"][0].item(),
            "best_f1_score": board_reconstruction_results[function_name]["f1_score"][
                best_idx
            ].item(),
            "zero_num_true_positive_squares": board_reconstruction_results[function_name][
                "num_true_positive_squares"
            ][0].item(),
            "best_num_true_positive_squares": board_reconstruction_results[function_name][
                "num_true_positive_squares"
            ][best_idx].item(),
            "zero_num_false_positive_squares": board_reconstruction_results[function_name][
                "num_false_positive_squares"
            ][0].item(),
            "best_num_false_positive_squares": board_reconstruction_results[function_name][
                "num_false_positive_squares"
            ][best_idx].item(),
            "zero_percent_active_classifiers": (
                board_reconstruction_results[function_name]["classifiers_per_token"][0]
                / board_reconstruction_results["active_per_token"][0]
            ).item(),
            "best_percent_active_classifiers": (
                board_reconstruction_results[function_name]["classifiers_per_token"][best_idx]
                / board_reconstruction_results["active_per_token"][best_idx]
            ).item(),
            "zero_classifiers_per_token": board_reconstruction_results[function_name][
                "classifiers_per_token"
            ][0].item(),
            "best_classifiers_per_token": board_reconstruction_results[function_name][
                "classifiers_per_token"
            ][best_idx].item(),
            "zero_classified_per_token": board_reconstruction_results[function_name][
                "classified_per_token"
            ][0].item(),
            "best_classified_per_token": board_reconstruction_results[function_name][
                "classified_per_token"
            ][best_idx].item(),
        }

        new_row_df = pd.DataFrame([new_row])
        df = pd.concat([df, new_row_df], ignore_index=True)

    return df

Basically, just set `autoencoder_group_paths` and various hyperparameters and run it. If you already ran, for example, `eval_sae_as_classifier` and don't want to run it again, set `run_eval_sae` to False. Note that in this case, `eval_results_n_inputs` must match in order for it to load the file saved from the previous run.

By default, we `save_results`, which means each of the 4 functions saves a `.pkl` file. By default, we also aggregate and format some of the results into a csv `output_file`. If you already have results `.pkl` files and want a csv, you can set all `run_...` to False, and it will load the results and put them into a csv. 

In [None]:
import importlib
importlib.reload(eval_sae)
importlib.reload(analysis)
importlib.reload(test_board_reconstruction)
importlib.reload(get_eval_results)
import circuits.chess_utils as chess_utils

# NOTE: This script makes a major assumption here: That all autoencoders in a given group are trained on chess XOR Othello
# We do this so we don't have to reconstruct the dataset for each autoencoder in the group
autoencoder_group_paths = ["../autoencoders/othello_layer5_ef4/"]
# autoencoder_group_paths = ["../autoencoders/chess_layer5/"]

eval_sae_n_inputs = 1000
batch_size = 10
device = "cuda"
model_path = "../models/"
save_results = False

eval_results_n_inputs = 1000
board_reconstruction_n_inputs = 1000

analysis_high_threshold = 0.95
analysis_low_threshold = 0.1
analysis_significance_threshold = 10

run_eval_results = True # We don't check for this as eval_results are pretty quick to collect

# To skip any of the following steps, set the corresponding variable to False
# The results must have been saved previously
run_eval_sae = True
run_analysis = True
run_board_reconstruction = True

dataset_size = max(eval_sae_n_inputs, eval_results_n_inputs, board_reconstruction_n_inputs)

if dataset_size == eval_results_n_inputs:
    dataset_size *= 2

output_file = "results.csv"

df = initialize_dataframe()


for autoencoder_group_path in autoencoder_group_paths:
    othello = eval_sae.check_if_autoencoder_is_othello(autoencoder_group_path)

    indexing_functions = eval_sae.get_recommended_indexing_functions(othello)
    indexing_function = indexing_functions[0]
    indexing_function = chess_utils.get_othello_even_list_indices

    custom_functions = eval_sae.get_recommended_custom_functions(othello)

    model_name = eval_sae.get_model_name(othello)

    print("Constructing evaluation dataset")
    data = eval_sae.construct_dataset(othello, custom_functions, dataset_size, device, models_path=model_path)

    folders = eval_sae.get_nested_folders(autoencoder_group_path)

    for autoencoder_path in folders:

        # For debugging
        # if "ef=4_lr=1e-03_l1=1e-01_layer=5" not in autoencoder_path:
        #     continue

        eval_results = get_eval_results.get_evals(
            autoencoder_path,
            eval_results_n_inputs,
            device,
            model_path,
            model_name,
            data.copy(),
            othello=othello,
            save_results=save_results,
        )


        expected_aggregation_output_location = eval_sae.get_output_location(
                autoencoder_path, n_inputs=eval_sae_n_inputs, indexing_function=indexing_function
            )
        
        if run_eval_sae:
            print("Aggregating", autoencoder_path)
            aggregation_results = (
                eval_sae.aggregate_statistics(
                    custom_functions=custom_functions,
                    autoencoder_path=autoencoder_path,
                    n_inputs=eval_sae_n_inputs,
                    batch_size=batch_size,
                    device=device,
                    model_path=model_path,
                    model_name=model_name,
                    data=data.copy(),
                    indexing_function=indexing_function,
                    othello=othello,
                    save_results=save_results,
                )
            )
        else:
            with open(expected_aggregation_output_location, "rb") as f:
                aggregation_results = pickle.load(f)

        expected_feature_labels_output_location = expected_aggregation_output_location.replace(
            "results.pkl", "feature_labels.pkl"
        )
        if run_analysis:
            feature_labels = analysis.analyze_results_dict(
                aggregation_results,
                output_path=expected_feature_labels_output_location,
                device=device,
                high_threshold=analysis_high_threshold,
                low_threshold=analysis_low_threshold,
                significance_threshold=analysis_significance_threshold,
                verbose=False,
                print_results=False,
                save_results=save_results,
            )
        else:
            with open(expected_feature_labels_output_location, "rb") as f:
                feature_labels = pickle.load(f)

        expected_reconstruction_output_location = expected_aggregation_output_location.replace(
            "results.pkl", "reconstruction.pkl"
        )

        if run_board_reconstruction:
            print("Testing board reconstruction")
            board_reconstruction_results = test_board_reconstruction.test_board_reconstructions(
                    custom_functions=custom_functions,
                    autoencoder_path=autoencoder_path,
                    feature_labels=feature_labels,
                    output_file=expected_reconstruction_output_location,
                    n_inputs=board_reconstruction_n_inputs,
                    batch_size=batch_size,
                    device=device,
                    model_name=model_name,
                    data=data.copy(),
                    othello=othello,
                    print_results=False,
                    save_results=save_results,
            )
        else:
            with open(expected_reconstruction_output_location, "rb") as f:
                board_reconstruction_results = pickle.load(f)

        df = append_results(
            eval_results,
            board_reconstruction_results,
            custom_functions,
            df,
            autoencoder_group_path,
            autoencoder_path,
            expected_reconstruction_output_location,
        )

        print("Finished", autoencoder_path)

        # Save the dataframe after each autoencoder so we don't lose data if the script crashes
        df.to_csv(output_file)

        

Example of gathering top k contexts

In [None]:
import torch
import circuits.chess_interp as chess_interp
importlib.reload(chess_interp)

torch.set_grad_enabled(False)

autoencoder_group_path = autoencoder_group_paths[0]

othello = eval_sae.check_if_autoencoder_is_othello(autoencoder_group_path)

indexing_functions = eval_sae.get_recommended_indexing_functions(othello)
indexing_function = indexing_functions[0]

custom_functions = eval_sae.get_recommended_custom_functions(othello)

model_name = eval_sae.get_model_name(othello)

print("Constructing evaluation dataset")
data = eval_sae.construct_dataset(othello, custom_functions, dataset_size, device, models_path=model_path)

dataset_size = dataset_size * 2  # x2 to make sure we have enough data for loss_recovered()

data, ae_bundle, pgn_strings, encoded_inputs = eval_sae.prep_firing_rate_data(
    autoencoder_path, dataset_size, model_path, model_name, data, device, dataset_size, othello
)

dims = torch.tensor([10], device=device)
chess_interp.examine_dimension_chess(ae_bundle, 100, dims)