# Inspect SAE activations Othello GPT 

Rico Angell trained a bunch of SAEs on OthelloGPT layer 5 resid post. Use this notebook to retrieve SAE feature activations.

Running this notebook requires cloning **THE COLLAB BRANCH** of Sam Marks' dictionary learning repo: https://github.com/saprmarks/dictionary_learning

In [None]:
import os

from huggingface_hub import hf_hub_download
from datasets import load_dataset
from transformer_lens import HookedTransformer
from nnsight import NNsight
import pickle
import torch

import sys
sys.path.append('/home/can/chess-gpt-circuits/circuits/')
# sys.path.append('your-path-to-dictionary-learning-repo')
from dictionary_learning.buffer import NNsightActivationBuffer
from dictionary_learning.dictionary import AutoEncoder, AutoEncoderNew, GatedAutoEncoder
import circuits.analysis as analysis
import circuits.utils as utils

repo_dir = "/home/can/chess-gpt-circuits/"
# repo_dir = 'your-path-to-own-experiments-repo'
device = 'cuda:0'

In [None]:
# download data from huggingface if needed
# if not os.path.exists(f'{repo_dir}/autoencoders/othello_5-21'):
#     hf_hub_download(repo_id='adamkarvonen/othello_saes', filename='othello_5-21.zip', local_dir=f'{repo_dir}/autoencoders')
#     # unzip the data
#     os.system(f'unzip {repo_dir}/autoencoders/othello_5-21.zip -d autoencoders')

In [None]:
# load SAE

device = 'cpu'

ae_type = 'standard'
ae_type = 'p_anneal'
trainer_id = 3

othello = False

if othello:
    ae_path = f'{repo_dir}/autoencoders/othello_5-21/othello-{ae_type}/trainer{trainer_id}'
else:
    ae_path = f'{repo_dir}/autoencoders/chess-trained_model-layer_5-2024-05-23/chess-trained_model-layer_5-{ae_type}/trainer{trainer_id}'

if ae_type == 'standard':
    ae = AutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'gated':
    ae = GatedAutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'standard_new':
    ae = AutoEncoderNew.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'gated_anneal':
    ae = GatedAutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
else:
    raise ValueError('Invalid ae_type')

if othello:
    with open (os.path.join(ae_path, 'indexing_None_n_inputs_1000_results.pkl'), 'rb') as f:
        results = pickle.load(f)
else:
    with open (os.path.join(ae_path, 'indexing_find_dots_indices_n_inputs_1000_results.pkl'), 'rb') as f:
        results = pickle.load(f)

results = utils.to_device(results, device)

print(results.keys())

lookup_table, misc_stats = analysis.analyze_results_dict(results, "", device, save_results=False, verbose=False, print_results=False)

with open (os.path.join(ae_path, 'n_inputs_1000_evals.pkl'), 'rb') as f:
    eval_results = pickle.load(f)

print(eval_results.keys())
print(f"L0: {eval_results['eval_results']['l0']}")

In [None]:
# load model
activation_dim = 512  # output dimension of the layer
layer = 5

if not othello:
    with open("models/meta.pkl", "rb") as f:
        meta = pickle.load(f)

    context_length = 256
    model_name = "adamkarvonen/8LayerChessGPT2"
    dataset_name = "adamkarvonen/chess_sae_text"
    data = utils.chess_hf_dataset_to_generator(
        dataset_name, meta, context_length=context_length, split="train", streaming=True
    )
    model_type = "chess"
else:
    context_length = 59
    model_name = "Baidicoot/Othello-GPT-Transformer-Lens"
    dataset_name = "taufeeque/othellogpt"
    data = utils.othello_hf_dataset_to_generator(
        dataset_name, context_length=context_length, split="train", streaming=True
    )
    model_type = "othello"

model = utils.get_model(model_name, device)
submodule = utils.get_submodule(model_name, layer, model)

games_batch_size = 500

games_batch = [next(data) for _ in range(games_batch_size)]
games_batch = torch.tensor(games_batch, device=device)
print(f'game_batch: {len(games_batch)}')


There's a couple things to watch out for. First, the lookup table only contains entries for alive SAE features. So, you use alive_features_F to index into the sae features, as shown below. s

In [None]:
for key in lookup_table:
    print(key)
    

In [None]:
alive_features_F = lookup_table['alive_features']
print(alive_features_F.shape)

feature_index = 10
sae_feature_index = alive_features_F[feature_index]

print(sae_feature_index)

decoder_weights_DF = ae.decoder.weight.data
print(decoder_weights_DF.shape)

sae_feature_decoder_vector_D = decoder_weights_DF[:, sae_feature_index]
print(sae_feature_decoder_vector_D.shape)

I recommend using `games_batch_to_state_stack_mine_yours_blank_mask_BLRRC`. In this one, blank squares aren't included in the lookup table. Otherwise, many early game features have 50+ matches, many of which appear spurious.

In [None]:
if othello:
    board_state_function = 'games_batch_to_state_stack_mine_yours_BLRRC'
    board_state_function = 'games_batch_to_state_stack_mine_yours_blank_mask_BLRRC'
else:
    board_state_function = 'board_to_pin_state'
    board_state_function = 'board_to_piece_state'

threshold_idx = 5 # There's 11 thresholds in the lookup table. If 0, the lookup table is constructed from every activation
# If 1, it's constructed from all activations above 10% of that feature's max activation. If 2, 20% etc.
# 1 is a reasonable default

demo_idx = None

for alive_feature_index in range(len(alive_features_F)):
    num_classified_squares = lookup_table[board_state_function][threshold_idx][alive_feature_index].sum()
    if num_classified_squares > 0:
        print(f'Alive Feature {alive_feature_index} (or SAE feature {alive_features_F[alive_feature_index]}) has {num_classified_squares} classified squares')

        demo_idx = alive_feature_index

if demo_idx is None:
    raise ValueError('No features have any classified squares')

In [None]:
demo_idx = 13

In [None]:
sae_feature_board_state_RRC = lookup_table[board_state_function][threshold_idx][demo_idx]
print(sae_feature_board_state_RRC.shape)

sae_feature_board_state_RR = torch.argmax(sae_feature_board_state_RRC, dim=-1)
print(sae_feature_board_state_RR.shape)

sae_feature_board_state_RR -= 1

not_classified_value = -9


# Many squares in the one hot vector are all 0. We want to set these to a value that represents "not classified"
zero_positions_RR = torch.all(sae_feature_board_state_RRC == 0, dim=-1)
print(zero_positions_RR.shape)

sae_feature_board_state_RR[zero_positions_RR] = not_classified_value

print(sae_feature_board_state_RR)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_othello_board(board):
    """
    Plots an Othello board using matplotlib with a specific color lookup for different values.

    Args:
    board (torch.Tensor): A 2D tensor representing the Othello board,
                          where 0, -1, 1, and -2 are mapped to specific colors.
    """
    # Create a color map with specific colors
    # Creating a dictionary for the color mapping
    color_map = {-1: 'black', 0: 'grey', 1: 'white', not_classified_value: 'yellow', -3: 'red'}
    
    # Replace board values with corresponding colors using a numpy vectorized operation
    label_colors = np.vectorize(color_map.get)(board.numpy())

    # Create a figure and axis for the plot
    fig, ax = plt.subplots()

    # Create a color map based on the unique labels in the board
    unique_labels = np.unique(board)
    colors = [color_map[label] for label in unique_labels]
    cmap = plt.matplotlib.colors.ListedColormap(colors)

    # Map board values to indices in the unique labels
    board_indices = np.vectorize(lambda x: np.where(unique_labels == x)[0][0])(board.numpy())

    # Plot the board using imshow
    cax = ax.imshow(board_indices, cmap=cmap)

    # Create a color bar with the correct labels
    cbar = fig.colorbar(cax, ticks=range(len(unique_labels)))
    cbar.ax.set_yticklabels([color_map[label] for label in unique_labels])

    # Set the axis to be off since we don't need it for a game board representation
    ax.axis('off')

    # Add a title to the plot
    plt.title('Othello Board. Grey = Empty, Yellow = Not present in one hot vector')

    # Show the plot
    plt.show()

plot_othello_board(sae_feature_board_state_RR)