# (Currently chess only) Dataframe comparing SAE statistics

In [1]:
# Imports

from tqdm import tqdm
import pickle
import torch
import einops
from datasets import load_dataset
from typing import Callable, Optional
import math
import os
import itertools
import json
import gc

import pandas as pd

from dataclasses import dataclass
import torch
from nnsight import NNsight
import json
from typing import Any
from datasets import load_dataset
from einops import rearrange
from jaxtyping import Int, Float, jaxtyped
from torch import Tensor
import os
from tqdm import tqdm
from transformers import GPT2LMHeadModel
from transformer_lens import HookedTransformer

from circuits.othello_buffer import OthelloActivationBuffer
from circuits.dictionary_learning import AutoEncoder
from circuits.chess_utils import encode_string
from circuits.dictionary_learning import ActivationBuffer
from circuits.dictionary_learning.dictionary import AutoEncoder, GatedAutoEncoder
from circuits.dictionary_learning.trainers.gated_anneal import GatedAnnealTrainer
from circuits.dictionary_learning.trainers.gdm import GatedSAETrainer
from circuits.dictionary_learning.trainers.p_anneal import PAnnealTrainer
from circuits.dictionary_learning.trainers.standard import StandardTrainer
from circuits.dictionary_learning.evaluation import evaluate
from circuits.nanogpt_to_hf_transformers import NanogptTokenizer, convert_nanogpt_model
from circuits.eval_sae_as_classifier import (
    initialize_results_dict, 
    get_data_batch, 
    apply_indexing_function,
    construct_eval_dataset,
    construct_othello_dataset,
    prep_firing_rate_data,
)
from circuits.utils import (
    get_model, 
    get_submodule,
    get_ae_bundle,
    collect_activations_batch,
    get_nested_folders,
    get_firing_features,
    to_device,
    AutoEncoderBundle,
)
import circuits.chess_utils as chess_utils
import circuits.othello_utils as othello_utils
import circuits.othello_engine_utils as othello_engine_utils

from circuits.dictionary_learning.evaluation import evaluate

from IPython import embed

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Globals

# Dimension key (from https://medium.com/@NoamShazeer/shape-suffixes-good-coding-style-f836e72e24fd):
# F  = features and minibatch size depending on the context (maybe this is stupid)
# B = batch_size
# L = seq length (context length)
# T = thresholds
# R = rows (or cols)
# C = classes for one hot encoding

home_dir = '/share/u/can'
repo_dir = f'{home_dir}/chess-gpt-circuits'

DEVICE = 'cuda:0'
torch.set_grad_enabled(False)
batch_size = 32
feature_batch_size = batch_size
n_inputs = 2048 # Length of the eval dataset
GAME = "chess" # "chess" or "othello"

models_path = repo_dir + "/models/"

In [3]:
# Load dataset and init game specific variables

if GAME == "chess":
    othello = False

    autoencoder_group_paths = ["/autoencoders/group1/"]
    custom_functions = [chess_utils.board_to_piece_state] #, chess_utils.board_to_pin_state]
    model_name = "adamkarvonen/8LayerChessGPT2"
    # data = construct_eval_dataset(custom_functions, n_inputs, models_path=models_path, device=DEVICE)
    indexing_functions = [chess_utils.get_even_list_indices]

elif GAME == "othello":
    othello = True

    autoencoder_group_paths = ["/autoencoders/othello_layer0/"]
    # autoencoder_group_paths = ["autoencoders/othello_layer0/", "autoencoders/othello_layer5_ef4/"]
    custom_functions = [
            # othello_utils.games_batch_no_last_move_to_state_stack_BLRRC,
            othello_utils.games_batch_to_state_stack_BLRRC,
            othello_utils.games_batch_to_state_stack_mine_yours_BLRRC,
        ]
    model_name = "Baidicoot/Othello-GPT-Transformer-Lens"
    # data = construct_othello_dataset(custom_functions, n_inputs, models_path=models_path, device=DEVICE)
    indexing_functions = [None]  # I'm experimenting with these for Othello
else:
    raise ValueError("Invalid game")

## General dataset statistic

This is only dataset dependent, but not SAE dependent and can be calculated once after loading the dataset

In [4]:
def get_true_board_state_counts(pgn_strings):
    # Find the true counts of board states over all movers and games in the dataset
    # This could be calculated within the board_to_piece_state evaluation!
    true_board_states_counts = chess_utils.create_state_stacks(pgn_strings, chess_utils.board_to_piece_state)
    true_board_states_counts = chess_utils.state_stack_to_one_hot(
        chess_utils.config_lookup[chess_utils.board_to_piece_state.__name__], 
        DEVICE, 
        true_board_states_counts)
    true_board_states_counts = true_board_states_counts.sum(dim=(0,1))
    true_board_states_counts.shape # [RRC]
    return true_board_states_counts

## SAE specific statistic

In [5]:
# Standard evals
def do_standard_evals(results, ae_bundle):
    eval_results = evaluate(
        ae_bundle.ae,
        ae_bundle.buffer,
        max_len=ae_bundle.context_length,
        batch_size=min(512, batch_size), # min(n_eval_samples, activation_buffer_out_batch_size) matters
        io="out",
        device=DEVICE,
        n_batches=1000
    )
    for k, v in eval_results.items():
        results[k] = v
    return results

In [6]:
# Evaluation of custom functions
def eval_custom_fn(
    results,
    n_act_threshs,
    alive_features_F,
    max_activations_F,
    ae_bundle,
    pgn_strings,
    custom_functions,
    encoded_inputs,
    firing_rate_n_inputs,
    indexing_function
):
    num_features = len(alive_features_F)
    print(
        f"Out of {ae_bundle.dictionary_size} features, on {firing_rate_n_inputs} activations, {num_features} are alive."
    )

    assert len(pgn_strings) >= n_inputs
    assert n_inputs % batch_size == 0

    n_iters = n_inputs // batch_size
    # We round up to ensure we don't ignore the remainder of features
    num_feature_iters = math.ceil(num_features / feature_batch_size)

    thresholds_T = torch.linspace(0, 1, n_act_threshs).to(DEVICE)
    thresholds_TF11 = einops.repeat(thresholds_T, "T -> T F 1 1", F=num_features)
    max_activations_1F11 = einops.repeat(max_activations_F, "F -> 1 F 1 1")
    thresholds_TF11 = thresholds_TF11 * max_activations_1F11

    for i in tqdm(range(n_iters), desc="Aggregating statistics"):
        start = i * batch_size
        end = (i + 1) * batch_size
        pgn_strings_BL = pgn_strings[start:end]
        encoded_inputs_BL = encoded_inputs[start:end]
        encoded_inputs_BL = torch.tensor(encoded_inputs_BL).to(DEVICE)

        batch_data = get_data_batch(data, pgn_strings_BL, start, end, custom_functions, DEVICE)

        all_activations_FBL, encoded_token_inputs = collect_activations_batch(
            ae_bundle, encoded_inputs_BL, alive_features_F
        )

        if indexing_function is not None:
            all_activations_FBL, batch_data = apply_indexing_function(
                pgn_strings[start:end], all_activations_FBL, batch_data, DEVICE, indexing_function
            )
        # For thousands of features, this would be many GB of memory. So, we minibatch.
        for feature in range(num_feature_iters):
            f_start = feature * feature_batch_size
            f_end = min((feature + 1) * feature_batch_size, num_features)
            f_batch_size = f_end - f_start

            activations_FBL = all_activations_FBL[
                f_start:f_end
            ]  
            
            thresholds_TF11_slice = thresholds_TF11[:, f_start:f_end, :, :]
            # NOTE: Now F == feature_batch_size
            # Maybe that's stupid and inconsistent and I should use a new letter for annotations
            # I'll roll with it for now


            ### Aggregate batch statistics
            active_indices_TFBL = activations_FBL > thresholds_TF11_slice
            active_counts_TF = einops.reduce(active_indices_TFBL, "T F B L -> T F", "sum")
            off_counts_TF = einops.reduce(~active_indices_TFBL, "T F B L -> T F", "sum")

            results["on_count"][:, f_start:f_end] += active_counts_TF
            results["off_count"][:, f_start:f_end] += off_counts_TF

            for custom_function in custom_functions:
                on_tracker_TFRRC = results[custom_function.__name__]["on"]
                off_tracker_FTRRC = results[custom_function.__name__]["off"]

                boards_BLRRC = batch_data[custom_function.__name__]
                boards_TFBLRRC = einops.repeat(
                    boards_BLRRC,
                    "B L R1 R2 C -> T F B L R1 R2 C",
                    F=f_batch_size,
                    T=thresholds_TF11_slice.shape[0],
                )

                # TODO The next 2 operations consume almost all of the compute. I don't think it will work,
                # but maybe we can only do 1 of these operations?
                active_boards_sum_TFRRC = einops.reduce(
                    boards_TFBLRRC * active_indices_TFBL[:, :, :, :, None, None, None],
                    "T F B L R1 R2 C -> T F R1 R2 C",
                    "sum",
                )
                off_boards_sum_TFRRC = einops.reduce(
                    boards_TFBLRRC * ~active_indices_TFBL[:, :, :, :, None, None, None],
                    "T F B L R1 R2 C -> T F R1 R2 C",
                    "sum",
                )

                on_tracker_TFRRC[:, f_start:f_end, :, :, :] += active_boards_sum_TFRRC
                off_tracker_FTRRC[:, f_start:f_end, :, :, :] += off_boards_sum_TFRRC

                results[custom_function.__name__]["on"] = on_tracker_TFRRC
                results[custom_function.__name__]["off"] = off_tracker_FTRRC

    return results

In [7]:
# Precision, recall, and F1

def get_classification_metrics(results, true_board_states_counts):
    precision_thresh = 0.9
    recall_thresh = 0.5
    f1_thresh = 0.5
    threshs = [precision_thresh, recall_thresh, f1_thresh]
    eps = 1e-8

    true_pos_TFRRC = results['board_to_piece_state']['on'] 
    pos_all_TF = results['on_count']
    true_all_RRC = true_board_states_counts

    precision = true_pos_TFRRC / (pos_all_TF[:, :, None, None, None] +eps) # Note that a feature which always fires (piece present/absent) will have a precision of 1
    recall = true_pos_TFRRC / (true_all_RRC[None, None, :, :, :] +eps)
    f1 = 2 * (precision * recall) / (precision + recall + eps)
    metrics_TFRRC = [precision, recall, f1]

    # Apply threshold
    counts_TFRRC = [metric > thresh for metric, thresh in zip(metrics_TFRRC, threshs)]

    # Drop empty square state counts
    for i in range(len(counts_TFRRC)):
        counts_TFRRC[i][..., 6] = False


    ### Fraction of features with high metric on at least one board state
    # High metric for at least one board state
    counts_any_board_TF = [metric.any(dim=(-1,-2,-3)) for metric in counts_TFRRC]

    # Report fraction of all features for count_as_firing_threshold = 0
    frac_any_board_nonzero_1 = [metric[0].float().mean() for metric in counts_any_board_TF]

    # Report fraction of all features for any threshold (choose threshold per feature that maximizes ratio)
    frac_any_board_best_1 = [metric.any(dim=0).float().mean() for metric in counts_any_board_TF]


    ### Fraction of board states that have at least one feature with high metric
    # Check for each board state whether at least one feature has a high metric (using count_as_firing_threshold = 0)
    counts_any_feature_nonzero_RCC = [metric[0].any(dim=0) for metric in counts_TFRRC]

    # Check for each board state whether at least one feature has a high metric (for any count_as_firing threshold)
    counts_any_feature_best_RCC = [metric.any(dim=(0,1)) for metric in counts_TFRRC]

    # Fraction of individual board states at least one feature has a high metric
    frac_any_feature_nonzero_RCC = [metric.float().mean() for metric in counts_any_feature_nonzero_RCC]
    frac_any_feature_best_RCC = [metric.float().mean() for metric in counts_any_feature_best_RCC]

    print(frac_any_board_nonzero_1)
    print(frac_any_board_best_1)
    print(frac_any_feature_nonzero_RCC)
    print(frac_any_feature_best_RCC)

    names = ['precision', 'recall', 'f1']
    for i, (name, t) in enumerate(zip(names, threshs)):
        results[f'frac_any_board_per_feature_act-nonzero_{name}-{t}'] = frac_any_board_nonzero_1[i].item()
        results[f'frac_any_board_per_feature_act-best_{name}-{t}'] = frac_any_board_best_1[i].item()
        results[f'frac_any_feature_per_board_act-nonzero_{name}-{t}'] = frac_any_feature_nonzero_RCC[i].item()
        results[f'frac_any_feature_per_board_act-best_{name}-{t}'] = frac_any_feature_best_RCC[i].item()

    return results

## Loop over SAEs

In [8]:
# Choose aes and indexing functions

# This could be computed once before the loop if adapting loading pgn_strings
# true_board_state_counts = get_true_board_state_counts(pgn_strings)

sweep_results = {}
sweep_result_keys = ['l0', 'frac_variance_explained', 'cossim', 'l2_ratio', 'frac_any_board_per_feature_act-nonzero_precision-0.9', 'frac_any_board_per_feature_act-best_precision-0.9', 'frac_any_feature_per_board_act-nonzero_precision-0.9', 'frac_any_feature_per_board_act-best_precision-0.9', 'frac_any_board_per_feature_act-nonzero_recall-0.5', 'frac_any_board_per_feature_act-best_recall-0.5', 'frac_any_feature_per_board_act-nonzero_recall-0.5', 'frac_any_feature_per_board_act-best_recall-0.5', 'frac_any_board_per_feature_act-nonzero_f1-0.5', 'frac_any_board_per_feature_act-best_f1-0.5', 'frac_any_feature_per_board_act-nonzero_f1-0.5', 'frac_any_feature_per_board_act-best_f1-0.5']

all_autoencoder_paths = []
for group_path in autoencoder_group_paths:
    all_autoencoder_paths += get_nested_folders(repo_dir + group_path) 

param_combinations = list(itertools.product(all_autoencoder_paths, indexing_functions))

for ae_dir, idx_fn in param_combinations:
    print(f'ae_dir: {ae_dir}')
    print(f'idx_fn: {idx_fn}\n')

# autoencoder_path, indexing_function = param_combinations[1]

for autoencoder_path, indexing_function in tqdm(param_combinations, desc="Autoencoder loop", total=len(param_combinations)):
    torch.cuda.empty_cache()
    gc.collect()
    
    indexing_function_name = "None"
    if indexing_function is not None:
        indexing_function_name = indexing_function.__name__

    print(f"Autoencoder: {autoencoder_path}")
    print(f"Indexing function: {indexing_function_name}")

    # TODO Function below manipulates the loaded data. If we change that, we can load data once and for all at the top of the file
    data = construct_eval_dataset(custom_functions, n_inputs, models_path=models_path, device=DEVICE)
    data, ae_bundle, pgn_strings, encoded_inputs = prep_firing_rate_data(
        autoencoder_path, batch_size, models_path, model_name, data, DEVICE, n_inputs, othello
    )

    firing_rate_n_inputs = min(int(n_inputs * 0.5), 1000) * ae_bundle.context_length
    # TODO: Custom thresholds per feature based on max activations
    alive_features_F, max_activations_F = get_firing_features(
        ae_bundle, firing_rate_n_inputs, batch_size, DEVICE
    )
    true_board_states_counts = get_true_board_state_counts(pgn_strings)
    assert true_board_states_counts is not None

    # initialize result dictionary
    n_act_threshs = 10
    results = initialize_results_dict(custom_functions, n_act_threshs, alive_features_F, DEVICE)

    # Standard evaluation metrics
    print('do_standard_evals')
    results = do_standard_evals(results, ae_bundle)
    del ae_bundle.buffer
    
    # Do custom eval metrics
    print('do custom eval metrics')
    results = eval_custom_fn(
        results,
        n_act_threshs,
        alive_features_F,
        max_activations_F,
        ae_bundle,
        pgn_strings,
        custom_functions,
        encoded_inputs,
        firing_rate_n_inputs,
        indexing_function,
    )

    torch.cuda.empty_cache()
    gc.collect()

    results = get_classification_metrics(results, true_board_states_counts)
    ae_name = autoencoder_path.split('/')[-1]
    sweep_results[ae_name] = {}
    for sweep_key in sweep_result_keys:
        sweep_results[ae_name][sweep_key] = results[sweep_key]

ae_dir: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=16_lr=1e-03_l1=1e-03_layer=5/
idx_fn: <function get_even_list_indices at 0x7f8abad05800>

ae_dir: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=4_lr=1e-03_l1=1e-03_layer=5/
idx_fn: <function get_even_list_indices at 0x7f8abad05800>

ae_dir: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=4_lr=1e-03_l1=1e-01_layer=5/
idx_fn: <function get_even_list_indices at 0x7f8abad05800>

ae_dir: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=16_lr=1e-03_l1=1e-01_layer=5/
idx_fn: <function get_even_list_indices at 0x7f8abad05800>

ae_dir: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=16_lr=1e-03_l1=3e-01_layer=5/
idx_fn: <function get_even_list_indices at 0x7f8abad05800>

ae_dir: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=4_lr=1e-03_l1=3e-01_layer=5/
idx_fn: <function get_even_list_indices at 0x7f8abad05800>



Autoencoder loop:   0%|          | 0/6 [00:00<?, ?it/s]

Autoencoder: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=16_lr=1e-03_l1=1e-03_layer=5/
Indexing function: get_even_list_indices






Collecting features: 100%|██████████| 8000/8000 [00:05<00:00, 1533.22it/s]


do_standard_evals
do custom eval metrics
Out of 8192 features, on 256000 activations, 4872 are alive.




In [None]:
df = pd.DataFrame.from_dict(sweep_results, orient='index')
df.round(2)


In [None]:
df

Unnamed: 0,l0,frac_variance_explained,cossim,l2_ratio,frac_any_board_per_feature_act-nonzero_precision-0.9,frac_any_board_per_feature_act-best_precision-0.9,frac_any_feature_per_board_act-nonzero_precision-0.9,frac_any_feature_per_board_act-best_precision-0.9,frac_any_board_per_feature_act-nonzero_recall-0.5,frac_any_board_per_feature_act-best_recall-0.5,frac_any_feature_per_board_act-nonzero_recall-0.5,frac_any_feature_per_board_act-best_recall-0.5,frac_any_board_per_feature_act-nonzero_f1-0.5,frac_any_board_per_feature_act-best_f1-0.5,frac_any_feature_per_board_act-nonzero_f1-0.5,frac_any_feature_per_board_act-best_f1-0.5
/share/u/can/chess-gpt-circuits/autoencoders/group1/ef=16_lr=1e-03_l1=1e-03_layer=5/,509.153748,-5.060123,0.92518,1.098365,"tensor(0.8660, device='cuda:0')","tensor(0.8660, device='cuda:0')","tensor(0.1310, device='cuda:0')","tensor(0.1983, device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')"
/share/u/can/chess-gpt-circuits/autoencoders/group1/ef=4_lr=1e-03_l1=1e-03_layer=5/,1066.594971,0.9990734,0.9996,0.99981,"tensor(0.9692, device='cuda:0')","tensor(0.9692, device='cuda:0')","tensor(0.0733, device='cuda:0')","tensor(0.1959, device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')"
/share/u/can/chess-gpt-circuits/autoencoders/group1/ef=4_lr=1e-03_l1=1e-01_layer=5/,24.903343,0.9563578,0.957291,0.908624,"tensor(0.9652, device='cuda:0')","tensor(0.9652, device='cuda:0')","tensor(0.1322, device='cuda:0')","tensor(0.2584, device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')"
/share/u/can/chess-gpt-circuits/autoencoders/group1/ef=16_lr=1e-03_l1=1e-01_layer=5/,20.512125,0.9309957,0.930462,0.884988,"tensor(0.6577, device='cuda:0')","tensor(0.6577, device='cuda:0')","tensor(0.1310, device='cuda:0')","tensor(0.2091, device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')"
/share/u/can/chess-gpt-circuits/autoencoders/group1/ef=16_lr=1e-03_l1=3e-01_layer=5/,2.149719,-1560396.0,0.498544,18.309023,"tensor(0.3735, device='cuda:0')","tensor(0.3735, device='cuda:0')","tensor(0.1118, device='cuda:0')","tensor(0.1298, device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')"
/share/u/can/chess-gpt-circuits/autoencoders/group1/ef=4_lr=1e-03_l1=3e-01_layer=5/,3.810344,0.8801003,0.866959,0.799832,"tensor(0.8378, device='cuda:0')","tensor(0.8378, device='cuda:0')","tensor(0.1538, device='cuda:0')","tensor(0.2175, device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')","tensor(0., device='cuda:0')"
