In [1]:
import torch
import numpy as np
from fancy_einsum import einsum
import chess
import numpy as np
import pickle
import logging
import plotly.graph_objects as go
from functools import partial

import chess_utils
import train_test_chess
import importlib

  from .autonotebook import tqdm as notebook_tqdm
2024-12-10 15:34:18,412 - probe_training_utils - INFO - {'vocab_size': 32, 'itos': {0: ' ', 1: '#', 2: '+', 3: '-', 4: '.', 5: '0', 6: '1', 7: '2', 8: '3', 9: '4', 10: '5', 11: '6', 12: '7', 13: '8', 14: '9', 15: ';', 16: '=', 17: 'B', 18: 'K', 19: 'N', 20: 'O', 21: 'Q', 22: 'R', 23: 'a', 24: 'b', 25: 'c', 26: 'd', 27: 'e', 28: 'f', 29: 'g', 30: 'h', 31: 'x'}, 'stoi': {' ': 0, '#': 1, '+': 2, '-': 3, '.': 4, '0': 5, '1': 6, '2': 7, '3': 8, '4': 9, '5': 10, '6': 11, '7': 12, '8': 13, '9': 14, ';': 15, '=': 16, 'B': 17, 'K': 18, 'N': 19, 'O': 20, 'Q': 21, 'R': 22, 'a': 23, 'b': 24, 'c': 25, 'd': 26, 'e': 27, 'f': 28, 'g': 29, 'h': 30, 'x': 31}}
2024-12-10 15:34:18,417 - probe_training_utils - INFO - Using device: cuda
2024-12-10 15:34:18,417 - probe_training_utils - INFO - [6, 4, 27, 9, 0, 27, 11, 0, 7, 4, 19, 28, 8]
2024-12-10 15:34:18,418 - probe_training_utils - INFO - Performing round trip test on meta


There's a bunch of setup below to get some data in some tensors that we can feed to our model.

In [2]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7a8e8f39b8c0>

In [3]:
# Flags to control logging
debug_mode = False
info_mode = True

if debug_mode:
    log_level = logging.DEBUG
elif info_mode:
    log_level = logging.INFO
else:
    log_level = logging.WARNING

# Configure logging
logging.basicConfig(level=log_level)
logger = logging.getLogger(__name__)

Here you can select which probe and model to use. By default, the model_setup.py downloads a lichess 8 layer model. We can then select a probe from saved_probes/. Ideally, this should also be a lichess probe. Then this code should auto populate parameters according to the probe's state dict.

To reproduce paper / blog post figures, set USE_16_LAYER to True and run model_setup.py on the lichess 16 layer model.

In [4]:
MODEL_DIR = "models/"
DATA_DIR = "data/"
PROBE_DIR = "linear_probes/"
SAVED_PROBE_DIR = "linear_probes/saved_probes/"
SPLIT = "test"

DEVICE = "cpu"
logger.info(f"Using device: {DEVICE}")
probe_type = 'random'
control_type = "lichess"
base_probe_name = f"tf_lens_{probe_type}_8layers_ckpt_no_optimizer_chess_piece_probe_layer_0.pth"
base_control_probe_name = f"tf_lens_{probe_type}_8layers_ckpt_no_optimizer_chess_piece_probe_layer_5.pth"

LAYER = 5

USE_16_LAYER = False

if USE_16_LAYER:
    LAYER = 11
    base_probe_name = "tf_lens_lichess_16layers_ckpt_no_optimizer_chess_piece_probe_layer_0.pth"

probe_to_test = base_probe_name.replace("layer_0", f"layer_{LAYER}")

num_games = 10
sample_size = 1
modes = 1

probe_file_location = f"{SAVED_PROBE_DIR}{probe_to_test}"
control_probe_file_location = f"{SAVED_PROBE_DIR}{base_control_probe_name}"
with open(probe_file_location, "rb") as f:
    state_dict = torch.load(f, map_location=torch.device(DEVICE))
    print(state_dict.keys())
    for key in state_dict.keys():
        if key != "linear_probe":
            print(key, state_dict[key])

    config = chess_utils.find_config_by_name(state_dict["config_name"])
    layer = state_dict["layer"]
    model_name = state_dict["model_name"]
    dataset_prefix = state_dict["dataset_prefix"]
    column_name = state_dict["column_name"]
    config.pos_start = state_dict["pos_start"]
    levels_of_interest = None
    if "levels_of_interest" in state_dict.keys():
        levels_of_interest = state_dict["levels_of_interest"]
    config.levels_of_interest = levels_of_interest
    indexing_function_name = state_dict["indexing_function_name"]
    n_layers = state_dict["n_layers"]
    

    split = SPLIT
    input_dataframe_file = f"{DATA_DIR}{dataset_prefix}_{split}.csv"
    config = chess_utils.set_config_min_max_vals_and_column_name(
        config, input_dataframe_file, dataset_prefix
    )
    misc_logging_dict = {
        "split": split,
        "dataset_prefix": dataset_prefix,
        "model_name": model_name,
        "n_layers": n_layers,
    }
print(f"This is the control probe: {control_probe_file_location}")
with open(control_probe_file_location, "rb") as f:
    control_state_dict = torch.load(f, map_location=torch.device(DEVICE))
    print(control_state_dict.keys())
    for key in control_state_dict.keys():
        if key != "linear_probe":
            print(key, control_state_dict[key])

    control_config = chess_utils.find_config_by_name(state_dict["config_name"])
    control_layer = state_dict["layer"]
    control_model_name = state_dict["model_name"]
    control_dataset_prefix = state_dict["dataset_prefix"]
    control_column_name = state_dict["column_name"]
    control_config.pos_start = state_dict["pos_start"]
    control_levels_of_interest = None
    if "levels_of_interest" in control_state_dict.keys():
        control_levels_of_interest = control_state_dict["levels_of_interest"]
    control_config.levels_of_interest = levels_of_interest
    control_indexing_function_name = state_dict["indexing_function_name"]
    control_n_layers = state_dict["n_layers"]
    

    control_split = SPLIT
    control_input_dataframe_file = f"{DATA_DIR}{dataset_prefix}_{split}.csv"
    control_config = chess_utils.set_config_min_max_vals_and_column_name(
        control_config, input_dataframe_file, dataset_prefix
    )
    misc_logging_dict = {
        "control_split": control_split,
        "control_dataset_prefix": control_dataset_prefix,
        "control_model_name": control_model_name,
        "control_n_layers": control_n_layers,
    }

INFO:__main__:Using device: cpu


dict_keys(['linear_probe', 'final_loss', 'iters', 'epochs', 'acc', 'linear_probe_name', 'layer', 'indexing_function_name', 'batch_size', 'lr', 'wd', 'pos_start', 'num_epochs', 'num_games', 'modes', 'wandb_project', 'config_name', 'column_name', 'levels_of_interest', 'split', 'dataset_prefix', 'model_name', 'n_layers', 'wandb_run_name', 'player_color'])
final_loss tensor(56.8934, requires_grad=True)
iters 50000
epochs 4
acc tensor(0.6581)
linear_probe_name chess_piece_probe
layer 5
indexing_function_name find_dots_indices
batch_size 2
lr 0.001
wd 0.01
pos_start 0
num_epochs 5
num_games 10000
modes 1
wandb_project chess_linear_probes
config_name chess_piece_probe
column_name None
levels_of_interest None
split train
dataset_prefix random
model_name tf_lens_random_8layers_ckpt_no_optimizer
n_layers 8
wandb_run_name chess_piece_probe_tf_lens_random_8layers_ckpt_no_optimizer_layer_5_indexing_find_dots_indices_max_games_10000
player_color White
This is the control probe: linear_probes/saved_p

  state_dict = torch.load(f, map_location=torch.device(DEVICE))
  control_state_dict = torch.load(f, map_location=torch.device(DEVICE))


At the end of the below cell, we index at select 1 of the num_games. The reason we do this is that with a large number of games, storing all the resid_posts and state_stacks quickly grows to many gigabytes of VRAM.

In [5]:
importlib.reload(chess_utils)

torch.set_printoptions(
    edgeitems=100,    # Increase number of edge items to display (default: 3)
    linewidth=100,    # Increase line width for tensor display
    profile="default"  # Ensure full output
)

probe_data = train_test_chess.construct_linear_probe_data(
    input_dataframe_file,
    dataset_prefix,
    n_layers,
    model_name,
    config,
    num_games,
    DEVICE,
)
if DEVICE == "cpu":
    probe_data.model.cpu()
print(f"Input dataframe file {input_dataframe_file}")

#print(probe_data.board_seqs_string)
print("AND NOW:")
num_games = 4
game = 1
games_to_run = probe_data.board_seqs_string[:num_games]
game_length_in_chars = len(games_to_run[0])
state_stacks_all_chars_MBlRR = chess_utils.create_state_stacks(games_to_run, config.custom_board_state_function)
#print(state_stacks_all_chars_MBlRR.shape)
logger.info(f"state_stack shape: {state_stacks_all_chars_MBlRR.shape}")
assert(state_stacks_all_chars_MBlRR.shape) == (modes, num_games, game_length_in_chars, config.num_rows, config.num_cols)
#This is the indices of the periods rights before the white moves
white_move_indices_BL = probe_data.custom_indices[:num_games]
#print(f"2nd transcript {games_to_run[game]}")
#print(f"and the white moves are {white_move_indices_BL[game]}")
#print(white_move_indices_BL.shape)
num_white_moves = white_move_indices_BL.shape[1]
#print(f"Num white moves {num_white_moves}")
##What about if the game is super compact and goes a move over? Ignore for now
assert(white_move_indices_BL.shape) == (num_games, num_white_moves)


  model.load_state_dict(torch.load(f"{MODEL_DIR}{model_name}.pth"))
2024-12-10 15:34:28,480 - probe_training_utils - INFO - Number of games: 10,length of a game in chars: 365
INFO:probe_training_utils:Number of games: 10,length of a game in chars: 365
2024-12-10 15:34:28,481 - probe_training_utils - INFO - 0    [15, 6, 4, 19, 23, 8, 0, 19, 30, 11, 0, 7, 4, ...
1    [15, 6, 4, 19, 30, 8, 0, 19, 28, 11, 0, 7, 4, ...
2    [15, 6, 4, 27, 9, 0, 19, 25, 11, 0, 7, 4, 25, ...
3    [15, 6, 4, 28, 9, 0, 26, 11, 0, 7, 4, 29, 8, 0...
4    [15, 6, 4, 25, 9, 0, 26, 10, 0, 7, 4, 30, 9, 0...
Name: transcript, dtype: object
INFO:probe_training_utils:0    [15, 6, 4, 19, 23, 8, 0, 19, 30, 11, 0, 7, 4, ...
1    [15, 6, 4, 19, 30, 8, 0, 19, 28, 11, 0, 7, 4, ...
2    [15, 6, 4, 27, 9, 0, 19, 25, 11, 0, 7, 4, 25, ...
3    [15, 6, 4, 28, 9, 0, 26, 11, 0, 7, 4, 29, 8, 0...
4    [15, 6, 4, 25, 9, 0, 26, 10, 0, 7, 4, 30, 9, 0...
Name: transcript, dtype: object
2024-12-10 15:34:28,484 - probe_training_utils - INF

Moving model to device:  cpu
user state dict one hot None
seqs string 0    ;1.Na3 Nh6 2.e4 d6 3.Be2 Nd7 4.b3 Ne5 5.Rb1 Kd...
1    ;1.Nh3 Nf6 2.Rg1 h6 3.Na3 c6 4.Ng5 Ne4 5.Nb1 a...
2    ;1.e4 Nc6 2.c3 a5 3.Qg4 Nd4 4.Bd3 Nc6 5.Bb5 d5...
3    ;1.f4 d6 2.g3 c6 3.a4 e5 4.c3 h5 5.Kf2 Rh6 6.R...
4    ;1.c4 d5 2.h4 Na6 3.Nh3 e5 4.Nc3 Bb4 5.Ne4 Nf6...
5    ;1.e3 g6 2.e4 a5 3.Bb5 Nc6 4.Qe2 Na7 5.Qc4 e5 ...
6    ;1.g3 g5 2.Nc3 Bh6 3.e4 Bg7 4.d3 Bh6 5.Qg4 d5 ...
7    ;1.b4 h6 2.f4 d6 3.Ba3 Bg4 4.g3 Bd7 5.Kf2 c5 6...
8    ;1.c3 b6 2.c4 g5 3.f3 f5 4.Kf2 d5 5.Qc2 g4 6.e...
9    ;1.c4 a6 2.e4 d6 3.g3 Nd7 4.Bg2 Nb6 5.g4 Be6 6...
Name: transcript, dtype: object
seqs int tensor([[15,  6,  4, 19, 23,  8,  0, 19, 30, 11,  0,  7,  4, 27,  9,  0, 26, 11,  0,  8,  4, 17, 27,
          7,  0, 19, 26, 12,  0,  9,  4, 24,  8,  0, 19, 27, 10,  0, 10,  4, 22, 24,  6,  0, 18, 26,
         12,  0, 11,  4, 26,  9,  0, 18, 27, 11,  0, 12,  4, 17, 28,  6,  0, 19, 30, 29,  9,  0, 13,
          4, 30,  9,  0, 19, 31, 28,

In [6]:
print("\nSelecting the game of interest")
print(probe_data.board_seqs_int.shape)
print(state_stacks_all_chars_MBlRR.shape)
print(white_move_indices_BL.shape)
##This is the straight un string of chars 
bss = probe_data.board_seqs_string

print(len(bss), len(bss[0]))
#print(bss[0])

game_of_interest = 2
## and this is the sequence of tokens
bsi = probe_data.board_seqs_int
bsi_interest = bsi[game_of_interest]
#print(f"before unsqueeze the shape is {bsi_interest.shape}, and it looks like {bsi_interest}")
probe_data.board_seqs_int = bsi_interest.unsqueeze(0)
#print(f"after unsqueeze the shape is {bsi_interest.shape}, and it looks like {bsi_interest}")

##Take only the string of the game of interest
probe_data.board_seqs_string = [probe_data.board_seqs_string[game_of_interest]]

probe_data.custom_indices = white_move_indices_BL[game_of_interest].unsqueeze(0)


Selecting the game of interest
torch.Size([10, 365])
torch.Size([1, 4, 365, 8, 8])
torch.Size([4, 34])
10 365


In [7]:
##select the state stacks
state_stacks_all_chars_MBlRR = state_stacks_all_chars_MBlRR[:, game_of_interest, :, :, :].unsqueeze(1)
print(f"shape of selected state stacks {state_stacks_all_chars_MBlRR.shape}")
##This makes little sense it looks like hes just recalculating the value of probedata custom indices
white_move_indices_BL = white_move_indices_BL[game_of_interest].unsqueeze(0)



shape of selected state stacks torch.Size([1, 1, 365, 8, 8])


In [8]:
##Conclusion
print(probe_data.board_seqs_int.shape)
print(state_stacks_all_chars_MBlRR.shape)
print(white_move_indices_BL.shape)
print(len(probe_data.board_seqs_string), len(probe_data.board_seqs_string[0]))

##Everything should be of size 1

torch.Size([1, 365])
torch.Size([1, 1, 365, 8, 8])
torch.Size([1, 34])
1 365


Here is an explanation of all the data we just generated:

In [9]:
print(f"All pgn strings are of length {game_length_in_chars}")
print(f"For game {game_of_interest}, the pgn string is {probe_data.board_seqs_string[0]}")
print(f"Using our encode functions, it's represented as ints that are fed as input to the GPT model with shape {probe_data.board_seqs_int.shape}")
print(f"The first 30 characters of board_seqs_ints looks like this: {probe_data.board_seqs_int[:, :30]}")
print(f"state_stacks_all_chars contains the board state at every char index in the pgn string with shape {state_stacks_all_chars_MBlRR.shape}")
print(f"white_move_indices contains the index of every white move in the pgn string with shape {white_move_indices_BL.shape}")
print(f"That means there are {num_white_moves} white moves in the game")
print(f"For example, in {probe_data.board_seqs_string[0][:14]}, the white move indices are {white_move_indices_BL[:, :2]} (the indices of each period)")



All pgn strings are of length 365
For game 2, the pgn string is ;1.e4 Nc6 2.c3 a5 3.Qg4 Nd4 4.Bd3 Nc6 5.Bb5 d5 6.Be2 Nf6 7.h4 Nh5 8.Qf5 Be6 9.Bg4 Na7 10.Qf3 c6 11.b4 Nf6 12.Qh3 a4 13.Bxe6 b5 14.Bc8 Ng4 15.h5 g6 16.Ba6 e5 17.Qxg4 Qd6 18.Bb2 Qe7 19.h6 Qb7 20.exd5 Rd8 21.Qg5 Bd6 22.Qf4 Rd7 23.Qc4 Bb8 24.Bc1 Bd6 25.dxc6 Bxb4 26.Bxb5 Rc7 27.Qb3 e4 28.Ba6 a3 29.Rh3 Qa8 30.Rh4 O-O 31.Bb2 Qe8 32.Nxa3 Be7 33.Bc8 Bc5 34.g4 Qe7 35.Qxf7+
Using our encode functions, it's represented as ints that are fed as input to the GPT model with shape torch.Size([1, 365])
The first 30 characters of board_seqs_ints looks like this: tensor([[15,  6,  4, 27,  9,  0, 19, 25, 11,  0,  7,  4, 25,  8,  0, 23, 10,  0,  8,  4, 21, 29,  9,
          0, 19, 26,  9,  0,  9,  4]])
state_stacks_all_chars contains the board state at every char index in the pgn string with shape torch.Size([1, 1, 365, 8, 8])
white_move_indices contains the index of every white move in the pgn string with shape torch.Size([1, 34])
That means t

In [10]:
##Loading the model and the probe
checkpoint = torch.load(probe_file_location, map_location=torch.device(DEVICE))
linear_probe_MDRRC = checkpoint["linear_probe"]
##this is the straight-up probe matrix (this means no activation I believe)
print(linear_probe_MDRRC.shape)

##Apparently max and min val in config are the range of the one hot encoding
##For other games I suppose? Look into this
one_hot_range = config.max_val - config.min_val + 1

board_seqs_int_BL = probe_data.board_seqs_int[:].to(DEVICE)

assert(board_seqs_int_BL.shape) == (1, game_length_in_chars)

indexed_state_stacks_MBLRR = []

torch.Size([1, 512, 8, 8, 13])


  checkpoint = torch.load(probe_file_location, map_location=torch.device(DEVICE))


In [11]:
##this sample size object is designed to handle these visualisations on multiple games
for batch_idx in range(sample_size):
    # Get the indices for the current batch
    ##what is up with the constant renaming??
    dots_indices_for_batch_L = white_move_indices_BL[batch_idx]

    # Index the state_stack for the current batch. Adding an unsqueeze operation to maintain the batch dimension.
    indexed_state_stack_MLRR = state_stacks_all_chars_MBlRR[:, batch_idx:batch_idx+1, dots_indices_for_batch_L, :, :]

    # Append the result to the list
    ##this list isnt a real list, its lenght is 1
    indexed_state_stacks_MBLRR.append(indexed_state_stack_MLRR)



Important note: At the bottom of the below cell, I currently am using softmax to view probe output probabilities. You can comment that out to view raw logits instead.

In this cell, we input the board_seqs_int to the GPT to obtain resid_post, the intermediate activations after our layer of interest. We index into resid_post using white_move_indices. These indexed resid_posts are then input to the linear probe, which outputs probe_out, a probability distribution for the state of every square on the board.

In [12]:
# Concatenate the indexed state stacks along the second dimension (batch dimension)
# Since we're maintaining the batch dimension during indexing, we don't need to add it back in

print(indexed_state_stacks_MBLRR[0].shape)
##does this renaming mean something or is it lazy name tracking?
state_stack_white_moves_MBLRR = torch.cat(indexed_state_stacks_MBLRR, dim=1)

print("state stack shapes")
print(state_stack_white_moves_MBLRR.shape)
print(state_stacks_all_chars_MBlRR.shape)

torch.Size([1, 1, 34, 8, 8])
state stack shapes
torch.Size([1, 1, 34, 8, 8])
torch.Size([1, 1, 365, 8, 8])


In [13]:
with torch.inference_mode():
    _, cache = probe_data.model.run_with_cache(board_seqs_int_BL[:, :-1], return_type=None)
    
    resid_post_BlD = cache["resid_post", layer][:, :]

#print(resid_post_BlD.shape)
assert(resid_post_BlD.shape) == (sample_size, game_length_in_chars - 1, linear_probe_MDRRC.shape[1])

In [14]:
# Initialize a list to hold the indexed state stacks
indexed_resid_posts_BLD = []

for batch_idx in range(sample_size):
    # Get the indices for the current batch
    # please stop
    dots_indices_for_batch_L = white_move_indices_BL[batch_idx]

    # Index the state_stack for the current batch
    indexed_resid_post_LD = resid_post_BlD[batch_idx, dots_indices_for_batch_L]

    # Append the result to the list
    indexed_resid_posts_BLD.append(indexed_resid_post_LD)

# Stack the indexed state stacks along the first dimension
# This results in a tensor of shape [2, 61, 8, 8] (assuming all batches have 61 indices)
resid_post_BLD = torch.stack(indexed_resid_posts_BLD)
resid_post_BLD = resid_post_BLD.to(DEVICE)
print("Resid post", resid_post_BLD.shape)

Resid post torch.Size([1, 34, 512])


In [15]:
##einstein summing convention
probe_out_MBLRRC = einsum(
    "batch pos d_model, modes d_model rows cols options -> modes batch pos rows cols options",
    resid_post_BLD,
    linear_probe_MDRRC,
)

#print(f"Probe out info: {probe_out_MBLRRC}")
probe_out_MBLRRC = probe_out_MBLRRC.log_softmax(-1)
print(f"Probe out shape: {probe_out_MBLRRC.shape}")
assert(probe_out_MBLRRC.shape) == (modes, sample_size, white_move_indices_BL.shape[1], config.num_rows, config.num_cols, one_hot_range)

Probe out shape: torch.Size([1, 1, 34, 8, 8, 13])


Here you can select which move you want to visualize (move_of_interest).

In [16]:
move_of_interest = 11
GAME_IDX = 0 # After refactoring to discard unused games, this is always 0
move_of_interest_index = white_move_indices_BL[GAME_IDX][move_of_interest] # Used to select pgn strings
move_of_interest_state_RR = state_stack_white_moves_MBLRR[0][GAME_IDX][move_of_interest]
print(move_of_interest_state_RR.shape)
print(move_of_interest_state_RR)

torch.Size([8, 8])
tensor([[ 4,  2,  3,  0,  6,  0,  2,  4],
        [ 1,  0,  0,  1,  0,  1,  1,  0],
        [ 0,  0,  1,  0,  0,  5,  0,  0],
        [ 0,  1,  0,  0,  1,  0,  3,  1],
        [-1,  0,  0, -1,  0,  0,  0,  0],
        [ 0,  0, -1,  0, -3, -2,  0,  0],
        [-2, -1,  0,  0, -1, -1, -1, -1],
        [-4,  0,  0, -5, -6, -3,  0, -4]], dtype=torch.int8)


Now we one hot encode our move_of_interest and store it in move_of_interest_state_one_hot.

In [17]:
##this is one hot encoding of the previous version, so just ones where the pieces are
state_stacks_one_hot_MBLRRC = chess_utils.state_stack_to_one_hot(modes, config.num_rows, config.num_cols, config.min_val, config.max_val, DEVICE, state_stack_white_moves_MBLRR)
print(state_stacks_one_hot_MBLRRC.shape)
assert(state_stacks_one_hot_MBLRRC.shape) == (modes, sample_size, num_white_moves, config.num_rows, config.num_cols, one_hot_range)
move_of_interest_state_one_hot_RRC = state_stacks_one_hot_MBLRRC[0][GAME_IDX][move_of_interest]
print(move_of_interest_state_one_hot_RRC.shape)

torch.Size([1, 1, 34, 8, 8, 13])
torch.Size([8, 8, 13])


We get the argmax of each square's probe probability distribution and store it in state_stacks_probe_outputs for easy graphing.

In [18]:
print(move_of_interest_state_one_hot_RRC.shape)
print(state_stacks_one_hot_MBLRRC.shape)
state_stacks_probe_outputs_MBLRR = chess_utils.one_hot_to_state_stack(probe_out_MBLRRC, config.min_val)
state_stacks_probe_outputs_MBLRR = torch.tensor(state_stacks_probe_outputs_MBLRR)
print(state_stacks_probe_outputs_MBLRR.shape)
assert(state_stacks_probe_outputs_MBLRR.shape) == (modes, sample_size, num_white_moves, config.num_rows, config.num_cols)
print(state_stacks_probe_outputs_MBLRR[0][GAME_IDX][move_of_interest])

torch.Size([8, 8, 13])
torch.Size([1, 1, 34, 8, 8, 13])
torch.Size([1, 1, 34, 8, 8])
tensor([[ 4,  2,  0,  0,  6,  3,  2,  4],
        [ 1,  1,  1,  0,  0,  1,  1,  1],
        [ 0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  1,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0, -2,  0,  0, -2,  0,  0],
        [-1, -1, -1,  0,  0, -1, -1, -1],
        [-4,  0, -3, -5, -6,  0,  0, -4]])


  state_stacks_probe_outputs_MBLRR = torch.tensor(state_stacks_probe_outputs_MBLRR)


Change blank_index, king_index, or pawn_index if you want to visualize the probe's view of other pieces. For example, if I want to see the black queen, I could set blank_index = -5 (refer to INT_TO_CHAR for the mapping)

In [23]:

INT_TO_CHAR = {
    -6: "\u265a",
    -5: "\u265b",
    -4: "\u265c",
    -3: "\u265d",
    -2: "\u265e",
    -1: "\u265f",
    0: ".",
    1: "\u2659",
    2: "\u2658",
    3: "\u2657",
    4: "\u2656",
    5: "\u2655",
    6: "\u2654",
}

# Mapping of integers to chess pieces
# I'm duplicating this from chess_utils.py for easy reference
PIECE_TO_ONE_HOT_MAPPING = {
    -6: 0,
    -5: 1,
    -4: 2,
    -3: 3,
    -2: 4,
    -1: 5,
    0: 6,
    1: 7,
    2: 8,
    3: 9,
    4: 10,
    5: 11,
    6: 12,
}

# Mapping of chess pieces to integers
PIECE_TO_INT = {
    chess.PAWN: 1,
    chess.KNIGHT: 2,
    chess.BISHOP: 3,
    chess.ROOK: 4,
    chess.QUEEN: 5,
    chess.KING: 6,
}

INT_TO_PIECE = {value: key for key, value in PIECE_TO_INT.items()}

BLANK_INDEX = PIECE_TO_ONE_HOT_MAPPING[0]
white_pawn_index = PIECE_TO_ONE_HOT_MAPPING[1]
black_king_index = PIECE_TO_ONE_HOT_MAPPING[-6]

def plot_board_state(board_state: torch.Tensor, clip_size: int = 200, show_scale: bool = False):
    # color scale: Black for -1, Gray for 0, White for 1
    # colorscale = [[0.0, 'black'], [0.5, 'gray'], [1.0, 'white']]
    colorscale = 'gray'
    if board_state.is_cuda:
        board_state = board_state.cpu()
    board_state = np.clip(board_state.numpy(), -clip_size, clip_size)

    # Create heatmap
    heatmap = go.Heatmap(z=board_state, colorscale=colorscale, showscale=show_scale)
    return heatmap

print(move_of_interest_state_one_hot_RRC[:, :, white_pawn_index])
# heatmap = plot_board_state(move_of_interest_state_one_hot[:, :, white_pawn_index], show_scale=True)

move_of_interest_probe_out = probe_out_MBLRRC[0][0][move_of_interest]
print(move_of_interest_probe_out.shape)

heatmap = plot_board_state(move_of_interest_probe_out[:, :, white_pawn_index], show_scale=True)

# Define the layout
layout = go.Layout(
    title="Chess board white pawns",
    xaxis=dict(ticks='', nticks=8),
    yaxis=dict(ticks='', nticks=8),
    autosize=False,
    width=600,
    height=600
)

# Create figure and plot
fig = go.Figure(data=[heatmap], layout=layout)
fig.show()

tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 1, 0, 1, 1, 0],
        [0, 0, 1, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 1, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int8)
torch.Size([8, 8, 13])


In [24]:
def tensor_to_text(board_state: torch.Tensor) -> np.ndarray:
    # Create a mapping from numbers to characters
    # Update this mapping according to your requirements

    # Convert the tensor to numpy array for easier processing
    board_array = board_state.numpy()

    # Create an empty array with the same shape for text
    text_array = np.empty(board_array.shape, dtype=str)

    # Fill the text array with corresponding characters
    for i in range(board_array.shape[0]):
        for j in range(board_array.shape[1]):
            text_array[i, j] = INT_TO_CHAR.get(board_array[i, j], str(board_array[i, j]))

    return text_array

def plot_board_state_with_text(board_state: torch.Tensor):
    # Convert the tensor to a text matrix
    text_matrix = tensor_to_text(board_state)

    # Define the custom colorscale
    colorscale = [
        [0, 'white'],   # Negative values
        [0.49, 'white'],
        [0.5, 'grey'],  # Zero
        [0.51, 'white'],
        [1, 'white']    # Positive values
    ]


    # Create heatmap with text and custom colorscale
    heatmap = go.Heatmap(
        z=board_state.numpy(), 
        text=text_matrix, 
        showscale=False, 
        colorscale=colorscale,
        texttemplate="%{text}",
        textfont=dict(size=48) 
    )

    return heatmap
heatmap = plot_board_state_with_text(move_of_interest_state_RR)

# Define the layout
layout = go.Layout(
    title="Chess board state with text",
    xaxis=dict(ticks='', nticks=8),
    yaxis=dict(ticks='', nticks=8),
    autosize=False,
    width=600,
    height=600
)

# Create figure and plot
fig = go.Figure(data=[heatmap], layout=layout)
fig.show()

In [20]:
move_of_interest_probe_out = probe_out_MBLRRC[0][0][move_of_interest]
print(move_of_interest_probe_out.shape)
print(move_of_interest_probe_out[:,:,1])

torch.Size([8, 8, 13])
tensor([[-11.8638,  -8.7759,  -8.5766, -10.3639, -11.9331,  -8.8207,  -9.0213, -10.1887],
        [ -9.5913,  -8.6168,  -8.2171,  -9.6126, -11.9665,  -8.5630, -10.0502,  -7.7284],
        [ -6.8680,  -7.0490,  -8.1512,  -7.1776,  -8.1446,  -6.6237,  -8.3881,  -8.5811],
        [ -9.0261,  -6.1371,  -6.2827,  -8.1011,  -7.4854,  -4.0360,  -7.0112,  -6.4512],
        [ -4.5671,  -6.1766,  -4.6725,  -6.3546,  -6.6435,  -2.9665,  -4.6150,  -5.7747],
        [ -5.5416,  -4.3494,  -5.7295,  -4.1441,  -3.1769,  -2.0063,  -5.8228,  -5.4695],
        [ -7.5855,  -5.6352,  -3.2520,  -3.5098,  -4.4906,  -4.4931,  -6.5019,  -8.2606],
        [ -7.0716,  -6.4105,  -5.8733,  -0.3633,  -5.4494,  -5.0350,  -9.3927,  -8.3793]])


In [25]:
probe_prediction_squares_RR = state_stacks_probe_outputs_MBLRR[0][GAME_IDX][move_of_interest]
print(probe_prediction_squares_RR)
def get_predictied(board_prediction_RR, piece_idx):
    return (board_prediction_RR == piece_idx).int()
print(get_predictied(probe_prediction_squares_RR, 2))

tensor([[ 4,  2,  0,  0,  6,  3,  2,  4],
        [ 1,  1,  1,  0,  0,  1,  1,  1],
        [ 0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  1,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0, -2,  0,  0, -2,  0,  0],
        [-1, -1, -1,  0,  0, -1, -1, -1],
        [-4,  0, -3, -5, -6,  0,  0, -4]])
tensor([[0, 1, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int32)


In [26]:
from plotly.subplots import make_subplots

fig_rows = 4
fig_cols = 3
fig = make_subplots(rows=fig_rows, cols=fig_cols, subplot_titles=[
    "Ground truth blank squares", "Predicted blank squares", "Confidence gradient blank squares",
    "Ground truth white pawn positions", "Predicted white pawn positions", "Confidence gradient white pawn positions",
    "Ground truth black king position", "Predicted black king position", "Confidence gradient black king position",
    "Ground truth state", "Predicted board state", "Missing pieces"
])


# Specify the size of each plot
plot_size = 400  # You can adjust this size

fig.add_trace(plot_board_state(move_of_interest_state_one_hot_RRC[:, :, BLANK_INDEX]), row=1, col=1)
#fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, BLANK_INDEX], clip_size=1), row=1, col=2)
fig.add_trace(plot_board_state(get_predictied(probe_prediction_squares_RR,BLANK_INDEX - 6)), row=1, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, BLANK_INDEX]), row=1, col=3)

fig.add_trace(plot_board_state(move_of_interest_state_one_hot_RRC[:, :, white_pawn_index]), row=2, col=1)
#fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, white_pawn_index], clip_size=1), row=2, col=2)
fig.add_trace(plot_board_state(get_predictied(probe_prediction_squares_RR,white_pawn_index -6)), row=2, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, white_pawn_index], show_scale=True), row=2, col=3)

fig.add_trace(plot_board_state(move_of_interest_state_one_hot_RRC[:, :, black_king_index]), row=3, col=1)
#fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, black_king_index], clip_size=1), row=3, col=2)
fig.add_trace(plot_board_state(get_predictied(probe_prediction_squares_RR, black_king_index - 6)), row=3, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, black_king_index]), row=3, col=3)

fig.add_trace(plot_board_state_with_text(move_of_interest_state_RR), row=4, col=1)
fig.add_trace(plot_board_state_with_text(state_stacks_probe_outputs_MBLRR[0][0][move_of_interest]), row=4, col=2)
missing_pieces_positions_RR = (move_of_interest_state_RR != state_stacks_probe_outputs_MBLRR[0][0][move_of_interest]).int()
missing_pieces_RR = (missing_pieces_positions_RR * move_of_interest_state_RR)
fig.add_trace(plot_board_state_with_text(missing_pieces_RR), row=4, col=3)

# Adjust the overall size of the figure
fig.update_layout(height=fig_rows * plot_size * 1.3, width=fig_cols * plot_size)
fig.update_annotations(dict(font=dict(size=18))) 



# Show the figure
fig.show()

This will check the percentage of squares in the sample (sample_size defaults to 1 game) where the ground truth matches the probe output.
I also do a round trip through all the transformations, which should match 100%.

In [25]:
def calculate_matching_percentage(state_stacks: torch.Tensor, probe_outputs: torch.Tensor) -> float:
    """
    Calculate the percentage of matching cells in two tensors.

    :param state_stacks: A tensor of shape [1, 1, 680, 8, 8].
    :param probe_outputs: A tensor of shape [1, 1, 680, 8, 8].
    :return: The percentage of cells that match.
    """
    # Element-wise comparison
    matches = state_stacks == probe_outputs

    # Count the number of matches
    num_matches = matches.sum().item()

    # Total number of elements
    total_elements = state_stacks.numel()

    # Calculate percentage
    percentage = (num_matches / total_elements) * 100
    print(f"Out of {total_elements} elements, {num_matches} matched, {percentage}%")

    return percentage
assert(state_stacks_probe_outputs_MBLRR.shape) == (state_stack_white_moves_MBLRR.shape)
print("Linear probe accuracy on all board squares in sample size:", calculate_matching_percentage(state_stack_white_moves_MBLRR, state_stacks_probe_outputs_MBLRR)/100)

##Round trip tests
#round_trip = chess_utils.one_hot_to_state_stack(chess_utils.state_stack_to_one_hot(modes, config.num_rows, config.num_cols, config.min_val,config.max_val, DEVICE, state_stack_white_moves_MBLRR), config.min_val)
#round_trip = torch.tensor(round_trip)
#print(round_trip.shape)
#print(state_stack_white_moves_MBLRR.shape)
#assert(round_trip.shape) == (modes, sample_size, num_white_moves, config.num_rows, config.num_cols)
#assert(round_trip.shape) == state_stack_white_moves_MBLRR.shape
#matching_percentage = calculate_matching_percentage(round_trip, state_stack_white_moves_MBLRR)
#assert(matching_percentage == 100.0)
#print(f"Round trip matching percentage: {matching_percentage}%")

Out of 2048 elements, 1534 matched, 74.90234375%
Linear probe accuracy on all board squares in sample size: 0.7490234375


Now, we can perform interventions on the model's internals and view the modified probe outputs. We can also verify the model produces legal moves under the modified state of the board.

First, we perform a sanity check to ensure that our interventions on model activations are working correctly. In this case, diff should roughly equal flip_dir.

Note that I'm only intervening on one layer here. By modifying the first for loop and training additional probes, we can easily intervene on an arbitrary amount of layers. If we were to intervene on multiple layers, we can only check that torch.allclose(diff, flip_dirs[layer], atol=1e-6) for the first layer that we intervene on.

In [26]:
probe_data.model.reset_hooks()

_, cache = probe_data.model.run_with_cache(board_seqs_int_BL.to(DEVICE)[:, :-1], return_type=None)
resid_post_BlD = cache["resid_post", layer][:, :]

r = 0
c = 0

probe_names = {}
for i in range(layer, layer + 1):
    probe_names[i] = base_probe_name.replace("layer_0", f"layer_{i}")

probes = {}

# Use this to intervene on multiple layers
for layer, probe_name in probe_names.items():
    probe_file_location = f"{SAVED_PROBE_DIR}{probe_name}"
    checkpoint = torch.load(probe_file_location, map_location=torch.device(DEVICE))
    linear_probe_MDRRC = checkpoint["linear_probe"]
    probes[layer] = linear_probe_MDRRC


flip_dirs = {}

piece1 = BLANK_INDEX
piece2 = black_king_index

for layer, linear_probe_MDRRC in probes.items():
    piece1_probe = linear_probe_MDRRC[:, :, r, c, piece1].squeeze()
    piece2_probe = linear_probe_MDRRC[:, :, r, c, piece2].squeeze()
    flip_dir = piece2_probe - piece1_probe
    flip_dir.to(DEVICE)
    flip_dirs[layer] = flip_dir

def flip_hook(resid, hook, flip_dir: torch.Tensor):
    resid[GAME_IDX, :] -= flip_dir # NOTE: We could only intervene on a single position in the sequence, but there's no harm in intervening on all of them

probe_data.model.reset_hooks()

for layer, flip_dir in flip_dirs.items():
    temp_hook_fn = partial(flip_hook, flip_dir=flip_dir)
    hook_name = f"blocks.{layer}.hook_resid_post"
    probe_data.model.add_hook(hook_name, temp_hook_fn)

print(probe_data.model.cpu())
_, modified_cache = probe_data.model.run_with_cache(board_seqs_int_BL.to(DEVICE)[:, :-1])
probe_data.model.reset_hooks()
modified_resid_post = modified_cache["resid_post", layer][:, :]

print(resid_post_BlD.shape)
print(modified_resid_post.shape)

diff = resid_post_BlD[GAME_IDX, 10, :] - modified_resid_post[GAME_IDX, 10, :]

assert torch.allclose(diff, flip_dirs[layer], atol=1e-6)
print("Flip hook test passed")

Moving model to device:  cpu
HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-7): 8 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp

Next, we load the model's vocab.

In [27]:
with open("models/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi = meta["stoi"]
itos = meta["itos"]
def encode_string(s: str) -> list[int]:
    """Encode a string into a list of integers."""
    return [stoi[c] for c in s]


def decode_list(l: list[int]) -> str:
    """Decode a list of integers into a string."""
    return "".join([itos[i] for i in l])

Next, we generate 10 characters using the model to determine the model's next move. Note that we are using argmax instead of a temperature based approach, so this will always return the most likely move.One annoying problem we deal with: In chess, the 0th row is at the bottom, which is how print(chess_board) displays everything. But, for our state stack (and any array), the 0th row is at the top.

Now, we get a pgn string up to the current move and convert it to a chess board. We use it to create an encoded model_input as well.

In [28]:
print(move_of_interest_state_RR)

pgn_string = probe_data.board_seqs_string[GAME_IDX][:move_of_interest_index + 1]
model_input = encode_string(pgn_string)
model_input = torch.tensor(model_input).unsqueeze(0).to(DEVICE)
print(model_input.shape)
board = chess_utils.pgn_string_to_board(pgn_string)

print(board)
print(board.legal_moves)

tensor([[ 4,  2,  0,  0,  6,  3,  2,  4],
        [ 0,  1,  0,  5,  0,  0,  0,  1],
        [ 3,  0,  0,  1,  0,  0,  1,  0],
        [ 1,  0,  1,  0,  0, -2,  0,  0],
        [-1,  0,  0,  0,  0, -3,  0,  0],
        [ 0,  0,  0, -1,  0,  0,  0,  0],
        [ 0, -1, -1,  0, -1,  0, -1, -1],
        [ 0, -4,  0, -5, -6, -3, -2, -4]], dtype=torch.int8)
torch.Size([1, 121])
. r . q k b n r
. p p . p . p p
. . . p . . . .
p . . . . b . .
P . P . . n . .
B . . P . . P .
. P . Q . . . P
R N . . K B N R
<LegalMoveGenerator at 0x76e0bc15b470 (Bxd6, Bc5, Bb4, Qxa5, Qxf4, Qb4, Qe3, Qc3, Qg2, Qf2, Qe2, Qc2, Qd1, Qc1, Nh3, Nf3, Ne2, Bh3, Bg2, Be2, Kf2, Kd1, Nc3, Ra2, gxf4, c5, g4, d4, h3, b3, h4, b4)>


We generate a move using the model on the original board and check that the move is legal. Next, we determine which piece was moved, and which row / column the source square of the move was.

In [29]:
model_move = chess_utils.get_model_move(probe_data.model, meta, model_input)
model_move_san = board.parse_san(model_move)
assert model_move_san in board.legal_moves

moved_piece = board.piece_at(model_move_san.from_square)
moved_piece_int = PIECE_TO_INT[moved_piece.piece_type]
moved_piece_probe_index = PIECE_TO_ONE_HOT_MAPPING[moved_piece_int]
source_square = chess.square_name(model_move_san.from_square)


r, c = chess_utils.square_to_coordinate(model_move_san.from_square)
print(r, c)

print(f"Model move: {model_move_san}, moved piece: {moved_piece}, moved piece int: {moved_piece_int}, moved piece probe index: {moved_piece_probe_index}, source square: {source_square}")

InvalidMoveError: invalid san: ''

Now, we create a modified board where the source square of the model's original move is blank.

In [None]:
modified_state_stack = state_stack_white_moves_MBLRR.clone()
modified_state_stack[0, GAME_IDX, move_of_interest, r, c] = 0
modified_move_of_interest_state = modified_state_stack[0, GAME_IDX, move_of_interest]
modified_state_stacks_one_hot = chess_utils.state_stack_to_one_hot(modes, config.num_rows, config.num_cols, config.min_val, config.max_val, DEVICE, modified_state_stack)
modified_move_of_interest_state_one_hot = modified_state_stacks_one_hot[0][GAME_IDX][move_of_interest]
modified_board = board.copy()
modified_board.set_piece_at(model_move_san.from_square, None)
print(modified_board)
print(modified_board.legal_moves)

assert modified_move_of_interest_state_one_hot.shape == move_of_interest_state_one_hot_RRC.shape
assert modified_state_stack.shape == state_stack_white_moves_MBLRR.shape
assert modified_state_stacks_one_hot.shape == state_stacks_one_hot_MBLRRC.shape

Next, we get flip_dir, which is a probe of piece * piece_coefficient - blank square * blank_coefficient. In practice, I find that it works best when blank_coefficient is 0. We subtract this flip_dir from the model's activations at every token. We generate 10 new characters using the model, and verify that the new move under this modified state is legal according to the modified state. We also save a copy of the modified activations and generate modified probe outputs.

In [None]:
_, cache = probe_data.model.run_with_cache(board_seqs_int_BL.to(DEVICE)[:, :-1], return_type=None)
resid_post_BlD = cache["resid_post", layer][:, :]

flip_dirs = {}

piece1 = BLANK_INDEX
piece1_probe = linear_probe_MDRRC[:, :, r, c, piece1].squeeze()
piece2 = moved_piece_probe_index

for layer, linear_probe_MDRRC in probes.items():
    piece2_probe = linear_probe_MDRRC[:, :, r, c, piece2].squeeze()
    flip_dir = piece2_probe - piece1_probe
    flip_dir.to(DEVICE)
    flip_dirs[layer] = flip_dir

def flip_hook(resid, hook, flip_dir: torch.Tensor):
    # print(resid[0, move_of_interest_index, :].shape)
    # print(flip_dir.shape)
    # print(piece1_probe.shape)
    # left_side = torch.dot(resid[0, move_of_interest_index, :], piece1_probe) - 3.0
    # right_side = torch.dot(flip_dir, piece1_probe)
    # scale = left_side / right_side
    # print(scale)
    
    # # Calculate scale
    # scale = left_side / right_side
    piece_coefficient = 1.0
    blank_coefficient = 0.0
    blank_probe = probes[layer][:, :, r, c, BLANK_INDEX].squeeze()
    piece_probe = probes[layer][:, :, r, c, moved_piece_probe_index].squeeze()

    flip_dir = (piece_probe * piece_coefficient) - (blank_probe * blank_coefficient)
    flip_dir = flip_dir / flip_dir.norm()
    scale = 1.0
    resid[0, :] -= scale * flip_dir # NOTE: We could only intervene on a single position in the sequence, but there's no harm in intervening on all of them

probe_data.model.reset_hooks()

for layer, flip_dir in flip_dirs.items():
    temp_hook_fn = partial(flip_hook, flip_dir=flip_dir)
    hook_name = f"blocks.{layer}.hook_resid_post"
    probe_data.model.add_hook(hook_name, temp_hook_fn)
_, modified_cache = probe_data.model.run_with_cache(board_seqs_int_BL.to(DEVICE)[:, :-1])
modified_board_model_move = chess_utils.get_model_move(probe_data.model, meta, model_input)
probe_data.model.reset_hooks()
modified_resid_post = modified_cache["resid_post", layer][:, :]


print(modified_board_model_move)
# modified_board_model_move_san = modified_board.parse_san(modified_board_model_move)
# assert modified_board_model_move_san in modified_board.legal_moves

In [None]:
print(flip_dirs[layer].shape)
print(resid_post_BlD.shape)
print(modified_resid_post.shape)

In [None]:
indexed_modified_resid_posts = []

for batch_idx in range(white_move_indices_BL.size(0)):
    dots_indices_for_batch_L = white_move_indices_BL[batch_idx]
    indexed_modified_resid_post = modified_resid_post[batch_idx, dots_indices_for_batch_L]
    indexed_modified_resid_posts.append(indexed_modified_resid_post)

# Stack the indexed state stacks along the first dimension
stacked_modified_resid_post = torch.stack(indexed_modified_resid_posts)
stacked_modified_resid_post = stacked_modified_resid_post.to(DEVICE)

assert stacked_modified_resid_post.shape == (sample_size, num_white_moves, linear_probe_MDRRC.shape[1])

modified_probe_out = einsum(
    "batch pos d_model, modes d_model rows cols options -> modes batch pos rows cols options",
    stacked_modified_resid_post,
    linear_probe_MDRRC,
)
modified_state_stacks_probe_outputs = chess_utils.one_hot_to_state_stack(modified_probe_out, config.min_val)
modified_state_stacks_probe_outputs = torch.tensor(modified_state_stacks_probe_outputs)

Now, we can graph the original and modified board states and probe outputs.

In [None]:
from plotly.subplots import make_subplots

move_of_interest_probe_out = probe_out_MBLRRC[0][0][move_of_interest]
move_of_interest_probe_out_modified = modified_probe_out[0][0][move_of_interest]
print(move_of_interest_probe_out.shape)

fig_rows = 6
fig_cols = 3
fig = make_subplots(rows=fig_rows, cols=fig_cols, subplot_titles=[
    "Chess board blank squares", "Probe output blank squares clip=2", "Probe output blank squares no clipping",
    "Chess board original piece", "Probe output original piece clip=5", "Probe output original piece no clipping",
    "Modified chess board blank squares", "Probe output blank squares clip=2", "Probe output blank squares no clipping",
    "Modified chess board original piece", "Probe output original piece clip=5", "Probe output original piece no clipping",
    "Chess board state", "Probe output board state", "Redundant probe output board state",
    "Modified chess board state", "Probe output board state", "Redundant probe output board state"
])


# Specify the size of each plot
plot_size = 400  # You can adjust this size



fig.add_trace(plot_board_state(move_of_interest_state_one_hot_RRC[:, :, BLANK_INDEX]), row=1, col=1)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, BLANK_INDEX], clip_size=2), row=1, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, BLANK_INDEX]), row=1, col=3)

fig.add_trace(plot_board_state(move_of_interest_state_one_hot_RRC[:, :, moved_piece_probe_index]), row=2, col=1)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, moved_piece_probe_index], clip_size=5), row=2, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, moved_piece_probe_index]), row=2, col=3)

fig.add_trace(plot_board_state(modified_move_of_interest_state_one_hot[:, :, BLANK_INDEX]), row=3, col=1)
fig.add_trace(plot_board_state(move_of_interest_probe_out_modified[:, :, BLANK_INDEX], clip_size=2), row=3, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out_modified[:, :, BLANK_INDEX]), row=3, col=3)

fig.add_trace(plot_board_state(modified_move_of_interest_state_one_hot[:, :, moved_piece_probe_index]), row=4, col=1)
fig.add_trace(plot_board_state(move_of_interest_probe_out_modified[:, :, moved_piece_probe_index], clip_size=5), row=4, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out_modified[:, :, moved_piece_probe_index]), row=4, col=3)

fig.add_trace(plot_board_state_with_text(move_of_interest_state_RR), row=5, col=1)
fig.add_trace(plot_board_state_with_text(state_stacks_probe_outputs_MBLRR[0][0][move_of_interest]), row=5, col=2)
fig.add_trace(plot_board_state_with_text(state_stacks_probe_outputs_MBLRR[0][0][move_of_interest]), row=5, col=2)

fig.add_trace(plot_board_state_with_text(modified_move_of_interest_state), row=6, col=1)
fig.add_trace(plot_board_state_with_text(modified_state_stacks_probe_outputs[0][0][move_of_interest]), row=6, col=2)
fig.add_trace(plot_board_state_with_text(modified_state_stacks_probe_outputs[0][0][move_of_interest]), row=6, col=2)

# Adjust the overall size of the figure
fig.update_layout(height=fig_rows * plot_size, width=fig_cols * plot_size)

# Show the figure
fig.show()