# Find lookup tables mapping high precision classifiers to 
Load reconstruction results for layer 5 Othello

This notebook always uses the sae_feature_index per default, instead of the alive index

In [None]:
# Setup
# Imports
import os
import pickle

import torch as t
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
import importlib
import numpy as np


from circuits.dictionary_learning.dictionary import AutoEncoder, AutoEncoderNew, GatedAutoEncoder, IdentityDict
from circuits.utils import (
    othello_hf_dataset_to_generator,
    get_model,
    get_submodule,
)

from feature_viz_othello_utils import (
    get_acts_IEs_VN,
    plot_lenses,
    plot_mean_metrics,
    plot_top_k_games
)

import circuits.utils as utils
import circuits.analysis as analysis
import feature_viz_othello_utils as viz_utils
from circuits.othello_engine_utils import to_board_label



device = 'cuda:0'
repo_dir = '/home/can/chess-gpt-circuits'
# repo_dir = "/home/adam/chess-gpt-circuits"

In [None]:
layer = 5
# node_type = "sae_feature"
node_type = "mlp_neuron"


if node_type == "sae_feature":
    ae_group_name = 'all_layers_othello_p_anneal_0530'
    ae_type = 'p_anneal'
    trainer_id = 0
    ae_path = f'{repo_dir}/autoencoders/{ae_group_name}/layer_{layer}/trainer{trainer_id}'
elif node_type == "mlp_neuron":
    ae_group_name = 'othello_mlp_acts_identity_aes'
    ae_type = 'identity'
    ae_path = f'{repo_dir}/autoencoders/{ae_group_name}/layer_{layer}'
else:
    raise ValueError('Invalid node_type')

# download data from huggingface if needed
if not os.path.exists(f'{repo_dir}/autoencoders/{ae_group_name}'):
    hf_hub_download(repo_id='adamkarvonen/othello_saes', filename=f'{ae_group_name}.zip', local_dir=f'{repo_dir}/autoencoders')
    # unzip the data
    os.system(f'unzip {repo_dir}/autoencoders/{ae_group_name}.zip -d {repo_dir}/autoencoders')

# Initialize the autoencoder
if ae_type == 'standard' or ae_type == 'p_anneal':
    ae = AutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'gated' or ae_type == 'gated_anneal':
    ae = GatedAutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'standard_new':
    ae = AutoEncoderNew.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'identity':
    ae = IdentityDict()
else:
    raise ValueError('Invalid ae_type')

In [None]:
model_name = "Baidicoot/Othello-GPT-Transformer-Lens"
model = get_model(model_name, device)

In [None]:
# Load results files

# load feature analysis results
def to_device(d, device=device):
    if isinstance(d, t.Tensor):
        return d.to(device)
    if isinstance(d, dict):
        return {k: to_device(v, device) for k, v in d.items()}


with open (os.path.join(ae_path, 'indexing_None_n_inputs_1000_results.pkl'), 'rb') as f:
    results = pickle.load(f)
results = utils.to_device(results, device)
print(results.keys())

feature_labels, misc_stats = analysis.analyze_results_dict(results, "", device, save_results=False, verbose=False, print_results=False, significance_threshold=100)
print(feature_labels.keys())

with open (os.path.join(ae_path, 'n_inputs_1000_evals.pkl'), 'rb') as f:
    eval_results = pickle.load(f)
print(eval_results.keys())
print(f"L0: {eval_results['eval_results']['l0']}")

In [None]:
bs_function = 'games_batch_to_valid_moves_BLRRC'
T_fire = 1
alive_to_feat_idx = {v.item(): i for i, v in enumerate(feature_labels['alive_features'])}
n_features_alive = len(alive_to_feat_idx)

## Histogram number of valid_moves per feature

In [None]:
# Choose T_fire with the maximum hpc features
T_fire_hpc_count = feature_labels[bs_function].sum(dim=(1,2,3,4))
plt.scatter(t.arange(T_fire_hpc_count.shape[0]).cpu().detach().numpy(), T_fire_hpc_count.cpu().detach().numpy())
plt.xlabel('T_fire')
plt.ylabel('Number of HPC features')
plt.title(f'HPC {node_type} for valid_moves')
plt.show()

T_fire_max_hpc = T_fire_hpc_count.argmax().item()
print(f'the T_fire with the maximum hpc features is {T_fire_max_hpc}')

In [None]:
# Aggregate over T_fire
tensor_feature_to_bs = t.any(feature_labels[bs_function], dim=0).int()

### Lookup: feature --> bs

In [None]:
lookup_feature_to_bs = tensor_feature_to_bs.squeeze().view(n_features_alive, -1)
lookup_feature_to_bs = [t.nonzero(tensor_feature_to_bs[i]).squeeze(dim=-1).tolist() for i in range(tensor_feature_to_bs.size(0))]

In [None]:
plt.hist([len(x) for x in lookup_feature_to_bs])
plt.xlabel('Number of valid_moves per feature')
plt.ylabel('Number of features')
plt.title(f'HPC {node_type} for valid_moves')

### Lookup: bs --> feature

In [None]:
lookup_bs_to_feat = tensor_feature_to_bs.squeeze().view(n_features_alive, -1).permute(1, 0)
lookup_bs_to_feat = [t.nonzero(lookup_bs_to_feat[i]).squeeze(dim=-1).tolist() for i in range(lookup_bs_to_feat.size(0))]

In [None]:
counts = plt.hist([len(x) for x in lookup_bs_to_feat])
plt.xlabel('Number of features per valid_move')
plt.ylabel('Count of valid_moves')
plt.title(f'HPC {node_type} for valid_moves')

counts

In [None]:
import feature_viz_othello_utils
importlib.reload(feature_viz_othello_utils) 

number_of_hpc_per_valid_move = t.tensor([len(x) for x in lookup_bs_to_feat])
fig, ax = plt.subplots()
viz_utils.visualize_board_from_tensor(ax, number_of_hpc_per_valid_move, title=f'Number of HPC {node_type} per valid move', vmax = 18, cmap='inferno')

In [None]:
bs_index = 47
print(f'valid_move onto {to_board_label(bs_index)} #{bs_index}')

lookup_bs_to_feat[bs_index]

In [None]:
NOT_CLASSIFIED_VALUE = -9

def get_feature_label_classified_squares(feature_labels, bs_function, feature_idx, mark_idx_s=None) -> t.Tensor:
    sae_feature_board_state_RRC = t.any(feature_labels[bs_function], dim=0).int()[feature_idx]
    sae_feature_board_state_RR = t.argmax(sae_feature_board_state_RRC, dim=-1)
    sae_feature_board_state_RR -= 1

    zero_positions_RR = t.all(sae_feature_board_state_RRC == 0, dim=-1)
    sae_feature_board_state_RR[zero_positions_RR] = NOT_CLASSIFIED_VALUE
    if mark_idx_s is not None:
        sae_feature_board_state_RR[mark_idx_s//8, mark_idx_s%8] = -3
    return sae_feature_board_state_RR.cpu().detach().numpy()

In [None]:
def plot_board_categorical(fig, axs, boards, node_idx, node_type):
    # Define color map
    color_map = {-1: 'black', 0: 'grey', 1: 'gold', NOT_CLASSIFIED_VALUE: 'white', -3: 'green'}
    color_map_labels = {-1: 'Mine', 0: 'Empty', 1: 'Yours', NOT_CLASSIFIED_VALUE: 'Not classified', -3: 'valid_move'}
    colors = list(color_map.values())
    cmap = plt.matplotlib.colors.ListedColormap(colors)
    label_to_enumerate = {label: i for i, label in enumerate(color_map.keys())}
    vmin=0
    vmax=len(color_map)-1
    norm = plt.matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)

    
    # Plot each board
    for ax, board, feat_idx in zip(axs.flat, boards, node_idx):
        board_indices = np.vectorize(lambda x: label_to_enumerate[x])(board)
        cax = ax.imshow(board_indices, cmap=cmap, norm=norm)

        # Plot labeling
        ax.set_xticks(range(8))
        ax.set_xticklabels(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])
        ax.set_title(f'{node_type} #{feat_idx}', fontsize=10)

    cbar = fig.colorbar(cax, ax=axs, norm=norm, orientation='horizontal', ticks=range(len(color_map_labels)))
    cbar.ax.set_xticklabels(list(color_map_labels.values()))

In [None]:
bs_indices = [0,1,2,45,46,47]
for bs_index in bs_indices:

    valid_move_square_label = to_board_label(bs_index)
    nodes = lookup_bs_to_feat[bs_index]
    input_bs_function = 'games_batch_to_state_stack_mine_yours_blank_mask_BLRRC'


    boards = [get_feature_label_classified_squares(feature_labels, input_bs_function, node_idx, mark_idx_s=bs_index) for node_idx in nodes]

    n_rows = len(boards)//4+1
    n_cols = 4
    n_empty = n_rows*n_cols - len(boards)

    fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols*2.5, n_rows*3.5))
    fig.subplots_adjust(hspace=0.4, wspace=0.4)  # Adjust spacing between subplots
    plot_board_categorical(fig, axs, boards, nodes, node_type=node_type)

    # Remove empty subplots
    for i in range(n_empty):
        fig.delaxes(axs.flatten()[-i-1])

    fig.suptitle(f'HPC {node_type} for board_state, given the {node_type} is HPC for {valid_move_square_label}=valid_move', fontsize=14)
    plt.show()

## DLA vs valid_move HPC

In [None]:
# # for bs_index in range(64):
# bs_index = 50

# print(f'valid_move onto {to_board_label(bs_index)} #{bs_index}')
# nodes = lookup_bs_to_feat[bs_index]
# for node_idx in nodes:
#     print(f'neuron {node_idx}')
#     plot_lenses(model, ae, node_idx, device, node_type, layer=layer)
# For square, get all neurons that have logit lens > 0.25 with unembed vector for that square

In [None]:
model.W_out.shape, model.W_U.shape

In [None]:
# DLA histogram
import torch.nn.functional as F
wOut = F.normalize(model.W_out[5], dim=1)
wU = F.normalize(model.W_U, dim=0)
dla = wOut @ wU
high_dla = (dla > 0.25).sum(dim=0)
high_dla = high_dla.cpu().detach().numpy()
plt.bar(np.arange(len(high_dla)), high_dla, alpha=0.5, label='High DLA (cos sim > 0.25)')

n_feats_per_bs = np.array([len(lookup_bs_to_feat[i]) for i in range(len(lookup_bs_to_feat))])
plt.bar(np.arange(len(n_feats_per_bs)), n_feats_per_bs, alpha=0.5, label="High precision (> 0.95) for next valid_move")
plt.xlabel('Square index')
plt.ylabel(f'Number of {node_type}s')
plt.legend()
plt.show()

## Measure recall