# Visualize othello board with highlighted background

Input to the model: "int", doesn't contain middle pieces
For plotting and correct conversion to board visualization: "string", does mind middle pieces 

## Setup

In [None]:
import os
import pickle

import torch as t
from huggingface_hub import hf_hub_download
import einops

from circuits.dictionary_learning.dictionary import AutoEncoder, AutoEncoderNew, GatedAutoEncoder
from circuits.utils import (
    othello_hf_dataset_to_generator,
    get_model,
    get_submodule,
)

from feature_viz_othello_utils import (
    get_acts_IEs_VB,
    plot_lenses,
    plot_mean_metrics,
    plot_top_k_games
)

repo_dir = '/home/can/chess-gpt-circuits'
# repo_dir = "/home/adam/chess-gpt-circuits/"
device = 'cuda:0'

In [None]:
# download data from huggingface if needed
if not os.path.exists(f'{repo_dir}/autoencoders/othello_5-21'):
    hf_hub_download(repo_id='adamkarvonen/othello_saes', filename='othello_5-21.zip', local_dir=f'{repo_dir}/autoencoders')
    # unzip the data
    os.system(f'unzip {repo_dir}/autoencoders/othello_5-21.zip -d autoencoders')

In [None]:
# load SAE
ae_type = 'standard'
trainer_id = 4

ae_path = f'{repo_dir}/autoencoders/othello_5-21/othello-{ae_type}/trainer{trainer_id}'
if ae_type == 'standard':
    ae = AutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'gated':
    ae = GatedAutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'standard_new':
    ae = AutoEncoderNew.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
else:
    raise ValueError('Invalid ae_type')

In [None]:
# load model and data
layer = 5
context_length = 59
activation_dim = 512  # output dimension of the layer
model_name = "Baidicoot/Othello-GPT-Transformer-Lens"
dataset_name = "taufeeque/othellogpt"

model = get_model(model_name, device)
submodule = get_submodule(model_name, layer, model)

data = othello_hf_dataset_to_generator(
    dataset_name, context_length=context_length, split="train", streaming=True
)

In [None]:
games_per_batch = 10 # B
n_batches = 10 # N

games_batch_NBS = [[next(data) for _ in range(games_per_batch)] for _ in range(n_batches)]
games_batch_NBS = t.tensor(games_batch_NBS, device=device)
print(f'game_batch shape: {games_batch_NBS.shape}')

In [None]:
# load feature analysis results
def to_device(d, device=device):
    if isinstance(d, t.Tensor):
        return d.to(device)
    if isinstance(d, dict):
        return {k: to_device(v, device) for k, v in d.items()}


with open(os.path.join(ae_path, 'indexing_None_n_inputs_1000_feature_labels.pkl'), 'rb') as f:
    feature_labels = pickle.load(f)
feature_labels = to_device(feature_labels)
print(feature_labels.keys())


with open (os.path.join(ae_path, 'n_inputs_1000_evals.pkl'), 'rb') as f:
    eval_results = pickle.load(f)
print(eval_results.keys())
print(f"L0: {eval_results['eval_results']['l0']}")

# Here's the interactive part

In [None]:
idx = 29
feat_idx = int(feature_labels['alive_features'][idx])

In [None]:
feature_acts_BS, acts_per_tokenembed_VB, ie_feature_to_logits_VB, ie_tokenembed_to_features_VB = get_acts_IEs_VB(
    model, ae, games_batch_NBS, submodule, feat_idx=feat_idx, compute_ie_embed=True, device=device
)

print(f'feat_idx: {feat_idx}')
print(f'games_batch shape: {games_batch_NBS.shape}')
print(f'feature_acts shape: {feature_acts_BS.shape}')
print(f'acts_per_tokenembed shape: {acts_per_tokenembed_VB.shape}')
print(f'ie_feature_to_logits shape: {ie_feature_to_logits_VB.shape}')
print(f'ie_tokenembed_to_features shape: {ie_tokenembed_to_features_VB.shape}')

## Weight space
Cosine sim of feature decoder vector with 
- token_embed
- unembed

In [None]:
plot_lenses(model, ae, feat_idx, device)

## Activation space

1. Feature act per input_token
2. IE input_token --> feature
3. IE feature --> logit

- mean over N games
- topK over N games

In [None]:
plot_mean_metrics(acts_per_tokenembed_VB, ie_tokenembed_to_features_VB, ie_feature_to_logits_VB, feat_idx, n_games=games_per_batch*n_batches, device=device)

In [None]:
# Choose metric to retrieve top k games
k = 10
sort_metric = 'activation' # ['activation', 'ie_embed','ie_logit']

games_batch_BS = games_batch_NBS.view(-1, context_length)

plot_top_k_games(
    feat_idx,
    games_batch_BS,
    feature_acts_BS,
    acts_per_tokenembed_VB,
    ie_tokenembed_to_features_VB,
    ie_feature_to_logits_VB,
    sort_metric=sort_metric,
    k=k,
    device=device,
)

## Board reconstruction: lookup table for high precision classifiers

In [None]:
# board_state_function = 'games_batch_to_state_stack_mine_yours_BLRRC'
board_state_function = 'games_batch_to_state_stack_mine_yours_blank_mask_BLRRC'
# board_state_function = 'games_batch_to_valid_moves_BLRRC'

# There's 11 thresholds in the lookup table. If 0, the lookup table is constructed from every activation
# If 1, it's constructed from all activations above 10% of that feature's max activation. If 2, 20% etc.
# 1 is a reasonable default
threshold_idx = 1 

demo_idx = None

for alive_feature_index in range(70):
    num_classified_squares = feature_labels[board_state_function][threshold_idx][alive_feature_index].sum()
    if num_classified_squares > 0:
        print(f'Feature {alive_feature_index} has {num_classified_squares} classified squares')

        demo_idx = alive_feature_index

if demo_idx is None:
    raise ValueError('No features have any classified squares')

In [None]:
feat_idx = 29
print(f'feat_idx: {feat_idx}')

In [None]:
print("Board states that the feature classifies according to Adam's measurements:")
print((feature_labels[board_state_function][threshold_idx][feat_idx] == 1).nonzero())
print("Number of such board states:")
print((feature_labels[board_state_function][threshold_idx][feat_idx] == 1).sum())

In [None]:
import numpy as np
import matplotlib.pyplot as plt

NOT_CLASSIFIED_VALUE = -9

def get_feature_label_classified_squares(feature_labels, board_state_function, threshold_idx, feature_idx) -> t.Tensor:
    sae_feature_board_state_RRC = feature_labels[board_state_function][threshold_idx][feature_idx]
    sae_feature_board_state_RR = t.argmax(sae_feature_board_state_RRC, dim=-1)
    sae_feature_board_state_RR -= 1

    zero_positions_RR = t.all(sae_feature_board_state_RRC == 0, dim=-1)
    sae_feature_board_state_RR[zero_positions_RR] = NOT_CLASSIFIED_VALUE
    return sae_feature_board_state_RR

def plot_board_categorical(board, feat_idx, board_state_function):
    # Define color map
    color_map = {-1: 'black', 0: 'grey', 1: 'white', NOT_CLASSIFIED_VALUE: 'yellow', -3: 'red'}
    unique_labels = np.unique(board)
    colors = [color_map[label] for label in unique_labels]
    cmap = plt.matplotlib.colors.ListedColormap(colors)

    # Initialize plot
    fig, ax = plt.subplots()
    board_indices = np.vectorize(lambda x: np.where(unique_labels == x)[0][0])(board.numpy())
    cax = ax.imshow(board_indices, cmap=cmap)

    # Create a color bar with the correct labels
    cbar = fig.colorbar(cax, ticks=range(len(unique_labels)))
    cbar.ax.set_yticklabels([color_map[label] for label in unique_labels])

    # Plot labeling
    ax.set_xticks(range(8))
    ax.set_xticklabels(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])
    ax.set_title(f'Feature #{feat_idx} is high precision classifier for:\n{board_state_function}\nGrey = Empty, Yellow = Not present in one hot vector')

    plt.show()

In [None]:
sae_feature_board_state_RR = get_feature_label_classified_squares(feature_labels, board_state_function, threshold_idx, feat_idx)
plot_board_categorical(sae_feature_board_state_RR.to('cpu'), feat_idx, board_state_function)