# Visualize othello board with highlighted background

Input to the model: "int", doesn't contain middle pieces
For plotting and correct conversion to board visualization: "string", does mind middle pieces 

In [None]:
import os
import pickle
from IPython.display import HTML

import torch as t
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download

from circuits.dictionary_learning.buffer import NNsightActivationBuffer
from circuits.dictionary_learning.dictionary import AutoEncoder, AutoEncoderNew, GatedAutoEncoder
import circuits.othello_utils as othello_utils
from circuits.utils import (
    othello_hf_dataset_to_generator,
    get_model,
    get_submodule,
)

repo_dir = '/home/can/chess-gpt-circuits'
# repo_dir = "/home/adam/chess-gpt-circuits/"
device = 'cuda:0'

In [None]:
# download data from huggingface if needed
if not os.path.exists(f'{repo_dir}/autoencoders/othello_5-21'):
    hf_hub_download(repo_id='adamkarvonen/othello_saes', filename='othello_5-21.zip', local_dir=f'{repo_dir}/autoencoders')
    # unzip the data
    os.system(f'unzip {repo_dir}/autoencoders/othello_5-21.zip -d autoencoders')

In [None]:
# load SAE
ae_type = 'standard'
trainer_id = 4

ae_path = f'{repo_dir}/autoencoders/othello_5-21/othello-{ae_type}/trainer{trainer_id}'
if ae_type == 'standard':
    ae = AutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'gated':
    ae = GatedAutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'standard_new':
    ae = AutoEncoderNew.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
else:
    raise ValueError('Invalid ae_type')

with open (os.path.join(ae_path, 'indexing_None_n_inputs_1000_feature_labels.pkl'), 'rb') as f:
    lookup_table = pickle.load(f)

print(lookup_table.keys())

with open (os.path.join(ae_path, 'n_inputs_1000_evals.pkl'), 'rb') as f:
    eval_results = pickle.load(f)

print(eval_results.keys())
print(f"L0: {eval_results['eval_results']['l0']}")

In [None]:
# load information about features

def to_device(d, device=device):
    if isinstance(d, t.Tensor):
        return d.to(device)
    if isinstance(d, dict):
        return {k: to_device(v, device) for k, v in d.items()}
    
ae = ae.to(device)
with open(os.path.join(ae_path, 'indexing_None_n_inputs_1000_feature_labels.pkl'), 'rb') as f:
    feature_labels = pickle.load(f)
feature_labels = to_device(feature_labels)

In [None]:
# load model and data

layer = 5
context_length = 59
activation_dim = 512  # output dimension of the layer
model_name = "Baidicoot/Othello-GPT-Transformer-Lens"
dataset_name = "taufeeque/othellogpt"

model = get_model(model_name, device)
submodule = get_submodule(model_name, layer, model)

model = model.to(device)

data = othello_hf_dataset_to_generator(
    dataset_name, context_length=context_length, split="train", streaming=True
)
buffer = NNsightActivationBuffer(
    data,
    model,
    submodule,
    n_ctxs=8e3,
    ctx_len=context_length,
    refresh_batch_size=128,
    io="out",
    d_submodule=activation_dim,
    device=device,
)

In [None]:
def convert_othello_dataset_sample_to_board(sample_i, move_idx=None):
    if type(sample_i) == t.Tensor:
        sample_i = sample_i.tolist()
    context = [othello_utils.itos[s] for s in sample_i]
    if move_idx is not None:
        context = context[:move_idx+1]
    board_state_RR = othello_utils.games_batch_to_state_stack_mine_yours_BLRRC([context])[0][-1]
    board_state_RR = t.argmax(board_state_RR, dim=-1) - 1
    return board_state_RR

game_idx = 0
move_idx = 20
sample_game = buffer.text_batch()[game_idx] # in "model input format"
sample_board = convert_othello_dataset_sample_to_board(sample_game, move_idx=move_idx)
# sample_board

In [None]:
def plot_othello_board_highlighted(true_board_RR, bg_board_RR=None, title=''):
    """
    Plots a comparison of the true and reconstructed Othello boards using matplotlib.

    Args:
    true_board (torch.Tensor): A 2D tensor representing the true Othello board.
    recon_board (torch.Tensor): A 2D tensor representing the reconstructed Othello board.
    """

    true_color_map = {-1: 'black', 0: 'white', 1: 'cornsilk'}
    if bg_board_RR is None:
        bg_board_RR = t.zeros_like(true_board_RR)
        cmap = plt.matplotlib.colors.ListedColormap(['white'])
        print_color_bar = False
        vmin = 0
        vmax = 0
    else:
        bg_max_abs = t.abs(bg_board_RR).max().item()
        if bg_board_RR.min() < 0:
            cmap = "RdBu"
            vmin = -bg_max_abs
            vmax = bg_max_abs
        else:
            cmap = "Blues"
            vmin = 0
            vmax = bg_board_RR.max().item()
        print_color_bar = True

    # Create a figure and axis for the plot
    fig, ax = plt.subplots(figsize=(4,4))
    plt.imshow(bg_board_RR, cmap=cmap, vmin=vmin, vmax=vmax)
    # add circles on each square with black borders
    for i in range(8):
        for j in range(8):
            plt.gca().add_patch(plt.Rectangle((j-0.5, i-0.5), 1, 1, fill=False, edgecolor='black', lw=0.5))
            if true_board_RR[i, j].item() == 0:
                continue
            circle = plt.Circle((j, i), 0.3, color=true_color_map[true_board_RR[i, j].item()], fill=True)
            circle_edges = plt.Circle((j, i), 0.3, color='black', fill=False)
            plt.gca().add_artist(circle)
            plt.gca().add_artist(circle_edges)
    plt.xticks(range(8), ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])
    plt.title(title)
    if print_color_bar:
        # align colorbar to zero
        norm = plt.Normalize(vmin=vmin, vmax=vmax)
        plt.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), ax=ax)

    # plt.savefig('othello_board_highlighted.png', dpi=300)
    plt.show()


# Test visualization
# random_bg = t.arange(-10,54).reshape(8,8)
# plot_othello_board_highlighted(sample_board, random_bg)

In [None]:
def plot_game_seq(context_i, bg_values, true_board_RR, max_act, prefix=''):
    context_s = [othello_utils.itos[s] for s in context_i]
    bg_board_RR = t.zeros((8, 8))
    for act, token in zip(bg_values, context_s):
        bg_board_RR[token // 8, token % 8] = act
    plot_othello_board_highlighted(true_board_RR, bg_board_RR=bg_board_RR, max_bg=max_act, title=prefix)

In [None]:
# functions for visualizing highlighted token sequences
def shade(value, max_value):
    if abs(value) > max_value:
        raise ValueError("Absolute value must be less than or equal to max_value.")
    
    if max_value == 0:
        return "#ffffff"
    
    normalized_value = value / max_value
    
    if normalized_value < 0:
        # Red shade for negative values
        red = 255
        green = int(255 * (1 + normalized_value))
        blue = int(255 * (1 + normalized_value))
    else:
        # Blue shade for positive values
        red = int(255 * (1 - normalized_value))
        green = int(255 * (1 - normalized_value))
        blue = 255
    
    # White color for zero value
    if value == 0:
        red = green = blue = 255
    
    # Convert RGB values to hex color code
    hex_color = "#{:02x}{:02x}{:02x}".format(red, green, blue)
    
    return hex_color

def visualize_game_seq(context_i, activations, max_value, prefix=''):
    context_s = [othello_utils.itos[s] for s in context_i]
    labeled_seq = list(map(othello_utils.to_board_label, context_s))
    html_elements = []
    for token, act in zip(labeled_seq, activations):
        hex_color = shade(act, max_value)
        s = token
        s = s.replace(' ', '&nbsp;')
        html_element = f'<span style="background-color: {hex_color}; color: black">{s}</span>'
        html_elements.append(html_element)
    
    combined_html = ' '.join(html_elements)
    combined_html = prefix + combined_html
    return HTML(combined_html)

In [None]:
def visualize(model, ae, buffer, feat_idx, k=10):
    labeled_seq = buffer.token_batch()
    with model.trace(labeled_seq, scan=False, validate=False): # use_cache=False,  output_attentions=False
        embeds = model.hook_embed.output.save()
        embeds.retain_grad()
        x = submodule.output
        f = ae.encode(x).save()
    mean_embed = embeds.value.mean(dim=(0,1))
    f = f.value[...,feat_idx]

    # get indices of top k exemplars

    def unravel_index(indices, shape):
        out = []
        for dim in reversed(shape):
            out.append(indices % dim)
            indices = indices // dim
        return tuple(reversed(out))

    flattened_f = f.flatten()
    top_values, top_indices_flattened = t.topk(flattened_f, k)
    top_indices = unravel_index(top_indices_flattened, f.shape)
    top_values.sum().backward()

    # compile top contexts and activations
    contexts, activations, attributions = [], [], []
    for i in range(k):
        context_idx, token_idx = top_indices[0][i].item(), top_indices[1][i].item()
        contexts.append(labeled_seq[context_idx, :token_idx+1].tolist())
        activations.append(f[context_idx, :token_idx+1].tolist())
        attributions.append(
            (embeds.value.grad * (embeds.value - mean_embed)).sum(dim=-1)[context_idx, :token_idx+1].tolist()
        )
    max_value = max([abs(x) for act in activations for x in act] + [abs(x) for att in attributions for x in att])

    for cnt, (context_i, activation, attribution) in enumerate(zip(contexts, activations, attributions)):
        print("="*80)
        print(f'Top {cnt+1} example:')
        board_state_RR = convert_othello_dataset_sample_to_board(context_i)
        display(visualize_game_seq(context_i, activation, max_value, prefix='feature activations: '))
        plot_game_seq(context_i, activation, board_state_RR, prefix='activations on board: ')
        display(visualize_game_seq(context_i, attribution, max_value, prefix='embedding attributions:'))
        plot_game_seq(context_i, attribution, board_state_RR, prefix='mean attribution patching effects on embed tokens: ')

    return contexts[0], activations[0], attributions[0]

In [None]:
model

In [None]:
model.W_E.shape

In [None]:
# logit lens and token_embed visualization

def cossim_logit_feature_decoder(model, ae, feat_idx):
    feat_decoder_vec = ae.decoder.weight[:, feat_idx]
    cossim = t.cosine_similarity(feat_decoder_vec, model.W_U[:, 1:].T, dim=1) # NOTE 0 is a special token?
    return cossim

def cossim_tokenembed_feature_decoder(model, ae, feat_idx):
    feat_decoder_vec = ae.decoder.weight[:, feat_idx]
    cossim = t.cosine_similarity(feat_decoder_vec, model.W_E[1:, :], dim=1)
    return cossim

def visualize_lens(ax, model, ae, feat_idx, cossim_func, title=''):
    cossim = cossim_func(model, ae, feat_idx)
    ll_board = t.zeros(64, device=device)
    print(cossim.shape)
    ll_board[othello_utils.stoi_indices] = cossim
    ll_board = ll_board.view(8, 8)

    cmap = "RdBu"
    vmin = -cossim.abs().max().item()
    vmax = cossim.abs().max().item()
    norm = plt.Normalize(vmin=vmin, vmax=vmax)
    ax.imshow(ll_board.cpu().detach().numpy(), cmap=cmap, norm=norm)
    plt.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), ax=ax)
    ax.set_xticks(range(8))
    ax.set_xticklabels(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])
    ax.set_title(f'{title} #feat {feat_idx}')

In [None]:
def plot_game_seq(ax, context_i, bg_values, true_board_RR, prefix=''):
    context_s = [othello_utils.itos[s] for s in context_i]
    bg_board_RR = t.zeros((8, 8))
    for act, token in zip(bg_values, context_s):
        bg_board_RR[token // 8, token % 8] = act
    plot_othello_board_highlighted(ax, true_board_RR, bg_board_RR=bg_board_RR, title=prefix)

def plot_othello_board_highlighted(ax, true_board_RR, bg_board_RR=None, title=''):
    true_color_map = {-1: 'black', 0: 'white', 1: 'cornsilk'}
    if bg_board_RR is None:
        bg_board_RR = t.zeros_like(true_board_RR)
        cmap = plt.matplotlib.colors.ListedColormap(['white'])
        print_color_bar = False
        vmin = 0
        vmax = 0
    else:
        bg_max_abs = t.abs(bg_board_RR).max().item()
        if bg_board_RR.min() < 0:
            cmap = "RdBu"
            vmin = -bg_max_abs
            vmax = bg_max_abs
        else:
            cmap = "Blues"
            vmin = 0
            vmax = bg_board_RR.max().item()
        print_color_bar = True

    ax.imshow(bg_board_RR, cmap=cmap, vmin=vmin, vmax=vmax)
    for i in range(8):
        for j in range(8):
            ax.add_patch(plt.Rectangle((j-0.5, i-0.5), 1, 1, fill=False, edgecolor='black', lw=0.5))
            if true_board_RR[i, j].item() == 0:
                continue
            circle = plt.Circle((j, i), 0.3, color=true_color_map[true_board_RR[i, j].item()], fill=True)
            circle_edges = plt.Circle((j, i), 0.3, color='black', fill=False)
            ax.add_artist(circle)
            ax.add_artist(circle_edges)
    ax.set_xticks(range(8))
    ax.set_xticklabels(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])
    ax.set_title(title)
    if print_color_bar:
        norm = plt.Normalize(vmin=vmin, vmax=vmax)
        plt.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), ax=ax)

def unravel_index(indices, shape):
        out = []
        for dim in reversed(shape):
            out.append(indices % dim)
            indices = indices // dim
        return tuple(reversed(out))

def visualize_combined(model, ae, buffer, feat_idx, k=10):
    labeled_seq = buffer.token_batch()
    with model.trace(labeled_seq, scan=False, validate=False): 
        embeds = model.hook_embed.output.save()
        embeds.retain_grad()
        x = submodule.output
        f = ae.encode(x).save()
    mean_embed = embeds.value.mean(dim=(0,1))
    f = f.value[...,feat_idx]

    flattened_f = f.flatten()
    top_values, top_indices_flattened = t.topk(flattened_f, k)
    top_indices = unravel_index(top_indices_flattened, f.shape)
    top_values.sum().backward()

    contexts, activations, attributions = [], [], []
    for i in range(k):
        context_idx, token_idx = top_indices[0][i].item(), top_indices[1][i].item()
        contexts.append(labeled_seq[context_idx, :token_idx+1].tolist())
        activations.append(f[context_idx, :token_idx+1].tolist())
        attributions.append(
            (embeds.value.grad * (embeds.value - mean_embed)).sum(dim=-1)[context_idx, :token_idx+1].tolist()
        )
    # Normalize with max of all act values (act usually much bigger, so not sure this makes sense)
    # max_value = max([abs(x) for act in activations for x in act] + [abs(x) for att in attributions for x in att])

    for cnt, (context_i, activation, attribution) in enumerate(zip(contexts, activations, attributions)):
        print("="*80)
        print(f'Top {cnt+1} example:')
        board_state_RR = convert_othello_dataset_sample_to_board(context_i)

        display(visualize_game_seq(context_i, activation, t.tensor(activation).abs().max(), prefix='feature activations: <br>'))
        display(visualize_game_seq(context_i, attribution, t.tensor(attribution).abs().max(), prefix='embedding attributions: <br>'))

        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
        
        plot_game_seq(axs[0], context_i, activation, board_state_RR, prefix='Activations on board')
        plot_game_seq(axs[1], context_i, attribution, board_state_RR, prefix='Attribution patching on embedding\n(mean_ablation)')
        visualize_lens(axs[2], model, ae, feat_idx, cossim_tokenembed_feature_decoder, title='Cosine similarity token_embed lens')
        visualize_lens(axs[3], model, ae, feat_idx, cossim_logit_feature_decoder, title='Cosine similarity logit lens')
        
        plt.show()

    return contexts[0], activations[0], attributions[0]

# Here's the interactive part

Key for mapping from the last entry of the board state tensor to pieces:
* 0 => black king
* 1 => black queen
* 2 => black rook
* 3 => black bishop
* 4 => black knight
* 5 => black pawn
* 6 => empty
* 7 => white pawn
* 8 => white knight
* 9 => white bishop
* 10 => white rook
* 11 => white queen
* 12 => white king

In [None]:
idx = 21

In [None]:
import matplotlib.pyplot as plt
import numpy as np

NOT_CLASSIFIED_VALUE = -9

def plot_othello_board(board):
    """
    Plots an Othello board using matplotlib with a specific color lookup for different values.

    Args:
    board (torch.Tensor): A 2D tensor representing the Othello board,
                          where 0, -1, 1, and -2 are mapped to specific colors.
    """
    # Create a color map with specific colors
    # Creating a dictionary for the color mapping
    color_map = {-1: 'black', 0: 'grey', 1: 'white', NOT_CLASSIFIED_VALUE: 'yellow', -3: 'red'}
    
    # Replace board values with corresponding colors using a numpy vectorized operation
    label_colors = np.vectorize(color_map.get)(board.numpy())

    # Create a figure and axis for the plot
    fig, ax = plt.subplots()

    # Create a color map based on the unique labels in the board
    unique_labels = np.unique(board)
    colors = [color_map[label] for label in unique_labels]
    cmap = plt.matplotlib.colors.ListedColormap(colors)

    # Map board values to indices in the unique labels
    board_indices = np.vectorize(lambda x: np.where(unique_labels == x)[0][0])(board.numpy())

    # Plot the board using imshow
    cax = ax.imshow(board_indices, cmap=cmap)

    # Create a color bar with the correct labels
    cbar = fig.colorbar(cax, ticks=range(len(unique_labels)))
    cbar.ax.set_yticklabels([color_map[label] for label in unique_labels])

    # Set the axis to be off since we don't need it for a game board representation
    ax.axis('off')

    # Add a title to the plot
    plt.title('Othello Board. Grey = Empty, Yellow = Not present in one hot vector')

    # Show the plot
    plt.show()

def get_feature_label_classified_squares(feature_labels, board_state_function, threshold_idx, feature_idx) -> t.Tensor:
    sae_feature_board_state_RRC = feature_labels[board_state_function][threshold_idx][feature_idx]
    sae_feature_board_state_RR = t.argmax(sae_feature_board_state_RRC, dim=-1)
    sae_feature_board_state_RR -= 1

    zero_positions_RR = t.all(sae_feature_board_state_RRC == 0, dim=-1)
    sae_feature_board_state_RR[zero_positions_RR] = NOT_CLASSIFIED_VALUE
    return sae_feature_board_state_RR


In [None]:
board_state_function = 'games_batch_to_state_stack_mine_yours_BLRRC'
board_state_function = 'games_batch_to_state_stack_mine_yours_blank_mask_BLRRC'
board_state_function = 'games_batch_to_valid_moves_BLRRC'

threshold_idx = 1 # There's 11 thresholds in the lookup table. If 0, the lookup table is constructed from every activation
# If 1, it's constructed from all activations above 10% of that feature's max activation. If 2, 20% etc.
# 1 is a reasonable default

demo_idx = None

for alive_feature_index in range(70):
    num_classified_squares = feature_labels[board_state_function][threshold_idx][alive_feature_index].sum()
    if num_classified_squares > 0:
        print(f'Feature {alive_feature_index} has {num_classified_squares} classified squares')

        demo_idx = alive_feature_index

if demo_idx is None:
    raise ValueError('No features have any classified squares')

In [None]:
print(feature_labels.keys())

In [None]:
# Just rerun this cell to skip through features

idx = demo_idx
idx = 36

threshold_idx = 1

print(f'idx: {idx}')
feat_idx = feature_labels['alive_features'][idx]
# ll_fig, ll_board = visualize_logit_lens(model, ae, feat_idx)
# ll_fig.show()

sae_feature_board_state_RR = get_feature_label_classified_squares(feature_labels, board_state_function, threshold_idx, idx)
plot_othello_board(sae_feature_board_state_RR.to('cpu'))

print("Board states that the feature classifies according to Adam's measurements:")
print((feature_labels[board_state_function][threshold_idx][idx] == 1).nonzero())
print("Number of such board states:")
print((feature_labels[board_state_function][threshold_idx][idx] == 1).sum())

contexts, activations, attributions = visualize_combined(model, ae, buffer, feat_idx, k=50)

idx += 1

In [None]:
dec_vec = ae.decoder.weight[:, 0]
dec_vec.shape

In [None]:
model.W_E[1:, :].T.shape

In [None]:
cossim = t.cosine_similarity(dec_vec, model.W_E[1:, :], dim=1)
dot_prod = dec_vec @ model.W_E[1:, :].T
manual_cossim = dot_prod / t.norm(dec_vec) / t.norm(model.W_E[1:, :], dim=1)
assert (cossim - manual_cossim).abs().sum() < 1e-6


In [None]:
cossim = t.cosine_similarity(dec_vec, model.W_U[:, 1:].T, dim=1) # NOTE 0 is a special token?
dot_prod = dec_vec @ model.W_U[:, 1:]
manual_cossim = dot_prod / t.norm(dec_vec) / t.norm(model.W_U[:, 1:], dim=0)
assert (cossim - manual_cossim).abs().sum() < 1e-6