In [None]:
import pickle
import pandas as pd
from typing import Callable, Optional
import torch
import os
import einops
import matplotlib.pyplot as plt

import circuits.eval_sae_as_classifier as eval_sae
import circuits.analysis as analysis
import circuits.eval_board_reconstruction as eval_board_reconstruction
import circuits.get_eval_results as get_eval_results
import circuits.f1_analysis as f1_analysis
import circuits.utils as utils
import circuits.pipeline_config as pipeline_config

There are some optional parameters you can change, but it will run without issue using the defaults.

We just need to pass in `autoencoder_path` and `autoencoder_group_path` and it will load all of the required information.

At a batch size of 5 and `config.analysis_on_cpu`, peak GPU memory usage is around 2.5 GB.

In [None]:
device = torch.device("cuda:0")

autoencoder_group_path = "../autoencoders/testing_chess/"
autoencoder_path = "../autoencoders/testing_chess/trainer4/"

othello = eval_sae.check_if_autoencoder_is_othello(autoencoder_group_path)
config = pipeline_config.Config()

# These both significantly reduce peak GPU memory usage
config.batch_size = 5
config.analysis_on_cpu = True

# Precompute will create both datasets and save them as pickle files
# If precompute == False, it creates the dataset on the fly
# This is far slower when evaluating multiple SAEs, but for an exploratory run it is fine
config.precompute = False

config.eval_results_n_inputs = 1000
config.eval_sae_n_inputs = 1000
config.board_reconstruction_n_inputs = 1000

# Once you have ran the analysis, you can set this to False and it will load the saved results
config.run_analysis = True
config.run_board_reconstruction = True
config.run_eval_sae = True
config.run_eval_results = True

# If you want to save the results of the analysis
config.save_results = True
config.save_feature_labels = True

print(f"Is Othello: {othello}")

Now we create separate train and test datasets. By default, we don't precompute the board states, so the dictionaries will just contain encoded and decoded input strings. For chess, the decoded input strings are PGN strings (1.e4 e5 2.Nf3 ...) and the encoded strings are a list of integers, where every integer corresponds to a character.

The board states will be computed on the fly if `precompute == False`. The tensor for `board_to_piece_state` will be of shape (batch size, seq length, rows, columns, classes) or (batch size, 20, 8, 8, 13).

20: There are 20 periods (which means white's turn to move) in a PGN string of 256 characters.

8: rows / columns

13: Total number of piece types (black king, black queen, blank, ... white king)

In [None]:
dataset_size = max(config.eval_sae_n_inputs, config.board_reconstruction_n_inputs)

# We have plenty of data and eval_results_data doesn't use VRAM, so we can afford to make it large
# So we don't hit the end of the activation buffer
eval_results_dataset_size = config.eval_results_n_inputs * 10

indexing_functions = eval_sae.get_recommended_indexing_functions(othello)
indexing_function = indexing_functions[0]

if othello:
    custom_functions = config.othello_functions
    game_name = "othello"
else:
    custom_functions = config.chess_functions
    game_name = "chess"

train_dataset_name = f"{game_name}_train_dataset.pkl"
test_dataset_name = f"{game_name}_test_dataset.pkl"

if os.path.exists(train_dataset_name) and config.precompute:
    print("Loading statistics aggregation dataset")
    with open(train_dataset_name, "rb") as f:
        train_data = pickle.load(f)
else:
    print("Constructing statistics aggregation dataset")
    train_data = eval_sae.construct_dataset(
        othello,
        custom_functions,
        dataset_size,
        split="train",
        device=device,
        precompute_dataset=config.precompute,
    )
    if config.precompute:
        print("Saving statistics aggregation dataset")
        with open(train_dataset_name, "wb") as f:
            pickle.dump(train_data, f)

if os.path.exists(test_dataset_name) and config.precompute:
    print("Loading test dataset")
    with open(test_dataset_name, "rb") as f:
        test_data = pickle.load(f)
else:
    print("Constructing test dataset")
    test_data = eval_sae.construct_dataset(
        othello,
        custom_functions,
        dataset_size,
        split="test",
        device=device,
        precompute_dataset=config.precompute,
    )
    if config.precompute:
        print("Saving test dataset")
        with open(test_dataset_name, "wb") as f:
            pickle.dump(test_data, f)

eval_results_data = eval_sae.construct_dataset(
    othello,
    [],
    eval_results_dataset_size,
    split="train",
    device=device,
    precompute_dataset=config.precompute,
)

Now we run an evaluation to get some standard sparse autoencoder metrics, such as L0 and loss recovered.

In [None]:
expected_eval_results_output_location = get_eval_results.get_output_location(
    autoencoder_path, n_inputs=config.eval_results_n_inputs
)

if config.run_eval_results:

    # If this is set, everything below should be reproducible
    # Then we can just save results from 1 run, make optimizations, and check that the results are the same
    # The determinism is only needed for getting activations from the activation buffer for finding alive features
    torch.manual_seed(0)
    eval_results = get_eval_results.get_evals(
        autoencoder_path,
        config.eval_results_n_inputs,
        config.batch_size,
        device,
        utils.to_device(eval_results_data.copy(), device),
        othello=othello,
        save_results=config.save_results,
    )
else:
    with open(expected_eval_results_output_location, "rb") as f:
        eval_results = pickle.load(f)
    eval_results = utils.to_device(eval_results, device)

We can view the results here.

In [None]:
print(eval_results["eval_results"].keys())
print(f"L0: {eval_results['eval_results']['l0']}")
print(f"Loss recovered: {eval_results['eval_results']['frac_recovered']}")

Now, we do the statistics aggregation, or the "training" phase. This will take a couple minutes to run depending on GPU. I will explain what this does in future cells.

In [None]:
expected_aggregation_output_location = eval_sae.get_output_location(
    autoencoder_path,
    n_inputs=config.eval_sae_n_inputs,
    indexing_function=indexing_function,
)

if config.run_eval_sae:
    print("Aggregating", autoencoder_path)
    aggregation_results = eval_sae.aggregate_statistics(
        custom_functions=custom_functions,
        autoencoder_path=autoencoder_path,
        n_inputs=config.eval_sae_n_inputs,
        batch_size=config.batch_size,
        device=device,
        data=utils.to_device(train_data.copy(), device),
        thresholds_T=config.f1_analysis_thresholds,
        indexing_function=indexing_function,
        othello=othello,
        save_results=config.save_results,
        precomputed=config.precompute,
    )
else:
    with open(expected_aggregation_output_location, "rb") as f:
        aggregation_results = pickle.load(f)
    aggregation_results = utils.to_device(aggregation_results, device)

We take the `aggregation_results` and use them to calculate the `feature_labels`.

In [None]:
if config.analysis_on_cpu:
    aggregation_results = utils.to_device(aggregation_results, "cpu")
    analysis_device = "cpu"
else:
    analysis_device = device

torch.cuda.empty_cache()

expected_feature_labels_output_location = expected_aggregation_output_location.replace(
    "results.pkl", "feature_labels.pkl"
)
if config.run_analysis:
    feature_labels, misc_stats = analysis.analyze_results_dict(
        aggregation_results,
        output_path=expected_feature_labels_output_location,
        device=analysis_device,
        high_threshold=config.analysis_high_threshold,
        low_threshold=config.analysis_low_threshold,
        significance_threshold=config.analysis_significance_threshold,
        verbose=False,
        print_results=False,
        save_results=config.save_feature_labels,
    )
else:
    with open(expected_feature_labels_output_location, "rb") as f:
        feature_labels = pickle.load(f)
    feature_labels = utils.to_device(feature_labels, analysis_device)


Plotting / display functions

In [None]:
def rc_to_square_notation(row, col):
    letters = "ABCDEFGH"
    number = row + 1
    letter = letters[col]
    return f"{letter}{number}"

def plot_board(board_RR: torch.Tensor, title: str = "Board", png_filename: Optional[str] = None):
    """
    Plots an 8x8 board with the value of the maximum square displayed in red text to two decimal places.

    Args:
        board_RR (torch.Tensor): A 2D tensor of shape (8, 8) with values from 0 to 1.
        title (str): Title of the plot.
    """
    assert board_RR.shape == (8, 8), "board_RR must be of shape 8x8"

    # Flip the board vertically
    board_RR = torch.flip(board_RR, [0])

    plt.imshow(board_RR, cmap='gray_r', interpolation='none', vmin=0, vmax=1)
    plt.colorbar()  # Adds a colorbar to help identify the values
    plt.title(title)

    # Set labels for columns (A-H)
    plt.xticks(range(8), ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])

    # Set labels for rows (1-8)
    plt.yticks(range(8), range(8, 0, -1))

    # Add gridlines mimicking a chess board
    # plt.grid(True, color='black', linewidth=1, linestyle='-', alpha=0.5)
    # plt.tick_params(bottom=False, left=False, labelbottom=True, labelleft=True)

    # Offset gridlines by 0.5 in x and y
    plt.gca().set_xticks([x - 0.5 for x in range(1, 9)], minor=True)
    plt.gca().set_yticks([y - 0.51 for y in range(1, 9)], minor=True)
    plt.grid(True, which='minor', color='black', linewidth=1, linestyle='-', alpha=0.5)

    # Find the maximum value and its position
    max_value, max_pos = torch.max(board_RR), torch.argmax(board_RR)
    max_i, max_j = torch.div(max_pos, 8, rounding_mode='floor'), max_pos % 8

    # Display the maximum value in red text at the corresponding position
    plt.text(max_j, max_i, f"{max_value:.0%}", color='red', ha='center', va='center', fontsize=12)

    if png_filename is not None:
        plt.savefig(png_filename)

    plt.show()

num_to_class = {0: "Black King", 1: "Black Queen", 2: "Black Rook", 3: "Black Bishop", 4: "Black Knight", 5: "Black Pawn",
                6: "Blank", 7: "White Pawn", 8: "White Knight", 9: "White Bishop", 10: "White Rook", 11: "White Queen", 12: "White King"}

What is `feature_labels`? It's a dict of board to state function: tensor. Various board to state functions include `board_to_pin_state`, `board_to_piece_masked_blank_and_initial_state`, etc.

`aggregation_results` contains the average state of the board when each feature is active above each threshold (for all board to state functions). Using `aggregation_results`, we identify all features that predict with at least 95% precision when a board state is present. `feature_labels` contains binary tensors, where a 1 indicates the board state is likely to be present when the feature is active.

The feature labels tensor is of shape (num_thresholds, num_alive_features, rows, columns, classes).

If feature labels `board_to_piece_masked_blank_and_initial_state` has a value of 1 at `threshold==5`, `feature_idx==173`, `row==2`, `column==2`, `classes==7`, then that means when feature 173 is active over a threshold of 50%, there is over a 95% chance there is a white pawn of C3.

In [None]:
function_of_interest = "board_to_piece_masked_blank_and_initial_state"

board_state_feature_labels_TFRRC = feature_labels[function_of_interest]
print(f"Board state feature labels: {board_state_feature_labels_TFRRC.shape}")
threshold = 2

board_state_feature_labels_FRRC = board_state_feature_labels_TFRRC[threshold]
board_state_counts_F = einops.reduce(board_state_feature_labels_FRRC, "F R1 R2 C -> F", "sum")

max_features = 175
demo_idx = 0
for i in range(max_features):
    if board_state_counts_F[i] > 0:
        print(f"Feature {i} has {board_state_counts_F[i]} classified squares")
        demo_idx = i

demo_feature_labels_RRC = board_state_feature_labels_FRRC[demo_idx]
print(f"\nFeature {demo_idx} has {board_state_counts_F[demo_idx].sum().item()} classified squares")

classified_squares = torch.where(demo_feature_labels_RRC == 1)
print(f"Classified squares as tensors: {classified_squares}")

row, column, classes = classified_squares

print(f"\nClassified squares for feature {demo_idx} at threshold {threshold}:")
for i in range(row.shape[0]):
    print(rc_to_square_notation(row[i].item(), column[i].item()), num_to_class[classes[i].item()])

There is a major footgun here. `feature_labels` is of shape (num_thresholds, num_alive_features, rows, columns, classes). num_alive_features != SAE hidden dimension.

If you wish to use identified SAE features for other tasks, you must take this in to account. We made this optimization because many SAEs had a significant number of dead features. By ignoring these features, we lowered compute / memory requirements. But, this is also highly confusing, and is a questionable choice in retrospect.

For the purposes of this demo notebook, we will use the indices in `feature_labels`.

To identify the "real SAE feature idx" which you would use for indexing in to SAE activations, do something like the following:

In [None]:
print(feature_labels[function_of_interest].shape)
print("The shape of alive_features doesn't equal the SAE hidden dim of 4096!", feature_labels['alive_features'].shape)

real_sae_feature_idx = feature_labels['alive_features'][demo_idx]

print(f"For feature idx {demo_idx}, the real SAE feature idx is {real_sae_feature_idx}")

How were these feature labels computed? We start with `aggregation_results`. It contains the average state of the board for every board state function for every feature when it is active.

In [None]:
print(aggregation_results['on_count'].shape)

T, F = aggregation_results["on_count"].shape

print(f"For all {F} alive features over {T} thresholds, on_count is the number of times each feature is on above every threshold")

Here is how the feature label was determined for this square.

In [None]:
example_row = row[0].item()
example_column = column[0].item()
example_class = classes[0].item()

example_on_count = aggregation_results["on_count"][threshold, demo_idx]
example_present_count = aggregation_results[function_of_interest]['on'][threshold, demo_idx, example_row, example_column, example_class].item()

print(f"Feature {demo_idx} was active {example_on_count} times above threshold {threshold}")
print(f"During these activations, there was a {num_to_class[example_class]} at {rc_to_square_notation(example_row, example_column)}")
print(f"{example_present_count} times, or {example_present_count / example_on_count:.0%} of the time")

In [None]:
formatted_results = analysis.add_off_tracker(aggregation_results, custom_functions, analysis_device)

formatted_results = analysis.normalize_tracker(
    formatted_results,
    "on",
    custom_functions,
    analysis_device,
)

formatted_results = analysis.normalize_tracker(
    formatted_results,
    "off",
    custom_functions,
    analysis_device,
)

print(formatted_results["board_to_piece_masked_blank_and_initial_state"]['on_normalized'].shape)

board_results_TFRRC = formatted_results["board_to_piece_masked_blank_and_initial_state"]['on_normalized']

def plot_feature_board_states(board_results_TFRRC: torch.Tensor, feature_idx: int, threshold: int, piece_type: int):
    results_RRC = board_results_TFRRC[threshold, feature_idx]

    feature_on_count = formatted_results['on_count'][threshold, feature_idx]

    print(f"Feature {feature_idx} had {int(feature_on_count)} activations over threshold {(threshold * 10)}%")

    print(results_RRC.shape)
    results_RR = results_RRC[..., piece_type]
    print(results_RR)

    title = f"Average {num_to_class[piece_type]} activation for \nfeature {feature_idx} over threshold {(threshold * 10)}%"
    png_filename = f"feature_{feature_idx}_threshold_{threshold}_piece_{piece_type}.png"
    plot_board(results_RR, title, png_filename)

# plot_feature_board_states(board_results_TFRRC, demo_idx, 0, 8)
# plot_feature_board_states(board_results_TFRRC, demo_idx, 2, 8)
# plot_feature_board_states(board_results_TFRRC, demo_idx, 2, 5)
# plot_feature_board_states(board_results_TFRRC, demo_idx, 0, 3)
# plot_feature_board_states(board_results_TFRRC, demo_idx, 5, 3)

In [None]:
# As we increase the threshold, the precision of the prediction increases

plot_feature_board_states(board_results_TFRRC, demo_idx, 0, 8)
plot_feature_board_states(board_results_TFRRC, demo_idx, 2, 8)

In [None]:
# Classes without a feature label don't have high precision predictions

plot_feature_board_states(board_results_TFRRC, demo_idx, 2, 5)

In [None]:
# As the threshold increases, a square that was not a feature label may become one as the precision increases

plot_feature_board_states(board_results_TFRRC, demo_idx, 0, 3)
plot_feature_board_states(board_results_TFRRC, demo_idx, 5, 3)

Now we move everything back to our device.

In [None]:
if config.analysis_on_cpu:
    aggregation_results = utils.to_device(aggregation_results, device)
    feature_labels = utils.to_device(feature_labels, device)
    misc_stats = utils.to_device(misc_stats, device)


Now, we use these feature labels to reconstruct the state of the board as measured by all board state functions. At every board state, we reconstruct it using only SAE feature activations and `feature_labels`. We measure the accuracy of the reconstructed board using F1 score. This will take a few minutes to run depending on GPU.

In [None]:
# Optionally, this can be sped up by
# config.board_reconstruction_n_inputs = 100

expected_reconstruction_output_location = expected_aggregation_output_location.replace(
    "results.pkl", "reconstruction.pkl"
)

if config.run_board_reconstruction:
    print("Testing board reconstruction")
    board_reconstruction_results = eval_board_reconstruction.test_board_reconstructions(
        custom_functions=custom_functions,
        autoencoder_path=autoencoder_path,
        feature_labels=feature_labels,
        output_file=expected_reconstruction_output_location,
        n_inputs=config.board_reconstruction_n_inputs,
        batch_size=config.batch_size,
        device=device,
        data=utils.to_device(test_data.copy(), device),
        othello=othello,
        print_results=False,
        save_results=config.save_results,
        precomputed=config.precompute,
    )
else:
    with open(expected_reconstruction_output_location, "rb") as f:
        board_reconstruction_results = pickle.load(f)
    board_reconstruction_results = utils.to_device(board_reconstruction_results, device)


We can then view (F1 score, number of true positives, false positives, false negatives, etc) per threshold for every function.

In [None]:
function_of_interest = "board_to_piece_masked_blank_and_initial_state"

print(board_reconstruction_results.keys())
print(board_reconstruction_results[function_of_interest].keys())
print(board_reconstruction_results[function_of_interest]['f1_score_per_class'])

threshold = 2

print(f"At threshold {threshold}, this SAE reconstructed {function_of_interest} with an F1 score of {board_reconstruction_results[function_of_interest]['f1_score_per_class'][threshold]}")