In [None]:
import torch
import pickle
import einops
import importlib
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
from matplotlib.colors import Normalize

import circuits.analysis as analysis
import circuits.eval_sae_as_classifier as eval_sae
import circuits.chess_utils as chess_utils
import circuits.utils as utils

In [None]:
# import torch
# # for testing purposes

# # Define a sample 3D tensor with random values
# # Dimensions: T x F x C (let's use 2 x 3 x 4 for simplicity)
# f1_TFC = torch.randn(2, 3, 4)
# print("Original Tensor (T x F x C):")
# print(f1_TFC)

# def best_f1_average(f1_TFC: torch.Tensor) -> torch.Tensor:
#     # Apply torch.max along the last dimension (dimension 2)
#     # Select only the values, ignoring the indices
#     f1_TF, _ = torch.max(f1_TFC, dim=1)
#     return f1_TF

# # Compute the maximum along the 'C' dimension and reduce to a 2D tensor
# f1_TF = best_f1_average(f1_TFC)
# print("\nReduced Tensor (T x F) with max values from 'C':")
# print(f1_TF)

In [None]:
def mask_all_blanks(results: dict, device) -> dict:
    custom_functions = analysis.get_all_custom_functions(results)
    for function in custom_functions:
        function_name = function.__name__

        if function == chess_utils.board_to_piece_state or function == chess_utils.board_to_piece_color_state:
            on_TFRRC = results[function_name]['on']
            off_TFRRC = results[function_name]['off']
            results[function_name]['on'] = analysis.mask_initial_board_state(on_TFRRC, function, device)
            results[function_name]['off'] = analysis.mask_initial_board_state(off_TFRRC, function, device)

    return results

def best_f1_average(f1_TFRRC: torch.Tensor) -> torch.Tensor:
    """For every threshold, for every square, find the best F1 score across all features. Then average across all squares.
    NOTE: If the function is binary, num_squares == 1. If it is board to piece state, num_squares == 8 * 8 * 12"""
    f1_TRRC, _ = torch.max(f1_TFRRC, dim=1)

    T, R1, R2, C = f1_TRRC.shape

    max_possible = R1 * R2 * C

    f1_T = einops.reduce(f1_TRRC, 'T R1 R2 C -> T', 'sum') / max_possible

    return f1_T
    

def f1s_above_threshold(f1_TFRRC: torch.Tensor, threshold: float) -> torch.Tensor:
    """For every threshold, for every square, find the best F1 score across all features. Then, find the number of squares that have a F1 score above the threshold.
    If the function is binary, num_squares == 1. If it is board to piece state, num_squares == 8 * 8 * 12
    NOTE: This will probably be most useful for features with 8x8xn options."""
    f1_TRRC, _ = torch.max(f1_TFRRC, dim=1)

    f1s_above_threshold_TRCC = f1_TRRC > threshold

    T, R1, R2, C = f1_TRRC.shape

    max_possible = R1 * R2 * C

    f1_T = einops.reduce(f1s_above_threshold_TRCC, 'T R1 R2 C -> T', 'sum')

    return f1_T


Define more custom functions above. At the bottom of this cell, by the NOTE, use the custom function you are interested in.

In [None]:
importlib.reload(analysis)

device = "cpu"
# device = "cuda"
mask = False

autoencoder_group_paths = ["../autoencoders/chess_layer5_large_sweep/"]
autoencoder_group_paths = ["../autoencoders/group-2024-05-14_chess/"]
# autoencoder_group_paths = ["../autoencoders/chess_layer0/"]

custom_functions = []
custom_function_names = []

csv_results_file = "../autoencoders/chess_layer5_large_sweep/results.csv"
# csv_results_file = "../autoencoders/chess_layer0/results.csv"
csv_results_file = "../autoencoders/group-2024-05-14_chess/results.csv"

df = pd.read_csv(csv_results_file)

all_sae_results = {}

results_filename_filter = "1000" # This is only necessary if you have multiple files with multiple n_inputs
# e.g. indexing_find_dots_indices_n_inputs_1000_results.pkl and indexing_find_dots_indices_n_inputs_5000_results.pkl
# In this case, if you want to view the results for n_inputs = 1000, you would set filter = "1000"

for autoencoder_group_path in autoencoder_group_paths:

    folders = eval_sae.get_nested_folders(autoencoder_group_path)
    sae_results = {}

    for autoencoder_path in folders:

        print(f"Processing {autoencoder_path}")

        assert autoencoder_path in df["autoencoder_path"].values, f"{autoencoder_path} not in csv file"

        sae_results[autoencoder_path] = {}

        results_filenames = analysis.get_all_results_file_names(autoencoder_path, results_filename_filter)
        if len(results_filenames) > 1 or len(results_filenames) == 0:
            print(f"Skipping {autoencoder_path} because it has {len(results_filenames)} results files")
            print("This is most likely because there are results files from different n_inputs")
            continue
        results_filename = results_filenames[0]

        with open(autoencoder_path + results_filename, "rb") as f:
            results = pickle.load(f)

        results = utils.to_device(results, device)

        custom_functions = analysis.get_all_custom_functions(results)
        for function in custom_functions:
            function_name = function.__name__
            custom_function_names.append(function_name)
        
        results = analysis.add_off_tracker(results, custom_functions, device)
        f1_dict_TFRRC = analysis.get_all_f1s(results, device)

        feature_labels = analysis.analyze_results_dict(results, output_path="", device=device, high_threshold=0.95, low_threshold=0.1, significance_threshold=10, save_results=False, mask=mask, verbose=False, print_results=False)

        correct_row = df["autoencoder_path"] == autoencoder_path
        sae_results[autoencoder_path]["l0"] = df[correct_row]["l0"].values[0]
        sae_results[autoencoder_path]["frac_variance_explained"] = df[correct_row]["frac_variance_explained"].values[0]

        for func_name in custom_function_names:
            if func_name in all_sae_results:
                continue

            T = f1_dict_TFRRC[func_name].shape[0]
            f1_counter_T = torch.zeros(T, device=device)
            all_sae_results[func_name] = {"f1_counter": f1_counter_T}

        for func_name in f1_dict_TFRRC:
            config = chess_utils.config_lookup[func_name]
            custom_function = config.custom_board_state_function
            assert custom_function in custom_functions, f"Key {custom_function} not in custom_functions"
            f1_TFRRC = f1_dict_TFRRC[func_name]


            # NOTE: Set your function of interest here
            f1_T = best_f1_average(f1_TFRRC)
            # f1_T = f1s_above_threshold(f1_TFRRC, 0.5)
            all_sae_results[func_name]["f1_counter"] += f1_T

            sae_results[autoencoder_path][func_name] = f1_T

        # torch.cuda.empty_cache()
    all_sae_results[autoencoder_group_path] = sae_results





By default, this looks at the best_idx. If you want to look at a particular threshold, set `best_idx = 3` or whatever you are interested in.

In [None]:
for func_name in custom_function_names:

    new_column_name = f"{func_name}_best_custom_metric"
    if new_column_name not in df.columns:
        df[new_column_name] = np.nan
    
    second_column_name = f"{func_name}_best_custom_metric_idx"

    f1_counter_T = all_sae_results[func_name]["f1_counter"]
    best_idx = torch.argmax(f1_counter_T)

    for autoencoder_group_path in autoencoder_group_paths:
        folders = eval_sae.get_nested_folders(autoencoder_group_path)
        
        for autoencoder_path in folders:
            f1_T = all_sae_results[autoencoder_group_path][autoencoder_path][func_name]
            best_f1 = f1_T[best_idx]
            df.loc[df["autoencoder_path"] == autoencoder_path, new_column_name] = best_f1.item()
            df.loc[df["autoencoder_path"] == autoencoder_path, second_column_name] = best_idx.item()

In [None]:
custom_metric_columns = []
custom_metric_idx_columns = []
for col in df.columns:
    if "custom_metric" in col and "idx" not in col:
        custom_metric_columns.append(col)
        print(col)
    
        idx_column_name = col + "_idx"
        custom_metric_idx_columns.append(idx_column_name)

In [None]:
best_f1_columns = []
best_f1_idx_columns = []
for col in df.columns:
    if "best_f1_score_per_square" in col:
        best_f1_columns.append(col)
        print(col)
        f1_idx = col.replace("best_f1_score_per_square", "best_idx")
        best_f1_idx_columns.append(f1_idx)

The next 2 cells find the average f1 score and custom metric score for all functions, all 8x8 board state functions, and all binary functions, then store it in the df and `average_metric_columns` and `average_metric_idx_columns`. It's pretty verbose, but it works.

In [None]:
board_state_8x8_columns = ["board_to_piece_state", "board_to_piece_color_state", "board_to_threat_state", "board_to_legal_moves_state", "board_to_pseudo_legal_moves_state"]

def get_board_state_columns(board_state_columns: list[str], columns: list[str], include: bool = True) -> list[str]:
    result_columns = []
    for col in columns:
        if any([board_state in col for board_state in board_state_columns]):
            if include:
                result_columns.append(col)
        else:
            if not include:
                result_columns.append(col)
    return result_columns

best_f1_board_state_columns = get_board_state_columns(board_state_8x8_columns, best_f1_columns, include=True)
best_f1_board_state_idx_columns = get_board_state_columns(board_state_8x8_columns, best_f1_idx_columns, include=True)

best_custom_metric_board_state_columns = get_board_state_columns(board_state_8x8_columns, custom_metric_columns, include=True)
best_custom_metric_board_state_idx_columns = get_board_state_columns(board_state_8x8_columns, custom_metric_idx_columns, include=True)

best_f1_binary_columns = get_board_state_columns(board_state_8x8_columns, best_f1_columns, include=False)
best_f1_binary_idx_columns = get_board_state_columns(board_state_8x8_columns, best_f1_idx_columns, include=False)

best_custom_metric_binary_columns = get_board_state_columns(board_state_8x8_columns, custom_metric_columns, include=False)
best_custom_metric_binary_idx_columns = get_board_state_columns(board_state_8x8_columns, custom_metric_idx_columns, include=False)


In [None]:
def add_average_metric_over_functions(
    df: pd.DataFrame,
    metric_type: str,
    average_metric_columns: list[str],
    average_metric_idx_columns: list[str],
    custom_metric_columns: list[str],
    custom_metric_idx_columns: list[str],
) -> tuple[pd.DataFrame, list[str], list[str]]:

    average_metric_column = f"{metric_type}_average"
    average_metric_idx_column = f"{metric_type}_average_idx"

    average_metric_columns.append(average_metric_column)
    average_metric_idx_columns.append(average_metric_idx_column)

    df[average_metric_column] = np.nan
    df[average_metric_idx_column] = np.nan

    df[average_metric_column] = df[custom_metric_columns].mean(axis=1)
    df[average_metric_idx_column] = df[custom_metric_idx_columns].mean(axis=1)

    return df, average_metric_columns, average_metric_idx_columns


average_metric_columns = []
average_metric_idx_columns = []

df, average_metric_columns, average_metric_idx_columns = add_average_metric_over_functions(
    df,
    "best_f1_score_per_square",
    average_metric_columns,
    average_metric_idx_columns,
    best_f1_board_state_columns,
    best_f1_board_state_idx_columns,
)
df, average_metric_columns, average_metric_idx_columns = add_average_metric_over_functions(
    df,
    "best_custom_metric",
    average_metric_columns,
    average_metric_idx_columns,
    best_custom_metric_board_state_columns,
    best_custom_metric_board_state_idx_columns,
)
df, average_metric_columns, average_metric_idx_columns = add_average_metric_over_functions(
    df,
    "best_f1_score_per_square_only_board_state",
    average_metric_columns,
    average_metric_idx_columns,
    best_f1_board_state_columns,
    best_f1_board_state_idx_columns,
)
df, average_metric_columns, average_metric_idx_columns = add_average_metric_over_functions(
    df,
    "best_custom_metric_only_board_state",
    average_metric_columns,
    average_metric_idx_columns,
    best_custom_metric_board_state_columns,
    best_custom_metric_board_state_idx_columns,
)
df, average_metric_columns, average_metric_idx_columns = add_average_metric_over_functions(
    df,
    "best_f1_score_per_square_only_binary",
    average_metric_columns,
    average_metric_idx_columns,
    best_f1_binary_columns,
    best_f1_binary_idx_columns,
)
df, average_metric_columns, average_metric_idx_columns = add_average_metric_over_functions(
    df,
    "best_custom_metric_only_binary",
    average_metric_columns,
    average_metric_idx_columns,
    best_custom_metric_binary_columns,
    best_custom_metric_binary_idx_columns,
)

In [None]:
df
df.to_csv("processed_results.csv", index=False)

In [None]:
# select only the numerical columns
numerical_columns = df.select_dtypes(include=['float64', 'int64']).columns
numerical_data = df[numerical_columns]

# calculate the correlation matrix
correlation_matrix = numerical_data.corr()

# create a heatmap using plotly
fig = px.imshow(correlation_matrix, 
                labels=dict(x="Columns", y="Columns", color="Correlation"),
                x=correlation_matrix.columns,
                y=correlation_matrix.columns,
                color_continuous_scale='RdBu_r',
                zmin=-1, zmax=1)

# update the layout
fig.update_layout(
    title='Correlation Matrix',
    width=2000,
    height=2000
)

# display the plot
fig.show()

Here are all the new custom metric columns.

In [None]:
for col in custom_metric_columns:
    print(col)

In [None]:
for col in best_f1_columns:
    print(col)

In [None]:
for col in average_metric_columns:
    print(col)
for col in average_metric_idx_columns:
    print(col)

In [None]:
# get unique trainer types
unique_trainers = df['trainer_class'].unique()

# create a dictionary mapping trainer types to marker shapes
trainer_markers = dict(zip(unique_trainers, ['o', 's', '^', 'D']))

def plot_custom_metric(color_column: str, idx_column_name: str):
    # create the scatter plot
    fig, ax = plt.subplots(figsize=(10, 6))

    # create a normalize object for color scaling
    # color_column = 'board_to_can_capture_queen_best_custom_metric'
    # color_column = 'board_to_piece_state_best_custom_metric'
    # color_column = 'board_to_has_legal_en_passant_best_custom_metric'
    # color_column = 'board_to_pin_state_best_custom_metric'
    # color_column = custom_metric_columns[6]
    norm = Normalize(vmin=df[color_column].min(), vmax=df[color_column].max())

    metric_1 = "l0"
    metric_2 = "frac_variance_explained"
    metric_2 = "frac_recovered"

    idx = df[idx_column_name].values[0]

    # plot data points for each trainer type separately
    for trainer, marker in trainer_markers.items():
        trainer_data = df[df['trainer_class'].str.contains(trainer)]
        ax.scatter(trainer_data[metric_1], trainer_data[metric_2], c=trainer_data[color_column], cmap='viridis', marker=marker, s=100, label=trainer, norm=norm)

    # add colorbar
    cbar = fig.colorbar(ax.collections[0], ax=ax)
    cbar.set_label(color_column)

    # set labels and title
    ax.set_xlabel(metric_1)
    ax.set_ylabel(metric_2)
    ax.set_title(f'{metric_1} vs. {metric_2} at threshold {idx} for {color_column}')

    # addnd
    ax.legend(title='Trainer Type', loc='upper right')

    # # set x range
    ax.set_xlim(0, 600)
    ax.set_ylim(0.9825, 1.001)

    # display the plot
    plt.show()

for i, column_name in enumerate(custom_metric_columns):
    idx_column_name = custom_metric_idx_columns[i]
    plot_custom_metric(column_name, idx_column_name)

In [None]:
for i, column_name in enumerate(average_metric_columns):
    idx_column_name = average_metric_idx_columns[i]
    plot_custom_metric(column_name, idx_column_name)

In [None]:
for i, column_name in enumerate(best_f1_columns):
    idx_column_name = best_f1_idx_columns[i]
    plot_custom_metric(column_name, idx_column_name)

In [None]:
# for func_name in custom_function_names:

#     new_column_name = f"{func_name}_best_custom_metric"
#     if new_column_name not in df.columns:
#         df[new_column_name] = np.nan
    
#     second_column_name = f"{func_name}_best_custom_metric_idx"

#     f1_counter_T = all_sae_results[func_name]["f1_counter"]
#     best_idx = torch.argmax(f1_counter_T)

#     for autoencoder_group_path in autoencoder_group_paths:
#         folders = eval_sae.get_nested_folders(autoencoder_group_path)
        
#         for autoencoder_path in folders:
#             f1_T = all_sae_results[autoencoder_group_path][autoencoder_path][func_name]
#             best_f1 = f1_T[best_idx]
#             df.loc[df["autoencoder_path"] == autoencoder_path, new_column_name] = best_f1.item()
#             df.loc[df["autoencoder_path"] == autoencoder_path, second_column_name] = best_idx.item()