In [None]:
import pickle
import pandas as pd
from typing import Callable
import torch

import circuits.eval_sae_as_classifier as eval_sae
import circuits.analysis as analysis
import circuits.test_board_reconstruction as test_board_reconstruction
import circuits.get_eval_results as get_eval_results

In [None]:
# For multi-GPU evaluation
from collections import deque
from joblib import Parallel, delayed

from circuits.utils import to_device

N_GPUS = 1
RESOURCE_STACK = deque([f"cuda:{i}" for i in range(N_GPUS)])

In [None]:
def initialize_dataframe(custom_functions: list[Callable]) -> pd.DataFrame:

    constant_columns = [
        "autoencoder_group_path",
        "autoencoder_path",
        "reconstruction_file",
        "eval_sae_n_inputs",
        "eval_results_n_inputs",
        "board_reconstruction_n_inputs",
        "l0",
        "l1_loss",
        "l2_loss",
        "frac_alive",
        "frac_variance_explained",
        "cossim",
        "l2_ratio",
        "loss_original",
        "loss_reconstructed",
        "loss_zero",
        "frac_recovered",
        "num_alive_features"
    ]

    template_columns = [
        "board_reconstruction_board_count",
        "num_squares",
        "best_idx",
        "zero_L0",
        "zero_f1_score",
        "best_L0",
        "best_f1_score",
        "zero_num_true_positive_squares",
        "best_num_true_positive_squares",
        "zero_num_false_positive_squares",
        "best_num_false_positive_squares",
        # "zero_percent_active_classifiers",
        # "best_percent_active_classifiers",
        # "zero_classifiers_per_token",
        # "best_classifiers_per_token",
        # "zero_classified_per_token",
        # "best_classified_per_token",
    ]

    # Generate the custom columns based on the custom functions
    custom_columns = [
        f"{func.__name__}_{template_col}"
        for func in custom_functions
        for template_col in template_columns
    ]
    

    # Combine the constant columns with the custom columns
    all_columns = constant_columns + custom_columns

    # Create and return the DataFrame with the combined columns
    return pd.DataFrame(columns=all_columns)

def append_results(
    eval_results: dict,
    aggregate_results: dict,
    board_reconstruction_results: dict,
    custom_functions: list[Callable],
    df: pd.DataFrame,
    autoencoder_group_path: str,
    autoencoder_path: str,
    reconstruction_file: str,
) -> pd.DataFrame:
    
    # Initialize the new row with constant fields
    new_row = {
        "autoencoder_group_path": autoencoder_group_path,
        "autoencoder_path": autoencoder_path,
        "reconstruction_file": reconstruction_file,
        "eval_sae_n_inputs": aggregate_results["hyperparameters"]['n_inputs'],
        "eval_results_n_inputs": eval_results["hyperparameters"]['n_inputs'],
        "board_reconstruction_n_inputs": board_reconstruction_results["hyperparameters"]['n_inputs'],
        "l0": eval_results['eval_results']["l0"],
        "l1_loss": eval_results['eval_results']["l1_loss"],
        "l2_loss": eval_results['eval_results']["l2_loss"],
        "frac_alive": eval_results['eval_results']["frac_alive"],
        "frac_variance_explained": eval_results['eval_results']["frac_variance_explained"],
        "cossim": eval_results['eval_results']["cossim"],
        "l2_ratio": eval_results['eval_results']["l2_ratio"],
        "loss_original": eval_results['eval_results']["loss_original"],
        "loss_reconstructed": eval_results['eval_results']["loss_reconstructed"],
        "loss_zero": eval_results['eval_results']["loss_zero"],
        "frac_recovered": eval_results['eval_results']["frac_recovered"],
        "num_alive_features": board_reconstruction_results["alive_features"].shape[0],
    }
    
    for custom_function in custom_functions:
        function_name = custom_function.__name__
        best_idx = board_reconstruction_results[function_name]["f1_score"].argmax()

        # Add the custom fields to the new row
        new_row[f"{function_name}_board_reconstruction_board_count"] = board_reconstruction_results[function_name]["num_boards"]
        new_row[f"{function_name}_num_squares"] = board_reconstruction_results[function_name]["num_squares"]
        new_row[f"{function_name}_best_idx"] = best_idx.item()
        new_row[f"{function_name}_zero_L0"] = board_reconstruction_results["active_per_token"][0].item()
        new_row[f"{function_name}_best_L0"] = board_reconstruction_results["active_per_token"][best_idx].item()
        new_row[f"{function_name}_zero_f1_score"] = board_reconstruction_results[function_name]["f1_score"][0].item()
        new_row[f"{function_name}_best_f1_score"] = board_reconstruction_results[function_name]["f1_score"][best_idx].item()
        new_row[f"{function_name}_zero_num_true_positive_squares"] = board_reconstruction_results[function_name]["num_true_positive_squares"][0].item()
        new_row[f"{function_name}_best_num_true_positive_squares"] = board_reconstruction_results[function_name]["num_true_positive_squares"][best_idx].item()
        new_row[f"{function_name}_zero_num_false_positive_squares"] = board_reconstruction_results[function_name]["num_false_positive_squares"][0].item()
        new_row[f"{function_name}_best_num_false_positive_squares"] = board_reconstruction_results[function_name]["num_false_positive_squares"][best_idx].item()
        # These following columns aren't currently used
        # new_row[f"{function_name}_zero_percent_active_classifiers"] = (
        #     board_reconstruction_results[function_name]["classifiers_per_token"][0]
        #     / board_reconstruction_results["active_per_token"][0]
        # ).item()
        # new_row[f"{function_name}_best_percent_active_classifiers"] = (
        #     board_reconstruction_results[function_name]["classifiers_per_token"][best_idx]
        #     / board_reconstruction_results["active_per_token"][best_idx]
        # ).item()
        # new_row[f"{function_name}_zero_classifiers_per_token"] = board_reconstruction_results[function_name]["classifiers_per_token"][0].item()
        # new_row[f"{function_name}_best_classifiers_per_token"] = board_reconstruction_results[function_name]["classifiers_per_token"][best_idx].item()
        # new_row[f"{function_name}_zero_classified_per_token"] = board_reconstruction_results[function_name]["classified_per_token"][0].item()
        # new_row[f"{function_name}_best_classified_per_token"] = board_reconstruction_results[function_name]["classified_per_token"][best_idx].item()


    new_row_df = pd.DataFrame([new_row])

    # Check if the original DataFrame is empty
    if df.empty:
        df = new_row_df
    else:
        df = pd.concat([df, new_row_df], ignore_index=True)
    return df

Basically, just set `autoencoder_group_paths` and various hyperparameters and run it. If you already ran, for example, `eval_sae_as_classifier` and don't want to run it again, set `run_eval_sae` to False. Note that in this case, `eval_results_n_inputs` must match in order for it to load the file saved from the previous run.

By default, we `save_results`, which means each of the 4 functions saves a `.pkl` file. By default, we also aggregate and format some of the results into a csv `output_file`. If you already have results `.pkl` files and want a csv, you can set all `run_...` to False, and it will load the results and put them into a csv. 

In [None]:
import importlib

importlib.reload(eval_sae)
importlib.reload(analysis)
importlib.reload(test_board_reconstruction)
importlib.reload(get_eval_results)
import circuits.chess_utils as chess_utils

importlib.reload(chess_utils)

# NOTE: This script makes a major assumption here: That all autoencoders in a given group are trained on chess XOR Othello
# We do this so we don't have to reconstruct the dataset for each autoencoder in the group
# autoencoder_group_paths = ["../autoencoders/othello_layer5_ef4/"]
# autoencoder_group_paths = ["../autoencoders/chess_layer5/"]
autoencoder_group_paths = ["../autoencoders/group-2024-05-07/"]


eval_sae_n_inputs = 1000
batch_size = 100
#device = "cuda"
model_path = "../models/"
save_results = True

eval_results_n_inputs = 1000
board_reconstruction_n_inputs = 1000

analysis_high_threshold = 0.95
analysis_low_threshold = 0.1
analysis_significance_threshold = 10

run_eval_results = True  # We don't check for this as eval_results are pretty quick to collect

# To skip any of the following steps, set the corresponding variable to False
# The results must have been saved previously
run_eval_sae = True
run_analysis = True
run_board_reconstruction = True

mask = False

dataset_size = max(eval_sae_n_inputs, eval_results_n_inputs, board_reconstruction_n_inputs)

# Dataset size must be larger than eval_results_n_inputs or we reach the end of the data stream
if dataset_size == eval_results_n_inputs:
    dataset_size *= 4

for autoencoder_group_path in autoencoder_group_paths:
    othello = eval_sae.check_if_autoencoder_is_othello(autoencoder_group_path)

    indexing_functions = eval_sae.get_recommended_indexing_functions(othello)
    indexing_function = indexing_functions[0]

    custom_functions = eval_sae.get_recommended_custom_functions(othello)
    # Example custom functions
    custom_functions = [
        chess_utils.board_to_piece_state,
        chess_utils.board_to_pin_state,
        chess_utils.board_to_has_castling_rights,
        chess_utils.board_to_check_state,
        chess_utils.board_to_can_check_next,
    ]

    model_name = eval_sae.get_model_name(othello)

    # If True, precompute everything and store it in VRAM. Faster, but far higher memory usage
    # If True, VRAM scales with batch size and n_inputs
    # If False, VRAM scales with batch size only
    precompute = True

    print("Constructing evaluation dataset")
    device = RESOURCE_STACK.pop()
    data = eval_sae.construct_dataset(
        othello, custom_functions, dataset_size, device, models_path=model_path, precompute_dataset=precompute
    )
    RESOURCE_STACK.append(device)
    del device

    folders = eval_sae.get_nested_folders(autoencoder_group_path)

    def full_eval_pipeline(autoencoder_path):

        df = initialize_dataframe(custom_functions)
        
        # For debugging
        # if "ef=4_lr=1e-03_l1=1e-01_layer=5" not in autoencoder_path:
        #     return df

        # Grab a GPU off the stack to use
        device = RESOURCE_STACK.pop()

        # If this is set, everything below should be reproducible
        # Then we can just save results from 1 run, make optimizations, and check that the results are the same
        # The determinism is only needed for getting activations from the activation buffer for finding alive features
        torch.manual_seed(0)
        eval_results = get_eval_results.get_evals(
            autoencoder_path,
            eval_results_n_inputs,
            device,
            model_path,
            model_name,
            to_device(data.copy(), device),
            othello=othello,
            save_results=save_results,
        )

        expected_aggregation_output_location = eval_sae.get_output_location(
            autoencoder_path, n_inputs=eval_sae_n_inputs, indexing_function=indexing_function
        )

        if run_eval_sae:
            print("Aggregating", autoencoder_path)
            aggregation_results = eval_sae.aggregate_statistics(
                custom_functions=custom_functions,
                autoencoder_path=autoencoder_path,
                n_inputs=eval_sae_n_inputs,
                batch_size=batch_size,
                device=device,
                model_path=model_path,
                model_name=model_name,
                data=to_device(data.copy(), device),
                indexing_function=indexing_function,
                othello=othello,
                save_results=save_results,
                precomputed=precompute,
            )
        else:
            with open(expected_aggregation_output_location, "rb") as f:
                aggregation_results = pickle.load(f)

        expected_feature_labels_output_location = expected_aggregation_output_location.replace(
            "results.pkl", "feature_labels.pkl"
        )
        if run_analysis:
            feature_labels = analysis.analyze_results_dict(
                aggregation_results,
                output_path=expected_feature_labels_output_location,
                device=device,
                high_threshold=analysis_high_threshold,
                low_threshold=analysis_low_threshold,
                significance_threshold=analysis_significance_threshold,
                verbose=False,
                print_results=False,
                save_results=save_results,
                mask=mask,
            )
        else:
            with open(expected_feature_labels_output_location, "rb") as f:
                feature_labels = pickle.load(f)

        expected_reconstruction_output_location = expected_aggregation_output_location.replace(
            "results.pkl", "reconstruction.pkl"
        )

        if run_board_reconstruction:
            print("Testing board reconstruction")
            board_reconstruction_results = test_board_reconstruction.test_board_reconstructions(
                custom_functions=custom_functions,
                autoencoder_path=autoencoder_path,
                feature_labels=feature_labels,
                output_file=expected_reconstruction_output_location,
                n_inputs=board_reconstruction_n_inputs,
                batch_size=batch_size,
                device=device,
                model_name=model_name,
                data=to_device(data.copy(), device),
                othello=othello,
                print_results=False,
                save_results=save_results, 
                precomputed=precompute,
                mask=mask,
            )
        else:
            with open(expected_reconstruction_output_location, "rb") as f:
                board_reconstruction_results = pickle.load(f)

        
        df = append_results(
            eval_results,
            aggregation_results,
            board_reconstruction_results,
            custom_functions,
            df,
            autoencoder_group_path,
            autoencoder_path,
            expected_reconstruction_output_location,
        )

        print("Finished", autoencoder_path)

        # Save the dataframe after each autoencoder so we don't lose data if the script crashes
        output_file = autoencoder_path + "/" + "results.csv"
        df.to_csv(output_file)

        # Put the GPU back on the stack after we're done
        RESOURCE_STACK.append(device)
        return df

    dfs = Parallel(n_jobs=N_GPUS, require="sharedmem")(
        delayed(full_eval_pipeline)(autoencoder_path) for autoencoder_path in folders
    )
    pd.concat(dfs, axis=0, ignore_index=True).to_csv(autoencoder_group_path + "results.csv")

Example of gathering top k contexts

In [None]:
# import torch
# import circuits.chess_interp as chess_interp
# importlib.reload(chess_interp)

# torch.set_grad_enabled(False)

# autoencoder_group_path = autoencoder_group_paths[0]

# othello = eval_sae.check_if_autoencoder_is_othello(autoencoder_group_path)

# indexing_functions = eval_sae.get_recommended_indexing_functions(othello)
# indexing_function = indexing_functions[0]

# custom_functions = eval_sae.get_recommended_custom_functions(othello)

# model_name = eval_sae.get_model_name(othello)

# device = RESOURCE_STACK.pop()
# print("Constructing evaluation dataset")
# data = eval_sae.construct_dataset(othello, custom_functions, dataset_size, device, models_path=model_path)


# dataset_size = dataset_size * 2  # x2 to make sure we have enough data for loss_recovered()


# # TODO: set `autoencoder_path`
# data, ae_bundle, pgn_strings, encoded_inputs = eval_sae.prep_firing_rate_data(
#     autoencoder_path, dataset_size, model_path, model_name, data, device, dataset_size, othello
# )

# dims = torch.tensor([10], device=device)
# chess_interp.examine_dimension_chess(ae_bundle, 100, dims)

# RESOURCE_STACK.append(device)
# del device