# Find lookup tables mapping high precision classifiers to 
Load reconstruction results for layer 5 Othello

This notebook always uses the sae_feature_index per default, instead of the alive index

In [None]:
# Setup
# Imports
import os
import pickle

import torch as t
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
import importlib
import numpy as np
import einops

from circuits.dictionary_learning.dictionary import AutoEncoder, AutoEncoderNew, GatedAutoEncoder, IdentityDict
from circuits.utils import (
    othello_hf_dataset_to_generator,
    get_model,
    get_submodule,
)

from feature_viz_othello_utils import (
    get_acts_IEs_VN,
    plot_lenses,
    plot_mean_metrics,
    plot_top_k_games
)

import circuits.utils as utils
import circuits.analysis as analysis
import feature_viz_othello_utils as viz_utils
from circuits.othello_engine_utils import to_board_label



device = 'cuda:0'
repo_dir = '/share/u/can/chess-gpt-circuits'
# repo_dir = "/home/adam/chess-gpt-circuits"

In [None]:
layer = 5
# node_type = "sae_feature"
node_type = "mlp_neuron"


if node_type == "sae_feature":
    ae_group_name = 'all_layers_othello_p_anneal_0530_with_lines'
    ae_type = 'p_anneal'
    trainer_id = 0
    ae_path = f'{repo_dir}/autoencoders/{ae_group_name}/layer_{layer}/trainer{trainer_id}'
elif node_type == "mlp_neuron":
    ae_group_name = 'othello_mlp_acts_identity_aes' # with_lines
    ae_type = 'identity'
    ae_path = f'{repo_dir}/autoencoders/{ae_group_name}/layer_{layer}'
else:
    raise ValueError('Invalid node_type')

# download data from huggingface if needed
if not os.path.exists(f'{repo_dir}/autoencoders/{ae_group_name}'):
    hf_hub_download(repo_id='adamkarvonen/othello_saes', filename=f'{ae_group_name}.zip', local_dir=f'{repo_dir}/autoencoders')
    # unzip the data
    os.system(f'unzip {repo_dir}/autoencoders/{ae_group_name}.zip -d {repo_dir}/autoencoders')

# Initialize the autoencoder
if ae_type == 'standard' or ae_type == 'p_anneal':
    ae = AutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'gated' or ae_type == 'gated_anneal':
    ae = GatedAutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'standard_new':
    ae = AutoEncoderNew.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'identity':
    ae = IdentityDict()
else:
    raise ValueError('Invalid ae_type')

In [None]:
model_name = "Baidicoot/Othello-GPT-Transformer-Lens"
model = get_model(model_name, device)

In [None]:
# Load results files

# load feature analysis results
def to_device(d, device=device):
    if isinstance(d, t.Tensor):
        return d.to(device)
    if isinstance(d, dict):
        return {k: to_device(v, device) for k, v in d.items()}


with open (os.path.join(ae_path, 'indexing_None_n_inputs_1000_results.pkl'), 'rb') as f:
    results = pickle.load(f)
results = utils.to_device(results, device)
print(results.keys())

feature_labels, misc_stats = analysis.analyze_results_dict(results, "", device, save_results=False, verbose=False, print_results=False, significance_threshold=100)
print(feature_labels.keys())

with open (os.path.join(ae_path, 'n_inputs_1000_evals.pkl'), 'rb') as f:
    eval_results = pickle.load(f)
print(eval_results.keys())
print(f"L0: {eval_results['eval_results']['l0']}")

In [None]:
# bs_function = 'games_batch_to_valid_moves_BLRRC'
# bs_function = 'games_batch_to_state_stack_mine_yours_blank_mask_BLRRC'
# bs_function = 'games_batch_to_state_stack_lines_mine_BLRCC'
bs_function = 'games_batch_to_state_stack_length_lines_mine_BLRCC'
# bs_function = 'games_batch_to_state_stack_opponent_length_lines_mine_BLRCC'

alive_to_feat_idx = {v.item(): i for i, v in enumerate(feature_labels['alive_features'])}
n_features_alive = len(alive_to_feat_idx)

In [None]:
results[bs_function].keys()

In [None]:
results[bs_function]['on'].shape, results[bs_function]['all'].shape

In [None]:
(results[bs_function]['on'].max(dim=0).values.max(dim=0).values < results[bs_function]['all']).sum()

In [None]:
on_and_off = results[bs_function]['on'] + results[bs_function]['off']
for thresh in range(1, 10):
    for freat in range(n_features_alive):
        assert t.all(on_and_off[thresh, freat] == results[bs_function]['all'])

## Recall

recall = TP / all_T

In [None]:
epsilon = 1e-6
recall_TFRRC = results[bs_function]['on'] / (results[bs_function]['all'] + epsilon)
precision_TFRRC = results[bs_function]['on'] / (results['on_count'][:, :, None, None, None] + epsilon)

# f1_TFRRC = 2 * recall_TFRRC * precision_TFRRC / (recall_TFRRC + precision_TFRRC + epsilon)
# recall_TFRRC = precision_TFRRC

In [None]:
# Drop values with incorrect initialization in the first evaluation run
recall_TFSqLenDir = einops.rearrange(recall_TFRRC, 'T F R1 R2 (Len Dir) -> T F (R1 R2) Len Dir', Len=6, Dir=8)
# recall_TFSqLenDir = recall_TFSqLenDir[:, :, :, 1:-1, :]

# Lookup table feature, indices with recall above T_recall
T_recall = 0.95
high_recall_TFSqLenDir = (recall_TFSqLenDir > T_recall)

In [None]:
high_recall_TFSqLenDir.shape

In [None]:
# Choose T_fire with the maximum hrc features
# Looks like noise? what's the random baseline?

T_fire_hrc_count = high_recall_TFSqLenDir.sum(dim=(1,2,3,4))
plt.bar(t.arange(T_fire_hrc_count.shape[0]).cpu().detach().numpy(), T_fire_hrc_count.cpu().detach().numpy())
plt.xlabel('T_fire')
plt.ylabel('log(Number of HRC features)')
plt.title(f'HRC {node_type} for valid_moves')
plt.yscale('log')
plt.show()

T_fire_max_hrc = T_fire_hrc_count.argmax().item()
print(f'the T_fire with the maximum hrc features is {T_fire_max_hrc}')

In [None]:
# Aggregate over T_fire
high_recall_FSqLenDir = t.any(high_recall_TFSqLenDir, dim=0).int()

## Lookup tables

### Lookup: feature --> bs

In [None]:
lines_per_feature = high_recall_FSqLenDir.sum(dim=(1,2,3))

counts = plt.hist(lines_per_feature.cpu().detach().numpy())
plt.xlabel('Number of lines per feature')
plt.ylabel('log(Count of features)')
plt.title(f'HRC {node_type}')
plt.yscale('log')
counts

In [None]:
# Indices of high_recall_F where value ==1
high_recall_FSq = t.any(high_recall_FSqLenDir, dim=(-2, -1)).int()
high_recall_F = t.sum(high_recall_FSq, dim=-1)
feat_idx_single_square_any_number_of_lines = t.where(high_recall_F == 1)[0]
print(f'Number of features with high recall for a single square: {feat_idx_single_square_any_number_of_lines.shape[0]}')
high_recall_filtered_FSqLenDir = high_recall_FSqLenDir[feat_idx_single_square_any_number_of_lines]

## DLA vs valid_move HRC

### Lookup bs --> feat

In [None]:
features_per_square = high_recall_filtered_FSqLenDir.sum(dim=(0,2,3))

In [None]:
counts = plt.hist(features_per_square.cpu().detach().numpy())
plt.xlabel('Number of features per square, any line')
plt.ylabel('Count of squares')
plt.title(f'HRC {node_type} for Line (Sq * Len * Dir)\n filtered for HRC corresponding to a single square')

counts

In [None]:
import feature_viz_othello_utils
importlib.reload(feature_viz_othello_utils) 

fig, ax = plt.subplots()
viz_utils.visualize_board_from_tensor(ax, features_per_square, title=f'Number of HRC {node_type}s per valid move', cmap='Blues')

In [None]:
# we already filter for features with high recall for exactly 1 square
# for a single square, print all the board configurations that have a high recall feature

In [None]:
eights = [[-1, 0], [-1, 1], [0, 1], [1, 1], [1, 0], [1, -1], [0, -1], [-1, -1]]
color_map = {-1: 'black', 0: 'grey', 1: 'gold', -9: 'white', -3: 'green'}
color_lbl = {'Mine': -1, 'Empty': 0, 'Yours': 1, 'Not classified': -9, 'Valid move': -3}

# Function mapping (Sq, Len, Dir) to a (64,) board tensor
def to_board_tensor(square_idxs, lengths, directions, device, opponent_only=True):
    board_tensor = t.ones((len(square_idxs), 8, 8), device=device) * color_lbl['Not classified']
    for i, (square_idx, len_idx, dir_idx) in enumerate(zip(square_idxs, lengths, directions)):
        x, y = square_idx // 8, square_idx % 8
        dx, dy = eights[dir_idx]

        if opponent_only is False:
            board_tensor[i, x, y] = color_lbl['Valid move']
        else:
            board_tensor[i, x, y] = color_lbl['Empty']
        for _ in range(1, len_idx + 2):
            x += dx
            y += dy
            if x < 0 or x >= 8 or y < 0 or y >= 8:
                print('Out of bounds')
                break
            board_tensor[i, x, y] = color_lbl['Yours']

        if opponent_only is False:
            x += dx
            y += dy
            board_tensor[i, x, y] = color_lbl['Mine']
    return board_tensor

def plot_board_categorical(fig, axs, boards, node_idxs, node_type):
    # Define color ma
    colors = list(color_map.values())
    cmap = plt.matplotlib.colors.ListedColormap(colors)
    label_to_enumerate = {label: i for i, label in enumerate(color_map.keys())}
    vmin=0
    vmax=len(color_map)-1
    norm = plt.matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)

    
    # Plot each board
    for ax, board, feat_idx in zip(axs.flat, boards, node_idxs):
        board_indices = np.vectorize(lambda x: label_to_enumerate[x])(board)
        cax = ax.imshow(board_indices, cmap=cmap, norm=norm)

        # Plot labeling
        ax.set_xticks(range(8))
        ax.set_xticklabels(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])
        # ax.set_title(f'{node_type} #{feat_idx}', fontsize=10)

    cbar = fig.colorbar(cax, ax=axs, norm=norm, orientation='vertical', ticks=range(len(color_lbl)))
    cbar.ax.set_yticklabels(list(color_lbl.keys()))

# Test
# fig, axs = plt.subplots(1, 3, figsize=(15, 5))
# btensor = to_board_tensor([0,1,2], [0,1,2], [4, 4, 4], device, opponent_only=False).cpu().detach().numpy()
# plot_board_categorical(fig, axs, btensor, [0,0,0], node_type)

In [None]:
len(feat_idx_single_square_any_number_of_lines)

In [None]:
HRC_nonzero = high_recall_filtered_FSqLenDir.nonzero()
HRC_features = HRC_nonzero[:, 0].unique()

for feat_idx in HRC_features:
    feature_idxs, square_idxs, lengths, directions = HRC_nonzero[HRC_nonzero[:, 0] == feat_idx].T
    boards = to_board_tensor(square_idxs, lengths, directions, device, opponent_only=False)

    plot_cols = 6
    plot_rows = (len(boards) + plot_cols-1) // plot_cols
    fig, axs = plt.subplots(plot_rows, plot_cols, figsize=(12, plot_rows+1))
    fig.subplots_adjust(hspace=0.4, wspace=0.4)  # Adjust spacing between subplots
    mlp_idx = feat_idx_single_square_any_number_of_lines[feat_idx]
    fig.suptitle(f'{node_type} #{mlp_idx}')
    
    plot_board_categorical(fig, axs, boards.cpu().detach().numpy(), feature_idxs.cpu().detach().numpy(), node_type)
    # Remove empty subplots
    n_empty = plot_cols*plot_rows - len(boards)
    for i in range(n_empty):
        fig.delaxes(axs.flatten()[-i-1])

    plt.show()

    if feat_idx > 20:
        break

In [None]:
valid_idxs_SqLenDir = t.zeros((0, 3), dtype=t.int)
valid_board = t.zeros((8, 8), dtype=t.int)
eights = [(0, 1), (1, 0), (1, 1), (1, -1), (0, -1), (-1, 0), (-1, -1), (-1, 1)]
for square_idx in range(64):
    r = square_idx // 8
    c = square_idx % 8
    for direction_idx, (dx, dy) in enumerate(eights):
        x, y = r + 2*dx, c + 2*dy
        length = 0
        while 0 <= x < 8 and 0 <= y < 8:
            idx_SqLenDir = t.tensor([square_idx, length, direction_idx], dtype=t.int).view(1, 3)
            valid_idxs_SqLenDir = t.cat([valid_idxs_SqLenDir, idx_SqLenDir], dim=0)
            valid_board[r, c] += 1
            # Update for next iteration
            x += dx
            y += dy
            length += 1

print(f'Total number of valid lines: {valid_board.sum()}')

# Plotting the valid_board
plt.imshow(valid_board.cpu().detach().numpy(), cmap='viridis')
plt.xticks(range(8))
plt.yticks(range(8))
plt.gca().set_xticklabels(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])

# Annotate each square with its value
for i in range(valid_board.size(0)):
    for j in range(valid_board.size(1)):
        plt.text(j, i, f'{valid_board[i, j].item()}', ha='center', va='center', color='white')
plt.title('Number of line configurations that make a valid move')

plt.show()

## Threshold sensitivity

In [None]:
square_idxs, length_idxs, direction_idxs = valid_idxs_SqLenDir.T

T_recall_space = t.cat((t.linspace(0, 0.95, 50), t.linspace(0.95, 1, 50)))
frac_valid_lines_classified = t.zeros(T_recall_space.shape)
for i, T_recall in enumerate(T_recall_space):
    # Apply T_recall
    hrc_TFSqLenDir = (recall_TFSqLenDir > T_recall).int().cpu()

    # HRC for any T_fir
    hrc_FSqLenDir = t.any(hrc_TFSqLenDir, dim=0)

    # Only features with high recall for a single square
    hrc_FSq = t.any(hrc_FSqLenDir, dim=(-2, -1))
    hrc_F = t.sum(hrc_FSq, dim=-1)
    hrc_feat_idx_single_square_any_number_of_lines = t.where(hrc_F == 1)[0]
    hrc_FSqLenDir = hrc_FSqLenDir[hrc_feat_idx_single_square_any_number_of_lines]
    
    # Select valid lines
    valid_LenDirs = hrc_FSqLenDir[:, square_idxs, length_idxs, direction_idxs]
    frac_valid_lines_classified[i] = valid_LenDirs.sum() / (valid_board.sum() - 4*19)

In [None]:
plt.scatter(T_recall_space.cpu().detach().numpy(), frac_valid_lines_classified.cpu().detach().numpy(), zorder=10)
plt.xlabel('T_recall')
plt.ylabel(f'Fraction of valid lines classified with high recall')
plt.grid(alpha=0.5, zorder=1)
# plt.ylim(-0.05, 1)

In [None]:
# Apply T_recall
T_recall = 0.95
hrc_TFSqLenDir = (recall_TFSqLenDir > T_recall).int().cpu()

# HRC for any T_fire
hrc_FSqLenDir = t.any(hrc_TFSqLenDir, dim=0)

# Only features with high recall for a single square
hrc_FSq = t.any(hrc_FSqLenDir, dim=(-2, -1))
hrc_F = t.sum(hrc_FSq, dim=-1)
hrc_feat_idx_single_square_any_number_of_lines = t.where(hrc_F == 1)[0]
hrc_single_FSqLenDir = hrc_FSqLenDir[hrc_feat_idx_single_square_any_number_of_lines]
hrc_single_F = hrc_single_FSqLenDir.any(dim=(-3, -2, -1))
print(f'Number of features with high recall for a single square: {hrc_single_F.sum()}')

hrc_FSqLenDir[~hrc_feat_idx_single_square_any_number_of_lines] = 0
hrc_FSqLenDir.nonzero()[:, :2]

## Precision

In [None]:
bs_function_VM = 'games_batch_to_valid_moves_BLRRC'
precision_VM_TFRRC = results[bs_function_VM]['on'] / (results['on_count'][:, :, None, None, None] + epsilon)
precision_VM_TFSq = einops.rearrange(precision_VM_TFRRC, 'T F R1 R2 1 -> T F (R1 R2 1)')
precision_VM_TFSq.shape

In [None]:
T_precision_space = t.cat((t.linspace(0, 0.95, 50), t.linspace(0.95, 1, 50)))
frac_valid_moves_classified = t.zeros(T_precision_space.shape)
precision_F_to_Sq = []
for i, T_precision in enumerate(T_precision_space):
    # Apply T_precision
    hpc_TFSq = (precision_VM_TFSq > T_precision).int().cpu()

    # HPC for any T_fire
    hpc_FSq = t.any(hpc_TFSq, dim=0)

    # Only features with high precision for a single square
    hpc_F = t.sum(hpc_FSq, dim=-1)
    hpc_feat_idx_single_square_any_number_of_lines = t.where(hpc_F == 1)[0]
    hpc_single_FSq = hpc_FSq[hpc_feat_idx_single_square_any_number_of_lines]
    hpc_Sq = hpc_single_FSq.any(dim=0)
    frac_valid_moves_classified[i] = hpc_Sq.sum() / 60
    hpc_FSq[~hpc_feat_idx_single_square_any_number_of_lines] = 0
    precision_F_to_Sq.append(t.nonzero(hpc_FSq))

In [None]:
plt.scatter(T_precision_space.cpu().detach().numpy(), frac_valid_moves_classified.cpu().detach().numpy(), zorder=10)
plt.xlabel('T_precision')
plt.ylabel(f'Fraction of valid moves classified with high precision')
plt.ylim(-0.05, 1.05)
plt.grid(alpha=0.5, zorder=1)

In [None]:
# Apply T_precision 
T_precision = 0.95
hpc_TFSq = (precision_VM_TFSq > T_precision).int().cpu()

# HPC for any T_fire
hpc_FSq = t.any(hpc_TFSq, dim=0)

# Only features with high precision for a single square
hpc_F = t.sum(hpc_FSq, dim=-1)
hpc_feat_idx_single_square_any_number_of_lines = t.where(hpc_F == 1)[0]
hpc_single_FSq = hpc_FSq[hpc_feat_idx_single_square_any_number_of_lines]
hpc_single_F = hpc_FSq.any(dim=1)
print(f'Number of features with high precision for a single square: {hpc_single_F.sum()}')

hpc_FSq[~hpc_feat_idx_single_square_any_number_of_lines] = 0
hpc_F_to_Sq = t.nonzero(hpc_FSq)

In [None]:
hpc_F_to_Sq

In [None]:
# Intersection of HRC and HPC
hrc_single_expanded_F, hpc_single_expanded_F = t.zeros(2048, dtype=t.int), t.zeros(2048, dtype=t.int)
hrc_single_expanded_F[hrc_feat_idx_single_square_any_number_of_lines] = 1
hpc_single_expanded_F[hpc_feat_idx_single_square_any_number_of_lines] = 1
hrc_hpc_single_F = hrc_single_expanded_F * hpc_single_expanded_F
print(f'Number of features with high recall and high precision for a single square: {hrc_hpc_single_F.sum()}')

In [None]:
hrc_indexes_single_F = t.nonzero(hrc_single_expanded_F)
hpc_indexes_single_F = t.nonzero(hpc_single_expanded_F)
hrc_hpc_indexes_single_F = t.nonzero(hrc_hpc_single_F)

In [None]:
hrc_indexes_single_F.shape, hpc_indexes_single_F.shape, hrc_hpc_indexes_single_F.shape
indexes_dict = {
    'high_precision': hpc_indexes_single_F.squeeze().tolist(),
    'high_recall': hrc_indexes_single_F.squeeze().tolist(),
    'high_precision_and_recall': hrc_hpc_indexes_single_F.squeeze().tolist()
    }

# export the indexes with json
import json
with open(os.path.join(ae_path, 'hpc_hrc_indexes_dict.json'), 'w') as f:
    json.dump(indexes_dict, f)

In [None]:
ae_path