In [None]:
from tqdm import tqdm
import pickle
import torch
import einops
from typing import Callable, Optional
import math
import os

from circuits.utils import (
    collect_activations_batch,
    get_nested_folders,
    to_device,
)
import circuits.eval_sae_as_classifier as eval_sae
import circuits.chess_utils as chess_utils
import circuits.othello_utils as othello_utils
import circuits.test_board_reconstruction as test_board_reconstruction
import circuits.othello_engine_utils as othello_engine_utils

In [None]:
autoencoder_group_path = "autoencoders/othello_layer5_ef4/"
autoencoder_folder = "ef=4_lr=1e-03_l1=6e-02_layer=5/"
autoencoder_path = autoencoder_group_path + autoencoder_folder
feature_labels_file = "indexing_None_n_inputs_10000_results_feature_labels.pkl"
reconstruction_file = "indexing_None_n_inputs_10000_results_reconstruction.pkl"
device = "cuda"
# device = "cpu" # Not sure wtf is going on, I get this error with CPU:
# RuntimeError: Unhandled FakeTensor Device Propagation for aten.bmm.default, found two different devices cpu:0, cpu
device = torch.device(device)
othello = eval_sae.check_if_autoencoder_is_othello(autoencoder_group_path)

n_inputs = 100
batch_size = 1

print(f"Othello: {othello}")
model_name = eval_sae.get_model_name(othello)

torch.set_printoptions(precision=2, sci_mode=False)
torch.set_grad_enabled(False)

# torch.set_default_tensor_type('torch.FloatTensor')  # sets default tensor type to CPU

In [None]:
with open(autoencoder_path + feature_labels_file, "rb") as f:
    feature_labels = pickle.load(f)
feature_labels = to_device(feature_labels, device)

custom_functions = [othello_utils.games_batch_to_state_stack_mine_yours_BLRRC]

data = eval_sae.construct_dataset(othello, custom_functions, n_inputs, device)

data, ae_bundle, pgn_strings, encoded_inputs = eval_sae.prep_firing_rate_data(
        autoencoder_path, batch_size, "", model_name, data, device, n_inputs, othello
    )
ae_bundle.buffer = None

In [None]:
thresholds_TF11 = feature_labels["thresholds"].to(device)
alive_features_F = feature_labels["alive_features"].to(device)
num_features = len(alive_features_F)
T, F, _, _ = thresholds_TF11.shape
indexing_function = None

if feature_labels["indexing_function"] in chess_utils.supported_indexing_functions:
    indexing_function = chess_utils.supported_indexing_functions[
        feature_labels["indexing_function"]
    ]

print(f"Num alive features: {num_features}")
print(f"Indexing function: {indexing_function}")

In [None]:
start = 0
end = 1
feature_batch_size = 1
num_feature_iters = math.ceil(num_features / feature_batch_size)
game_of_interest = 0
move_of_interest = 30

pgn_strings_BL = pgn_strings[start:end]
encoded_inputs_BL = encoded_inputs[start:end]
encoded_inputs_BL = torch.tensor(encoded_inputs_BL).to(device)

results = test_board_reconstruction.initialize_reconstruction_dict(
        custom_functions, thresholds_TF11.shape[0], alive_features_F, device
    )

batch_data = eval_sae.get_data_batch(
    data, pgn_strings_BL, start, end, custom_functions, device
)

all_activations_FBL, encoded_token_inputs = collect_activations_batch(
    ae_bundle, encoded_inputs_BL, alive_features_F
)

if indexing_function is not None:
    all_activations_FBL, batch_data = eval_sae.apply_indexing_function(
        pgn_strings[start:end], all_activations_FBL, batch_data, device, indexing_function
    )

constructed_boards = test_board_reconstruction.initialized_constructed_boards_dict(
    custom_functions, batch_data, thresholds_TF11, device
)

feature_piece_counts_TF = torch.zeros(T, F, device=device)

# For thousands of features, this would be many GB of memory. So, we minibatch.
for feature in range(num_feature_iters):
    f_start = feature * feature_batch_size
    f_end = min((feature + 1) * feature_batch_size, num_features)
    f_batch_size = f_end - f_start

    activations_FBL = all_activations_FBL[
        f_start:f_end
    ]  # NOTE: Now F == feature_batch_size

    results, additive_boards = test_board_reconstruction.aggregate_feature_labels(
        results,
        feature_labels,
        custom_functions,
        activations_FBL,
        thresholds_TF11[:, f_start:f_end, :, :],
        f_start,
        f_end,
        device,
    )

    additive_board_TBLRRC = additive_boards[custom_functions[0].__name__]

    counts_per_threshold_T = einops.reduce(additive_board_TBLRRC[:, game_of_interest, move_of_interest, :, :, :], "T R1 R2 C -> T", "sum")

    feature_piece_counts_TF[:, f_start] = counts_per_threshold_T


    for custom_function in constructed_boards:
        constructed_boards[custom_function] += additive_boards[custom_function]
results = test_board_reconstruction.compare_constructed_to_true_boards(
    results, custom_functions, constructed_boards, batch_data, device
)
results = test_board_reconstruction.calculate_F1_scores(results, custom_functions)

In [None]:
f1_scores = results[custom_functions[0].__name__]["f1_score"]
print(f1_scores)
best_idx = f1_scores.argmax()
print(f"Best threshold: {best_idx}")

In [None]:
top_20_features = torch.argsort(feature_piece_counts_TF[best_idx], descending=True)[:20]
print(f"Top 20 features: {top_20_features}")
print(f"Top 20 feature counts: {feature_piece_counts_TF[best_idx, top_20_features]}")

print(feature_piece_counts_TF[best_idx, top_20_features[0]].sum().item())

print(f"Top 20 feature labels shape: {feature_labels[custom_functions[0].__name__][best_idx, top_20_features].shape}")

top_feature_RRC = feature_labels[custom_functions[0].__name__][best_idx, top_20_features[0]]

print(top_feature_RRC.sum().item())

In [None]:
top_feature_RR = torch.argmax(top_feature_RRC, dim=-1)
top_feature_RR -= 1
zero_positions = torch.all(top_feature_RRC == 0, dim=-1)
top_feature_RR[zero_positions] = 9

print("Top feature:")
for row in top_feature_RR:
    for value in row:
        # Print a blank space if the value is 9, otherwise print the value
        if value.item() == 9:
            print(' ', end=' ')
        else:
            print(value.item(), end=' ')
    print()  # Newline after each row

In [None]:
board_state_RRC = batch_data[custom_functions[0].__name__][game_of_interest][move_of_interest]
board_state_RR = torch.argmax(board_state_RRC, dim=-1) - 1
print(board_state_RR)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_othello_board(board):
    """
    Plots an Othello board using matplotlib with a specific color lookup for different values.

    Args:
    board (torch.Tensor): A 2D tensor representing the Othello board,
                          where 0, -1, 1, and 9 are mapped to specific colors.
    """
    # Create a color map with specific colors
    # Creating a dictionary for the color mapping
    color_map = {-1: 'black', 0: 'grey', 1: 'white', 9: 'yellow'}
    
    # Replace board values with corresponding colors using a numpy vectorized operation
    label_colors = np.vectorize(color_map.get)(board.numpy())

    # Create a figure and axis for the plot
    fig, ax = plt.subplots()

    # Create a color map based on the unique labels in the board
    unique_labels = np.unique(board)
    colors = [color_map[label] for label in unique_labels]
    cmap = plt.matplotlib.colors.ListedColormap(colors)

    # Map board values to indices in the unique labels
    board_indices = np.vectorize(lambda x: np.where(unique_labels == x)[0][0])(board.numpy())

    # Plot the board using imshow
    cax = ax.imshow(board_indices, cmap=cmap)

    # Create a color bar with the correct labels
    cbar = fig.colorbar(cax, ticks=range(len(unique_labels)))
    cbar.ax.set_yticklabels([color_map[label] for label in unique_labels])

    # Set the axis to be off since we don't need it for a game board representation
    ax.axis('off')

    # Add a title to the plot
    plt.title('Othello Board. Grey = Empty, Yellow = Not present in one hot vector')

    # Show the plot
    plt.show()

plot_othello_board(board_state_RR.to('cpu'))
plot_othello_board(top_feature_RR.to('cpu'))