# Find lookup tables mapping high precision classifiers to 
Load reconstruction results for layer 5 Othello

This notebook always uses the sae_feature_index per default, instead of the alive index

In [None]:
# Setup
# Imports
import os
import pickle

import torch as t
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
import importlib


from circuits.dictionary_learning.dictionary import AutoEncoder, AutoEncoderNew, GatedAutoEncoder, IdentityDict
from circuits.utils import (
    othello_hf_dataset_to_generator,
    get_model,
    get_submodule,
)

from feature_viz_othello_utils import (
    get_acts_IEs_VN,
    plot_lenses,
    plot_mean_metrics,
    plot_top_k_games
)

import circuits.utils as utils
import circuits.analysis as analysis
import feature_viz_othello_utils as viz_utils



device = 'cuda:0'

In [None]:
repo_dir = '/home/can/chess-gpt-circuits'
# repo_dir = "/home/adam/chess-gpt-circuits"
ae_group_name = 'othello_mlp_acts_identity_aes'


# load SAE
ae_type = 'identity'
trainer_id = 4
layer = 5




# download data from huggingface if needed
if not os.path.exists(f'{repo_dir}/autoencoders/{ae_group_name}'):
    hf_hub_download(repo_id='adamkarvonen/othello_saes', filename=f'{ae_group_name}.zip', local_dir=f'{repo_dir}/autoencoders')
    # unzip the data
    os.system(f'unzip {repo_dir}/autoencoders/{ae_group_name}.zip -d autoencoders')

# Initialize the autoencoder
ae_path = f'{repo_dir}/autoencoders/{ae_group_name}/othello-{ae_type}/trainer{trainer_id}'
ae_path = f'{repo_dir}/autoencoders/{ae_group_name}/layer_{layer}'
if ae_type == 'standard' or ae_type == 'p_anneal':
    ae = AutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'gated' or ae_type == 'gated_anneal':
    ae = GatedAutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'standard_new':
    ae = AutoEncoderNew.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'identity':
    ae = IdentityDict()
else:
    raise ValueError('Invalid ae_type')

In [None]:
# Load results files

# load feature analysis results
def to_device(d, device=device):
    if isinstance(d, t.Tensor):
        return d.to(device)
    if isinstance(d, dict):
        return {k: to_device(v, device) for k, v in d.items()}


with open (os.path.join(ae_path, 'indexing_None_n_inputs_1000_results.pkl'), 'rb') as f:
    results = pickle.load(f)
results = utils.to_device(results, device)
print(results.keys())

feature_labels, misc_stats = analysis.analyze_results_dict(results, "", device, save_results=False, verbose=False, print_results=False, significance_threshold=100)
print(feature_labels.keys())

with open (os.path.join(ae_path, 'n_inputs_1000_evals.pkl'), 'rb') as f:
    eval_results = pickle.load(f)
print(eval_results.keys())
print(f"L0: {eval_results['eval_results']['l0']}")

In [None]:
bs_function = 'games_batch_to_valid_moves_BLRRC'
T_fire = 1
alive_to_feat_idx = {v.item(): i for i, v in enumerate(feature_labels['alive_features'])}
n_features_alive = len(alive_to_feat_idx)

## Histogram number of valid_moves per feature

In [None]:
# Choose T_fire with the maximum hpc features
T_fire_hpc_count = feature_labels[bs_function].sum(dim=(1,2,3,4))
plt.scatter(t.arange(T_fire_hpc_count.shape[0]).cpu().detach().numpy(), T_fire_hpc_count.cpu().detach().numpy())
plt.xlabel('T_fire')
plt.ylabel('Number of HPC features')
plt.show()

T_fire_max_hpc = T_fire_hpc_count.argmax().item()
print(f'the T_fire with the maximum hpc features is {T_fire_max_hpc}')

### Lookup: feature --> bs

In [None]:
lookup_feature_to_bs = feature_labels[bs_function][T_fire_max_hpc]
lookup_feature_to_bs = lookup_feature_to_bs.squeeze().view(n_features_alive, -1)
lookup_feature_to_bs = [t.nonzero(lookup_feature_to_bs[i]).squeeze(dim=-1).tolist() for i in range(lookup_feature_to_bs.size(0))]

In [None]:
plt.hist([len(x) for x in lookup_feature_to_bs], bins=range(0, 9))

### Lookup: bs --> feature

In [None]:
lookup_bs_to_feat = feature_labels[bs_function][T_fire_max_hpc]
lookup_bs_to_feat = lookup_bs_to_feat.squeeze().view(n_features_alive, -1).permute(1, 0)
lookup_bs_to_feat = [t.nonzero(lookup_bs_to_feat[i]).squeeze(dim=-1).tolist() for i in range(lookup_bs_to_feat.size(0))]

In [None]:
counts = plt.hist([len(x) for x in lookup_bs_to_feat], bins=range(0, 9))
plt.xlabel('Number of features per valid_move')
plt.ylabel('Count of valid_moves')

counts

In [None]:
import feature_viz_othello_utils
importlib.reload(feature_viz_othello_utils) 

number_of_hpc_per_valid_move = t.tensor([len(x) for x in lookup_bs_to_feat])
fig, ax = plt.subplots()
viz_utils.visualize_board_from_tensor(ax, number_of_hpc_per_valid_move, title='Number of HPC features per valid move')

In [None]:
lookup_bs_to_feat[3]