In [None]:
import pickle

import circuits.eval_sae_as_classifier as eval_sae
import circuits.analysis as analysis
import circuits.test_board_reconstruction as test_board_reconstruction
import circuits.get_eval_results as get_eval_results

Basically, just set `autoencoder_group_paths` and various hyperparameters and run it. If you already ran, for example, `eval_sae_as_classifier` and don't want to run it again, set `run_eval_sae` to False. Note that in this case, `eval_results_n_inputs` must match in order for it to load the file saved from the previous run.

In [None]:
import importlib
importlib.reload(eval_sae)
importlib.reload(analysis)
importlib.reload(test_board_reconstruction)
importlib.reload(get_eval_results)

# NOTE: This script makes a major assumption here: That all autoencoders in a given group are trained on chess XOR Othello
# We do this so we don't have to reconstruct the dataset for each autoencoder in the group
autoencoder_group_paths = ["../autoencoders/othello_layer5_ef4/"]
autoencoder_group_paths = ["../autoencoders/chess_layer5/"]

eval_sae_n_inputs = 10
batch_size = 10
device = "cuda"
model_path = "../models/"
save_results = True

eval_results_n_inputs = 100
board_reconstruction_n_inputs = 10

analysis_high_threshold = 0.95
analysis_low_threshold = 0.1
analysis_significance_threshold = 10

run_eval_results = True # We don't check for this as eval_results are pretty quick to collect

# To skip any of the following steps, set the corresponding variable to False
# The results must have been saved previously
run_eval_sae = True
run_analysis = True
run_board_reconstruction = True

dataset_size = max(eval_sae_n_inputs, eval_results_n_inputs, board_reconstruction_n_inputs)

if dataset_size == eval_results_n_inputs:
    dataset_size *= 2


for autoencoder_group_path in autoencoder_group_paths:
    othello = eval_sae.check_if_autoencoder_is_othello(autoencoder_group_path)

    indexing_functions = eval_sae.get_recommended_indexing_functions(othello)
    indexing_function = indexing_functions[0]

    custom_functions = eval_sae.get_recommended_custom_functions(othello)

    model_name = eval_sae.get_model_name(othello)

    print("Constructing evaluation dataset")
    data = eval_sae.construct_dataset(othello, custom_functions, dataset_size, device, models_path=model_path)

    folders = eval_sae.get_nested_folders(autoencoder_group_path)

    for autoencoder_path in folders:

        # For debugging
        if "ef=4_lr=1e-03_l1=1e-01_layer=5" not in autoencoder_path:
            continue

        eval_results = get_eval_results.get_evals(
            autoencoder_path,
            eval_results_n_inputs,
            device,
            model_path,
            model_name,
            data.copy(),
            othello=othello,
            save_results=save_results,
        )

        expected_aggregation_output_location = eval_sae.get_output_location(
                autoencoder_path, n_inputs=eval_sae_n_inputs, indexing_function=indexing_function
            )
        
        if run_eval_sae:
            print("Aggregating", autoencoder_path)
            aggregation_results = (
                eval_sae.aggregate_statistics(
                    custom_functions=custom_functions,
                    autoencoder_path=autoencoder_path,
                    n_inputs=eval_sae_n_inputs,
                    batch_size=batch_size,
                    device=device,
                    model_path=model_path,
                    model_name=model_name,
                    data=data.copy(),
                    indexing_function=indexing_function,
                    othello=othello,
                    save_results=save_results,
                )
            )
        else:
            with open(expected_aggregation_output_location, "rb") as f:
                aggregation_results = pickle.load(f)

        expected_feature_labels_output_location = expected_aggregation_output_location.replace(
            "results.pkl", "feature_labels.pkl"
        )
        if run_analysis:
            feature_labels = analysis.analyze_results_dict(
                aggregation_results,
                output_path=expected_feature_labels_output_location,
                device=device,
                high_threshold=analysis_high_threshold,
                low_threshold=analysis_low_threshold,
                significance_threshold=analysis_significance_threshold,
                verbose=False,
                print_results=False,
                save_results=save_results,
            )
        else:
            with open(expected_feature_labels_output_location, "rb") as f:
                feature_labels = pickle.load(f)

        expected_reconstruction_output_location = expected_aggregation_output_location.replace(
            "results.pkl", "reconstruction.pkl"
        )

        if run_board_reconstruction:
            print("Testing board reconstruction")
            board_reconstruction_results = test_board_reconstruction.test_board_reconstructions(
                    custom_functions=custom_functions,
                    autoencoder_path=autoencoder_path,
                    feature_labels=feature_labels,
                    output_file=expected_reconstruction_output_location,
                    n_inputs=board_reconstruction_n_inputs,
                    batch_size=batch_size,
                    device=device,
                    model_name=model_name,
                    data=data.copy(),
                    othello=othello,
                    print_results=False,
                    save_results=save_results,
            )
        else:
            with open(expected_reconstruction_output_location, "rb") as f:
                board_reconstruction_results = pickle.load(f)

        print("Finished", autoencoder_path)



        

Example of gathering top k contexts

In [None]:
import torch
import circuits.chess_interp as chess_interp

torch.set_grad_enabled(False)

autoencoder_group_path = autoencoder_group_paths[0]

othello = eval_sae.check_if_autoencoder_is_othello(autoencoder_group_path)

indexing_functions = eval_sae.get_recommended_indexing_functions(othello)
indexing_function = indexing_functions[0]

custom_functions = eval_sae.get_recommended_custom_functions(othello)

model_name = eval_sae.get_model_name(othello)

print("Constructing evaluation dataset")
data = eval_sae.construct_dataset(othello, custom_functions, dataset_size, device, models_path=model_path)

dataset_size = dataset_size * 2  # x2 to make sure we have enough data for loss_recovered()

data, ae_bundle, pgn_strings, encoded_inputs = eval_sae.prep_firing_rate_data(
    autoencoder_path, dataset_size, model_path, model_name, data, device, dataset_size, othello
)

dims = torch.tensor([10], device=device)
chess_interp.examine_dimension_chess(ae_bundle, 100, dims, model_path=model_path)