In [None]:
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer, HookedTransformerConfig
# from mech_interp_othello_utils import OthelloBoardState
import einops
import torch
from tqdm import tqdm
import numpy as np
from fancy_einsum import einsum
import chess
import numpy as np
import csv
import chess_utils
import plotly.graph_objects as go
from dataclasses import dataclass

device = "cuda"
device = "cpu"
# device = "mps"

n_layers = 16
n_heads = 8
MODEL_DIR = "models/"
DATA_DIR = "data/"
cfg = HookedTransformerConfig(
    n_layers = n_layers,
    d_model = 512,
    d_head = 64,
    n_heads = n_heads,
    d_mlp = 2048,
    d_vocab = 32,
    n_ctx = 1023,
    act_fn="gelu",
    normalization_type="LNPre"
)
model_name = f"tf_lens_{n_layers}"
model = HookedTransformer(cfg)
model.load_state_dict(torch.load(f'{MODEL_DIR}{model_name}.pth'))
model.to(device)
torch.set_grad_enabled(False)

In [None]:
layer = 14

@dataclass
class Config:
    min_val: int
    max_val: int
    custom_function: callable
    linear_probe_name: str

piece_config = Config(
    min_val = -6,
    max_val = 6,
    custom_function = chess_utils.board_to_piece_state,
    linear_probe_name = f"{MODEL_DIR}{model_name}_chess_piece_probe_layer_{layer}.pth",
)

color_config = Config(
    min_val=-1,
    max_val=1,
    custom_function=chess_utils.board_to_piece_color_state,
    linear_probe_name=f"{MODEL_DIR}{model_name}_chess_color_probe_layer_{layer}.pth",
)

random_weights_config = Config(
    min_val=-1,
    max_val=1,
    custom_function=chess_utils.board_to_piece_color_state,
    linear_probe_name=f"{MODEL_DIR}{model_name}_randomly_initialized_probe.pth",
)

config = color_config
config = piece_config

modes = 1
rows = 8
cols = 8
start_pos_index = 5

In [None]:


board_seqs_int = torch.tensor(np.load(f"{DATA_DIR}test_board_seqs_int.npy")).long()

num_games = len(board_seqs_int)
game_length = len(board_seqs_int[0])
print(f"Number of games: {num_games}, game length in chars: {game_length}")

print(board_seqs_int.shape)
dots_indices = torch.tensor(np.load(f"{DATA_DIR}test_dots_indices.npy")).long()
print(dots_indices.shape)
dots_indices = dots_indices[:, start_pos_index:]

num_white_moves = len(dots_indices[0])
print(f"Number of white moves: {num_white_moves}")

board_seqs_string = []

with open(f"{DATA_DIR}test_board_seqs_string.csv", newline='') as csvfile:
    reader = csv.reader(csvfile, delimiter=',')
    for row in reader:
        board_seqs_string.append(row[0])
print(len(board_seqs_string), len(board_seqs_string[0]))
# print(board_seqs_string[0])

assert(board_seqs_int.shape) == (num_games, game_length)
assert(len(board_seqs_string)) == num_games
assert(len(board_seqs_string[0])) == game_length
assert(dots_indices.shape) == (num_games, num_white_moves)


In [None]:
sample_size = 1

state_stacks_all_chars = chess_utils.create_state_stacks(board_seqs_string[:sample_size], config.custom_function)
print(state_stacks_all_chars.shape)
assert(state_stacks_all_chars.shape) == (modes, sample_size, game_length, rows, cols)
white_move_indices = dots_indices[:sample_size]
print(white_move_indices.shape)
assert(white_move_indices.shape) == (sample_size, num_white_moves)

In [None]:
move_13_index = white_move_indices[0][13]
move_13_state = state_stacks_all_chars[0][0][move_13_index]
print(move_13_state.shape)
print(move_13_state)
# chess_utils.pretty_print_state_stack(move_13_state.numpy())

In [None]:

checkpoint = torch.load(config.linear_probe_name, map_location=torch.device('cpu'))
linear_probe = checkpoint["linear_probe"]
print(linear_probe.shape)

# for piece type one hot vectors
min_val = -6
max_val = 6


one_hot_range = config.max_val - config.min_val + 1

games_int = board_seqs_int[:sample_size]
games_dots = white_move_indices[:sample_size]

indexed_state_stacks = []

for batch_idx in range(state_stacks_all_chars.size(0)):
    # Get the indices for the current batch
    dots_indices_for_batch = games_dots[batch_idx]

    # Index the state_stack for the current batch
    indexed_state_stack = state_stacks_all_chars[:,batch_idx, dots_indices_for_batch, :, :]

    # Append the result to the list
    indexed_state_stacks.append(indexed_state_stack)

# Stack the indexed state stacks along the first dimension
# This results in a tensor of shape [2, 61, 8, 8] (assuming all batches have 61 indices)
state_stack_white_moves = torch.stack(indexed_state_stacks)
# print("after indexing state stack shape", state_stack.shape)
print("state stack shapes")
print(state_stack_white_moves.shape)
print(state_stacks_all_chars.shape)

state_stack_one_hot = chess_utils.state_stack_to_one_hot(modes,
                                 rows, cols, config.min_val, config.max_val, device, state_stack_white_moves).to(device)

with torch.inference_mode():
    _, cache = model.run_with_cache(games_int.to(device)[:, :-1], return_type=None)
    resid_post = cache["resid_post", layer][:, :]

print(resid_post.shape)

# Initialize a list to hold the indexed state stacks
indexed_resid_posts = []

for batch_idx in range(games_dots.size(0)):
    # Get the indices for the current batch
    dots_indices_for_batch = games_dots[batch_idx]

    # Index the state_stack for the current batch
    indexed_resid_post = resid_post[batch_idx, dots_indices_for_batch]

    # Append the result to the list
    indexed_resid_posts.append(indexed_resid_post)

# Stack the indexed state stacks along the first dimension
# This results in a tensor of shape [2, 61, 8, 8] (assuming all batches have 61 indices)
resid_post = torch.stack(indexed_resid_posts)
# print("Resid post", resid_post.shape)
probe_out = einsum(
    "batch pos d_model, modes d_model rows cols options -> modes batch pos rows cols options",
    resid_post,
    linear_probe,
)
print(probe_out.shape)
assert(probe_out.shape) == (state_stack_one_hot.shape)

In [None]:
def plot_board_state(board_state: torch.Tensor, title: str):
    # color scale: Black for -1, Gray for 0, White for 1
    # colorscale = [[0.0, 'black'], [0.5, 'gray'], [1.0, 'white']]
    colorscale = 'RdBu'

    # Create heatmap
    heatmap = go.Heatmap(z=board_state.numpy(), colorscale=colorscale)

    # Define the layout
    layout = go.Layout(
        title=title,
        xaxis=dict(ticks='', nticks=8),
        yaxis=dict(ticks='', nticks=8),
        autosize=False,
        width=600,
        height=600
    )

    # Create figure and plot
    fig = go.Figure(data=[heatmap], layout=layout)
    fig.show()
plot_board_state(move_13_state, "Chess board state")

In [None]:
num_options = 3 # options are black, white, empty
state_stacks_one_hot = chess_utils.state_stack_to_one_hot(modes, rows, cols, config.min_val, config.max_val, device, state_stacks_all_chars)
print(state_stacks_one_hot.shape)
assert(state_stacks_one_hot.shape) == (modes, sample_size, game_length, rows, cols, one_hot_range)
move_13_state_one_hot = state_stacks_one_hot[0][0][move_13_index]
print(move_13_state_one_hot.shape)

In [None]:
def state_stack_to_one_hot(
    num_modes: int,
    num_rows: int,
    num_cols: int,
    min_val: int,
    max_val: int,
    device: torch.device,
    state_stack: np.ndarray,
) -> torch.Tensor:
    """Input shape: assert(state_stacks_all_chars.shape) == (modes, sample_size, game_length, rows, cols)
    Output shape: assert(state_stacks_one_hot.shape) == (modes, sample_size, game_length, rows, cols, one_hot_range)
    """
    range_size = max_val - min_val + 1

    # Initialize the one-hot tensor
    one_hot = torch.zeros(
        state_stack.shape[0],  # num modes
        state_stack.shape[1],  # num games
        state_stack.shape[2],  # num moves
        num_rows,
        num_cols,
        range_size,
        device=device,
        dtype=int,
    )

    for val in range(min_val, max_val + 1):
        one_hot[..., val - min_val] = state_stack == val

    return one_hot

print(state_stacks_all_chars.shape)
state_stacks_one_hot = state_stack_to_one_hot(modes, rows, cols, config.min_val, config.max_val, device, state_stacks_all_chars)
print(state_stacks_one_hot.shape)

In [None]:
def one_hot_to_state_stack(one_hot: torch.Tensor, min_val: int) -> np.ndarray:
    indices = torch.argmax(one_hot, dim=-1)
    print(indices.shape)
    state_stack = indices.numpy() + min_val
    return state_stack
print(probe_out.shape)
test_output = one_hot_to_state_stack(probe_out, config.min_val)
print(test_output.shape)

In [None]:
plot_board_state(move_13_state_one_hot[:, :, 0], "Chess board black squares")
plot_board_state(move_13_state_one_hot[:, :, 1], "Chess board blank pieces")
plot_board_state(move_13_state_one_hot[:, :, 2], "Chess board white pieces")
plot_board_state(move_13_state, "Chess board state")

In [None]:
move_13_probe_out = probe_out[0][0][13]
print(move_13_probe_out.shape)
plot_board_state(move_13_probe_out[:, :, 0], "Chess board blank squares")
plot_board_state(move_13_probe_out[:, :, 1], "Chess board white pieces")
plot_board_state(move_13_probe_out[:, :, 2], "Chess board black pieces")
plot_board_state(move_13_state, "Chess board state")

In [None]:
print(move_13_state_one_hot.shape)
print(state_stacks_one_hot.shape)
print(probe_out.shape)
assert(probe_out.shape) == (modes, sample_size, num_white_moves, rows, cols, one_hot_range)
state_stacks_probe_outputs = chess_utils.one_hot_to_state_stack(probe_out, config.min_val)
state_stacks_probe_outputs = torch.tensor(state_stacks_probe_outputs)
print(state_stacks_probe_outputs.shape)
assert(state_stacks_probe_outputs.shape) == (modes, sample_size, num_white_moves, rows, cols)
print(state_stacks_probe_outputs[0][0][13])
plot_board_state(state_stacks_probe_outputs[0][0][13], "Chess board state")

In [None]:
def calculate_matching_percentage(state_stacks: torch.Tensor, probe_outputs: torch.Tensor) -> float:
    """
    Calculate the percentage of matching cells in two tensors.

    :param state_stacks: A tensor of shape [1, 1, 680, 8, 8].
    :param probe_outputs: A tensor of shape [1, 1, 680, 8, 8].
    :return: The percentage of cells that match.
    """
    # Element-wise comparison
    matches = state_stacks == probe_outputs

    # Count the number of matches
    num_matches = matches.sum().item()

    # Total number of elements
    total_elements = state_stacks.numel()

    # Calculate percentage
    percentage = (num_matches / total_elements) * 100
    print(f"Out of {total_elements} elements, {num_matches} matched, {percentage}%")

    return percentage
assert(state_stacks_probe_outputs.shape) == (state_stack_white_moves.shape)
print(calculate_matching_percentage(state_stack_white_moves, state_stacks_probe_outputs))

round_trip = chess_utils.one_hot_to_state_stack(chess_utils.state_stack_to_one_hot(modes, rows, cols, config.min_val,config.max_val, device, state_stack_white_moves), config.min_val)
round_trip = torch.tensor(round_trip)
print(round_trip.shape)
print(state_stack_white_moves.shape)
assert(round_trip.shape) == (modes, sample_size, num_white_moves, rows, cols)
assert(round_trip.shape) == state_stack_white_moves.shape
matching_percentage = calculate_matching_percentage(round_trip, state_stack_white_moves)
assert(matching_percentage == 100.0)
print(f"Round trip matching percentage: {matching_percentage}%")