# Find lookup tables mapping high precision classifiers to 
Load reconstruction results for layer 5 Othello

This notebook always uses the sae_feature_index per default, instead of the alive index

In [None]:
# Setup
# Imports
import os
import pickle

import torch as t
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
import importlib
import numpy as np
import einops

from circuits.dictionary_learning.dictionary import AutoEncoder, AutoEncoderNew, GatedAutoEncoder, IdentityDict
from circuits.utils import (
    othello_hf_dataset_to_generator,
    get_model,
    get_submodule,
)

from feature_viz_othello_utils import (
    get_acts_IEs_VN,
    plot_lenses,
    plot_mean_metrics,
    plot_top_k_games
)

import circuits.utils as utils
import circuits.analysis as analysis
import feature_viz_othello_utils as viz_utils
from circuits.othello_engine_utils import to_board_label



device = 'cuda:0'
repo_dir = '/home/can/chess-gpt-circuits'
# repo_dir = "/home/adam/chess-gpt-circuits"

In [None]:
layer = 5
# node_type = "sae_feature"
node_type = "mlp_neuron"


if node_type == "sae_feature":
    ae_group_name = 'all_layers_othello_p_anneal_0530_with_lines'
    ae_type = 'p_anneal'
    trainer_id = 0
    ae_path = f'{repo_dir}/autoencoders/{ae_group_name}/layer_{layer}/trainer{trainer_id}'
elif node_type == "mlp_neuron":
    ae_group_name = 'othello_mlp_acts_identity_aes' # with_lines
    ae_type = 'identity'
    ae_path = f'{repo_dir}/autoencoders/{ae_group_name}/layer_{layer}'
else:
    raise ValueError('Invalid node_type')

# download data from huggingface if needed
if not os.path.exists(f'{repo_dir}/autoencoders/{ae_group_name}'):
    hf_hub_download(repo_id='adamkarvonen/othello_saes', filename=f'{ae_group_name}.zip', local_dir=f'{repo_dir}/autoencoders')
    # unzip the data
    os.system(f'unzip {repo_dir}/autoencoders/{ae_group_name}.zip -d {repo_dir}/autoencoders')

# Initialize the autoencoder
if ae_type == 'standard' or ae_type == 'p_anneal':
    ae = AutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'gated' or ae_type == 'gated_anneal':
    ae = GatedAutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'standard_new':
    ae = AutoEncoderNew.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'identity':
    ae = IdentityDict()
else:
    raise ValueError('Invalid ae_type')

In [None]:
model_name = "Baidicoot/Othello-GPT-Transformer-Lens"
model = get_model(model_name, device)

In [None]:
# Load results files

# load feature analysis results
def to_device(d, device=device):
    if isinstance(d, t.Tensor):
        return d.to(device)
    if isinstance(d, dict):
        return {k: to_device(v, device) for k, v in d.items()}


with open (os.path.join(ae_path, 'indexing_None_n_inputs_1000_results.pkl'), 'rb') as f:
    results = pickle.load(f)
results = utils.to_device(results, device)
print(results.keys())

feature_labels, misc_stats = analysis.analyze_results_dict(results, "", device, save_results=False, verbose=False, print_results=False, significance_threshold=100)
print(feature_labels.keys())

with open (os.path.join(ae_path, 'n_inputs_1000_evals.pkl'), 'rb') as f:
    eval_results = pickle.load(f)
print(eval_results.keys())
print(f"L0: {eval_results['eval_results']['l0']}")

In [None]:
# bs_function = 'games_batch_to_valid_moves_BLRRC'
# bs_function = 'games_batch_to_state_stack_mine_yours_blank_mask_BLRRC'
# bs_function = 'games_batch_to_state_stack_lines_mine_BLRCC'
bs_function = 'games_batch_to_state_stack_length_lines_mine_BLRCC'
# bs_function = 'games_batch_to_state_stack_opponent_length_lines_mine_BLRCC'

alive_to_feat_idx = {v.item(): i for i, v in enumerate(feature_labels['alive_features'])}
n_features_alive = len(alive_to_feat_idx)

In [None]:
results[bs_function].keys()

In [None]:
results[bs_function]['on'].shape, results[bs_function]['all'].shape

In [None]:
(results[bs_function]['on'].max(dim=0).values.max(dim=0).values < results[bs_function]['all']).sum()

In [None]:
on_and_off = results[bs_function]['on'] + results[bs_function]['off']
for thresh in range(1, 10):
    for freat in range(n_features_alive):
        assert t.all(on_and_off[thresh, freat] == results[bs_function]['all'])

## Recall

recall = TP / all_T

In [None]:
results[bs_function]['all'].shape
eights = [[-1, 0], [-1, 1], [0, 1], [1, 1], [1, 0], [1, -1], [0, -1], [-1, -1]]


In [None]:
# 20 random ints between 0 and 1000
np.random.randint(0, 1000, 20)

In [None]:
a = t.arange(5) * 8
print(a)
results[bs_function]['all'][0,0,a]

In [None]:
epsilon = 1e-6
recall_TFRRC = results[bs_function]['on'] / (results[bs_function]['all'] + epsilon)

In [None]:
length = 2
direction = 3
t.tensor([11, 12, 13, 21, 22, 23])
einops.rearrange(t.tensor([11, 12, 13, 21, 22, 23]), '(l d) -> l d', l=length, d=direction)

In [None]:
# Drop values with incorrect initialization in the first evaluation run
recall_TFSqLenDir = einops.rearrange(recall_TFRRC, 'T F R1 R2 (Len Dir) -> T F (R1 R2) Len Dir', Len=5, Dir=8)
# recall_TFSqLenDir = recall_TFSqLenDir[:, :, :, 1:-1, :]

# Lookup table feature, indices with recall above T_recall
T_recall = 0.9
high_recall_TFSqLenDir = (recall_TFSqLenDir > T_recall)

In [None]:
high_recall_TFSqLenDir.shape

In [None]:
Sq = 0
Len = 0
t.any(high_recall_TFSqLenDir, dim=0)[:, Sq].nonzero()

In [None]:
# Choose T_fire with the maximum hrc features
# Looks like noise? what's the random baseline?

T_fire_hrc_count = high_recall_TFSqLenDir.sum(dim=(1,2,3,4))
plt.bar(t.arange(T_fire_hrc_count.shape[0]).cpu().detach().numpy(), T_fire_hrc_count.cpu().detach().numpy())
plt.xlabel('T_fire')
plt.ylabel('log(Number of HRC features)')
plt.title(f'HRC {node_type} for valid_moves')
plt.yscale('log')
plt.show()

T_fire_max_hrc = T_fire_hrc_count.argmax().item()
print(f'the T_fire with the maximum hrc features is {T_fire_max_hrc}')

In [None]:
# Aggregate over T_fire
high_recall_FSqLenDir = t.any(high_recall_TFSqLenDir, dim=0).int()

## Number of valid_moves per feature

## Lookup tables

### Lookup: feature --> bs

In [None]:
lines_per_feature = high_recall_FSqLenDir.sum(dim=(1,2,3))

counts = plt.hist(lines_per_feature.cpu().detach().numpy())
plt.xlabel('Number of lines per feature')
plt.ylabel('log(Count of features)')
plt.title(f'HRC {node_type}')
plt.yscale('log')
counts

In [None]:
# Indices of high_recall_F where value ==1
high_recall_FSq = t.any(high_recall_FSqLenDir, dim=(-2, -1)).int()
high_recall_F = t.sum(high_recall_FSq, dim=-1)
feat_idx_single_square_any_number_of_lines = t.where(high_recall_F == 1)[0]
print(f'Number of features with high recall for a single square: {feat_idx_single_square_any_number_of_lines.shape[0]}')
high_recall_filtered_FSqLenDir = high_recall_FSqLenDir[feat_idx_single_square_any_number_of_lines]

## DLA vs valid_move HRC

### Lookup bs --> feat

In [None]:
features_per_square = high_recall_filtered_FSqLenDir.sum(dim=(0,2,3))

In [None]:
counts = plt.hist(features_per_square.cpu().detach().numpy())
plt.xlabel('Number of features per square, any line')
plt.ylabel('Count of squares')
plt.title(f'HRC {node_type} for Line (Sq * Len * Dir)\n filtered for HRC corresponding to a single square')

counts

In [None]:
import feature_viz_othello_utils
importlib.reload(feature_viz_othello_utils) 

fig, ax = plt.subplots()
viz_utils.visualize_board_from_tensor(ax, features_per_square, title=f'Number of HRC {node_type} per empty square', cmap='inferno')

In [None]:
# we already filter for features with high recall for exactly 1 square
# for a single square, print all the board configurations that have a high recall feature

In [None]:
eights = [[-1, 0], [-1, 1], [0, 1], [1, 1], [1, 0], [1, -1], [0, -1], [-1, -1]]
color_map = {-1: 'black', 0: 'grey', 1: 'gold', -9: 'white', -3: 'green'}
color_lbl = {'Mine': -1, 'Empty': 0, 'Yours': 1, 'Not classified': -9, 'Valid move': -3}

# Function mapping (Sq, Len, Dir) to a (64,) board tensor
def to_board_tensor(square_idx, len_idx, dir_idx, device, opponent_only=True):
    board_tensor = t.ones(8, 8, device=device) * color_lbl['Not classified']
    x, y = square_idx // 8, square_idx % 8
    dx, dy = eights[dir_idx]

    board_tensor[x, y] = color_lbl['Empty']
    for i in range(1, len_idx + 2):
        x += dx
        y += dy
        print(f'x: {x}, y: {y}')   
        if x < 0 or x >= 8 or y < 0 or y >= 8:
            break
        board_tensor[x, y] = color_lbl['Yours']
        print(board_tensor.cpu().detach().numpy())

    if opponent_only is False:
        x += dx
        y += dy
        board_tensor[x, y] = color_lbl['Mine']
    return board_tensor

def plot_board_categorical(fig, axs, boards, node_idxs, node_type):
    # Define color ma
    colors = list(color_map.values())
    cmap = plt.matplotlib.colors.ListedColormap(colors)
    label_to_enumerate = {label: i for i, label in enumerate(color_map.keys())}
    vmin=0
    vmax=len(color_map)-1
    norm = plt.matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)

    
    # Plot each board
    for ax, board, feat_idx in zip(axs.flat, boards, node_idxs):
        board_indices = np.vectorize(lambda x: label_to_enumerate[x])(board)
        cax = ax.imshow(board_indices, cmap=cmap, norm=norm)

        # Plot labeling
        ax.set_xticks(range(8))
        ax.set_xticklabels(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])
        ax.set_title(f'{node_type} #{feat_idx}', fontsize=10)

    cbar = fig.colorbar(cax, ax=axs, norm=norm, orientation='horizontal', ticks=range(len(color_lbl)))
    cbar.ax.set_xticklabels(list(color_lbl.keys()))

In [None]:
square_idx = 0
feature_lines = high_recall_filtered_FSqLenDir[:, square_idx].nonzero()

for feat_idx, length, direction in feature_lines:
    print(f'Feature {feat_idx} with length {length} and direction {direction}')
    board_tensor = to_board_tensor(square_idx, length, direction, device)
    print(board_tensor)
    print('---------------------------------')
    break
    

In [None]:
high_recall_SqFLenDir.nonzero().shape

In [None]:


def get_feature_label_classified_squares(feature_labels, bs_function, feature_idx, mark_idx_s=None) -> t.Tensor:
    sae_feature_board_state_RRC = t.any(feature_labels[bs_function], dim=0).int()[feature_idx]
    sae_feature_board_state_RR = t.argmax(sae_feature_board_state_RRC, dim=-1)
    sae_feature_board_state_RR -= 1

    zero_positions_RR = t.all(sae_feature_board_state_RRC == 0, dim=-1)
    sae_feature_board_state_RR[zero_positions_RR] = NOT_CLASSIFIED_VALUE
    if mark_idx_s is not None:
        sae_feature_board_state_RR[mark_idx_s//8, mark_idx_s%8] = -3
    return sae_feature_board_state_RR.cpu().detach().numpy()

In [None]:
bs_indices = [0,1,2,21,22,23]
for bs_index in bs_indices:

    valid_move_square_label = to_board_label(bs_index)
    nodes = lookup_bs_to_feat[bs_index]
    input_bs_function = 'games_batch_to_state_stack_mine_yours_blank_mask_BLRRC'


    boards = [get_feature_label_classified_squares(feature_labels, input_bs_function, node_idx, mark_idx_s=bs_index) for node_idx in nodes]

    n_rows = len(boards)//4+1
    n_cols = 4
    n_empty = n_rows*n_cols - len(boards)

    fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols*2.5, n_rows*3.5))
    fig.subplots_adjust(hspace=0.4, wspace=0.4)  # Adjust spacing between subplots
    plot_board_categorical(fig, axs, boards, nodes, node_type=node_type)

    # Remove empty subplots
    for i in range(n_empty):
        fig.delaxes(axs.flatten()[-i-1])

    fig.suptitle(f'HPC {node_type} for board_state, given the {node_type} is HRC for {valid_move_square_label}=valid_move', fontsize=14)
    plt.show()