In [1]:
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer, HookedTransformerConfig
# from mech_interp_othello_utils import OthelloBoardState
import einops
import torch
from tqdm import tqdm
import numpy as np
from fancy_einsum import einsum
import chess
import numpy as np
import csv
import chess_utils
import plotly.graph_objects as go

device = "cuda"
device = "cpu"
# device = "mps"

n_layers = 16
n_heads = 8
MODEL_DIR = "models/"
DATA_DIR = "data/"
cfg = HookedTransformerConfig(
    n_layers = n_layers,
    d_model = 512,
    d_head = 64,
    n_heads = n_heads,
    d_mlp = 2048,
    d_vocab = 32,
    n_ctx = 1023,
    act_fn="gelu",
    normalization_type="LNPre"
)
model = HookedTransformer(cfg)
model.load_state_dict(torch.load(f'{MODEL_DIR}tf_lens_16.pth'))
model.to(device)
torch.set_grad_enabled(False)

  from .autonotebook import tqdm as notebook_tqdm


Moving model to device:  cpu


<torch.autograd.grad_mode.set_grad_enabled at 0x17fb252d0>

In [2]:


board_seqs_int = torch.tensor(np.load(f"{DATA_DIR}test_board_seqs_int.npy")).long()

num_games = len(board_seqs_int)
game_length = len(board_seqs_int[0])
print(f"Number of games: {num_games}, game length in chars: {game_length}")

print(board_seqs_int.shape)
dots_indices = torch.tensor(np.load(f"{DATA_DIR}test_dots_indices.npy")).long()
# state_stack = torch.tensor(np.load("state_stacks_5k.npy")).long()
print(dots_indices.shape)
# print(state_stack.shape)

num_white_moves = len(dots_indices[0])
print(f"Number of white moves: {num_white_moves}")

board_seqs_string = []

with open(f"{DATA_DIR}test_board_seqs_string.csv", newline='') as csvfile:
    reader = csv.reader(csvfile, delimiter=',')
    for row in reader:
        board_seqs_string.append(row[0])
print(len(board_seqs_string), len(board_seqs_string[0]))
# print(board_seqs_string[0])

assert(board_seqs_int.shape) == (num_games, game_length)
assert(len(board_seqs_string)) == num_games
assert(len(board_seqs_string[0])) == game_length
assert(dots_indices.shape) == (num_games, num_white_moves)


Number of games: 8241, game length in chars: 680
torch.Size([8241, 680])
torch.Size([8241, 61])
Number of white moves: 61
8241 680


In [3]:
sample_size = 1
custom_function = chess_utils.board_to_piece_color_state
state_stacks = chess_utils.create_state_stacks(board_seqs_string[:sample_size], custom_function)
print(state_stacks.shape)
assert(state_stacks.shape) == (sample_size, game_length, 8, 8)
white_move_indices = dots_indices[:sample_size]
print(white_move_indices.shape)
assert(white_move_indices.shape) == (sample_size, num_white_moves)

torch.Size([1, 680, 8, 8])
torch.Size([1, 61])


In [4]:
move_13_index = white_move_indices[0][13]
move_13_state = state_stacks[0][move_13_index]
print(move_13_state.shape)
print(move_13_state)
chess_utils.pretty_print_state_stack(move_13_state.numpy())

torch.Size([8, 8])
tensor([[ 0,  0,  1,  1,  0,  1,  0,  1],
        [ 1,  1,  1,  1,  1,  1,  0,  0],
        [ 0,  0,  1,  0,  0,  1,  0,  0],
        [ 0, -1,  0,  0,  1,  0,  1,  0],
        [-1,  0,  0,  0,  0,  0,  0,  1],
        [ 0,  0, -1, -1, -1, -1,  0, -1],
        [-1,  0,  0,  0,  0,  0, -1, -1],
        [-1, -1,  0,  0, -1, -1,  0, -1]])
B B . . B B . B
B . . . . . B B
. . B B B B . B
B . . . . . . W
. B . . W . W .
. . W . . W . .
W W W W W W . .
. . W W . W . W


In [5]:
layer = 12
linear_probe = torch.load(f"{MODEL_DIR}main_chess_linear_probe.pth", map_location=torch.device('cpu'))
print(linear_probe.shape)

games_int = board_seqs_int[:sample_size]
games_dots = white_move_indices[:sample_size]

with torch.inference_mode():
    _, cache = model.run_with_cache(games_int.to(device)[:, :-1], return_type=None)
    resid_post = cache["resid_post", layer][:, :]

print(resid_post.shape)

# Initialize a list to hold the indexed state stacks
indexed_resid_posts = []

for batch_idx in range(games_dots.size(0)):
    # Get the indices for the current batch
    dots_indices_for_batch = games_dots[batch_idx]

    # Index the state_stack for the current batch
    indexed_resid_post = resid_post[batch_idx, dots_indices_for_batch]

    # Append the result to the list
    indexed_resid_posts.append(indexed_resid_post)

# Stack the indexed state stacks along the first dimension
# This results in a tensor of shape [2, 61, 8, 8] (assuming all batches have 61 indices)
resid_post = torch.stack(indexed_resid_posts)
# print("Resid post", resid_post.shape)
probe_out = einsum(
    "batch pos d_model, modes d_model rows cols options -> modes batch pos rows cols options",
    resid_post,
    linear_probe,
)
print(probe_out.shape)

torch.Size([1, 512, 8, 8, 3])
torch.Size([1, 679, 512])
torch.Size([1, 1, 61, 8, 8, 3])


In [6]:
def plot_board_state(board_state: torch.Tensor, title: str):
    # color scale: Black for -1, Gray for 0, White for 1
    # colorscale = [[0.0, 'black'], [0.5, 'gray'], [1.0, 'white']]
    colorscale = 'RdBu'

    # Create heatmap
    heatmap = go.Heatmap(z=board_state.numpy(), colorscale=colorscale)

    # Define the layout
    layout = go.Layout(
        title=title,
        xaxis=dict(ticks='', nticks=8),
        yaxis=dict(ticks='', nticks=8),
        autosize=False,
        width=600,
        height=600
    )

    # Create figure and plot
    fig = go.Figure(data=[heatmap], layout=layout)
    fig.show()
plot_board_state(move_13_state, "Chess board state")

In [7]:
num_options = 3 # options are black, white, empty
state_stacks_one_hot = chess_utils.state_stack_to_one_hot(1, 8, 8, num_options, device, state_stacks)
print(state_stacks_one_hot.shape)
assert(state_stacks_one_hot.shape) == (1, sample_size, game_length, 8, 8, num_options)
move_13_state_one_hot = state_stacks_one_hot[0][0][move_13_index]
print(move_13_state_one_hot.shape)

torch.Size([1, 1, 680, 8, 8, 3])
torch.Size([8, 8, 3])


In [8]:
plot_board_state(move_13_state_one_hot[:, :, 0], "Chess board blank squares")
plot_board_state(move_13_state_one_hot[:, :, 1], "Chess board white pieces")
plot_board_state(move_13_state_one_hot[:, :, 2], "Chess board black pieces")
plot_board_state(move_13_state, "Chess board state")

In [9]:
move_13_probe_out = probe_out[0][0][13]
print(move_13_probe_out.shape)
plot_board_state(move_13_probe_out[:, :, 0], "Chess board blank squares")
plot_board_state(move_13_probe_out[:, :, 1], "Chess board white pieces")
plot_board_state(move_13_probe_out[:, :, 2], "Chess board black pieces")
plot_board_state(move_13_state, "Chess board state")

torch.Size([8, 8, 3])


In [16]:
state_stacks_probe_outputs = chess_utils.one_hot_to_state_stack(state_stacks_one_hot)
state_stacks_probe_outputs = torch.tensor(state_stacks_probe_outputs)
print(state_stacks_probe_outputs.shape)
print(state_stacks_probe_outputs[0][0][move_13_index])
plot_board_state(state_stacks_probe_outputs[0][0][move_13_index], "Chess board state")

torch.Size([1, 1, 680, 8, 8])
tensor([[ 0,  0,  1,  1,  0,  1,  0,  1],
        [ 1,  1,  1,  1,  1,  1,  0,  0],
        [ 0,  0,  1,  0,  0,  1,  0,  0],
        [ 0, -1,  0,  0,  1,  0,  1,  0],
        [-1,  0,  0,  0,  0,  0,  0,  1],
        [ 0,  0, -1, -1, -1, -1,  0, -1],
        [-1,  0,  0,  0,  0,  0, -1, -1],
        [-1, -1,  0,  0, -1, -1,  0, -1]])


In [17]:
def calculate_matching_percentage(state_stacks: torch.Tensor, probe_outputs: torch.Tensor) -> float:
    """
    Calculate the percentage of matching cells in two tensors.

    :param state_stacks: A tensor of shape [1, 1, 680, 8, 8].
    :param probe_outputs: A tensor of shape [1, 1, 680, 8, 8].
    :return: The percentage of cells that match.
    """
    # Element-wise comparison
    matches = state_stacks == probe_outputs

    # Count the number of matches
    num_matches = matches.sum().item()

    # Total number of elements
    total_elements = state_stacks.numel()

    # Calculate percentage
    percentage = (num_matches / total_elements) * 100

    return percentage
print(calculate_matching_percentage(state_stacks, state_stacks_probe_outputs))

100.0
