In [None]:
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer, HookedTransformerConfig
import einops
import torch
from tqdm import tqdm
import numpy as np
from fancy_einsum import einsum
import chess
import numpy as np
import csv
from dataclasses import dataclass
from torch.nn import MSELoss, L1Loss
import pandas as pd
import pickle
import os
import logging
import plotly.graph_objects as go

import chess_utils
import train_test_chess
from train_test_chess import Config, LinearProbeData

In [None]:
torch.set_grad_enabled(False)

In [None]:
# Flags to control logging
debug_mode = False
info_mode = True

if debug_mode:
    log_level = logging.DEBUG
elif info_mode:
    log_level = logging.INFO
else:
    log_level = logging.WARNING

# Configure logging
logging.basicConfig(level=log_level)
logger = logging.getLogger(__name__)

In [None]:
MODEL_DIR = "models/"
DATA_DIR = "data/"
PROBE_DIR = "linear_probes/"
SAVED_PROBE_DIR = "linear_probes/saved_probes/"
SPLIT = "test"

# device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
device = "cpu"
logger.info(f"Using device: {device}")

probe_to_test = "tf_lens_lichess_16layers_ckpt_no_optimizer_chess_piece_probe_layer_12_pos_start_0.pth"

probe_file_location = f"{SAVED_PROBE_DIR}{probe_to_test}"
with open(probe_file_location, "rb") as f:
    state_dict = torch.load(f, map_location=torch.device(device))
    print(state_dict.keys())
    for key in state_dict.keys():
        if key != "linear_probe":
            print(key, state_dict[key])

    config = train_test_chess.find_config_by_name(state_dict["config_name"])
    layer = state_dict["layer"]
    model_name = state_dict["model_name"]
    dataset_prefix = state_dict["dataset_prefix"]
    process_data = state_dict["process_data"]
    column_name = state_dict["column_name"]
    config.pos_start = state_dict["pos_start"]
    levels_of_interest = None
    if "levels_of_interest" in state_dict.keys():
        levels_of_interest = state_dict["levels_of_interest"]
    config.levels_of_interest = levels_of_interest
    indexing_function_name = state_dict["indexing_function_name"]
    n_layers = state_dict["n_layers"]

    split = SPLIT
    input_dataframe_file = f"{DATA_DIR}{dataset_prefix}{split}.csv"
    config = train_test_chess.set_config_min_max_vals_and_column_name(
        config, input_dataframe_file, dataset_prefix
    )
    misc_logging_dict = {
        "split": split,
        "dataset_prefix": dataset_prefix,
        "model_name": model_name,
        "n_layers": n_layers,
    }

    probe_data = train_test_chess.construct_linear_probe_data(
        input_dataframe_file,
        layer,
        dataset_prefix,
        split,
        n_layers,
        model_name,
        config,
    )

In [None]:
sample_size = 1
modes = 1

game_length_in_chars = len(probe_data.board_seqs_string[0])


state_stacks_all_chars = chess_utils.create_state_stacks(probe_data.board_seqs_string[:sample_size], config.custom_board_state_function)
logger.info(f"state_stack shape: {state_stacks_all_chars.shape}")
assert(state_stacks_all_chars.shape) == (modes, sample_size, game_length_in_chars, config.num_rows, config.num_cols)
white_move_indices = probe_data.custom_indices[:sample_size]
print(white_move_indices.shape)
num_white_moves = white_move_indices.shape[1]
assert(white_move_indices.shape) == (sample_size, num_white_moves)

In [None]:
move_of_interest = 12
move_of_interest_index = white_move_indices[0][move_of_interest]
move_of_interest_state = state_stacks_all_chars[0][0][move_of_interest_index]
print(move_of_interest_state.shape)
print(move_of_interest_state)

In [None]:
checkpoint = torch.load(probe_file_location, map_location=torch.device(device))
linear_probe = checkpoint["linear_probe"]
print(linear_probe.shape)

# for piece type one hot vectors
min_val = -6
max_val = 6


one_hot_range = config.max_val - config.min_val + 1

games_int = probe_data.board_seqs_int[:sample_size]
games_dots = white_move_indices[:sample_size]

indexed_state_stacks = []

for batch_idx in range(state_stacks_all_chars.size(0)):
    # Get the indices for the current batch
    dots_indices_for_batch = games_dots[batch_idx]

    # Index the state_stack for the current batch
    indexed_state_stack = state_stacks_all_chars[:,batch_idx, dots_indices_for_batch, :, :]

    # Append the result to the list
    indexed_state_stacks.append(indexed_state_stack)

# Stack the indexed state stacks along the first dimension
state_stack_white_moves = torch.stack(indexed_state_stacks)
# print("after indexing state stack shape", state_stack.shape)
print("state stack shapes")
print(state_stack_white_moves.shape)
print(state_stacks_all_chars.shape)

state_stack_one_hot = chess_utils.state_stack_to_one_hot(modes,
                                 config.num_rows, config.num_cols, config.min_val, config.max_val, device, state_stack_white_moves).to(device)

with torch.inference_mode():
    _, cache = probe_data.model.run_with_cache(games_int.to(device)[:, :-1], return_type=None)
    resid_post = cache["resid_post", layer][:, :]

print(resid_post.shape)

# Initialize a list to hold the indexed state stacks
indexed_resid_posts = []

for batch_idx in range(games_dots.size(0)):
    # Get the indices for the current batch
    dots_indices_for_batch = games_dots[batch_idx]

    # Index the state_stack for the current batch
    indexed_resid_post = resid_post[batch_idx, dots_indices_for_batch]

    # Append the result to the list
    indexed_resid_posts.append(indexed_resid_post)

# Stack the indexed state stacks along the first dimension
# This results in a tensor of shape [2, 61, 8, 8] (assuming all batches have 61 indices)
resid_post = torch.stack(indexed_resid_posts)
resid_post = resid_post.to(device)
# print("Resid post", resid_post.shape)
probe_out = einsum(
    "batch pos d_model, modes d_model rows cols options -> modes batch pos rows cols options",
    resid_post,
    linear_probe,
)
print(probe_out.shape)
assert(probe_out.shape) == (state_stack_one_hot.shape)

In [None]:
state_stacks_one_hot = chess_utils.state_stack_to_one_hot(modes, config.num_rows, config.num_cols, config.min_val, config.max_val, device, state_stacks_all_chars)
print(state_stacks_one_hot.shape)
assert(state_stacks_one_hot.shape) == (modes, sample_size, game_length_in_chars, config.num_rows, config.num_cols, one_hot_range)
move_of_interest_state_one_hot = state_stacks_one_hot[0][0][move_of_interest_index]
print(move_of_interest_state_one_hot.shape)

In [None]:
print(move_of_interest_state_one_hot.shape)
print(state_stacks_one_hot.shape)
print(probe_out.shape)
assert(probe_out.shape) == (modes, sample_size, num_white_moves, config.num_rows, config.num_cols, one_hot_range)
state_stacks_probe_outputs = chess_utils.one_hot_to_state_stack(probe_out, config.min_val)
state_stacks_probe_outputs = torch.tensor(state_stacks_probe_outputs)
print(state_stacks_probe_outputs.shape)
assert(state_stacks_probe_outputs.shape) == (modes, sample_size, num_white_moves, config.num_rows, config.num_cols)
print(state_stacks_probe_outputs[0][0][move_of_interest])

In [None]:
# Mapping of chess pieces to integers
PIECE_TO_INT = {
    chess.PAWN: 1,
    chess.KNIGHT: 2,
    chess.BISHOP: 3,
    chess.ROOK: 4,
    chess.QUEEN: 5,
    chess.KING: 6,
}

INT_TO_CHAR = {
    -6: "k",
    -5: "q",
    -4: "r",
    -3: "b",
    -2: "n",
    -1: "p",
    0: ".",
    1: "P",
    2: "N",
    3: "B",
    4: "R",
    5: "Q",
    6: "K",
}

# Mapping of integers to chess pieces
PIECE_TO_ONE_HOT_MAPPING = {-6: 0, -5: 1, -4: 2, -3: 3, -2: 4, -1: 5, 0: 6, 1: 7, 2: 8, 3: 9, 4: 10, 5: 11, 6: 12}

blank_index = PIECE_TO_ONE_HOT_MAPPING[0]
white_pawn_index = PIECE_TO_ONE_HOT_MAPPING[1]
black_king_index = PIECE_TO_ONE_HOT_MAPPING[-6]

def plot_board_state(board_state: torch.Tensor, clip_size: int = 200):
    # color scale: Black for -1, Gray for 0, White for 1
    # colorscale = [[0.0, 'black'], [0.5, 'gray'], [1.0, 'white']]
    colorscale = 'gray'
    board_state = np.clip(board_state.numpy(), -clip_size, clip_size)

    # Create heatmap
    heatmap = go.Heatmap(z=board_state, colorscale=colorscale)
    return heatmap

heatmap = plot_board_state(move_of_interest_state_one_hot[:, :, white_pawn_index])

# Define the layout
layout = go.Layout(
    title="Chess board white pawns",
    xaxis=dict(ticks='', nticks=8),
    yaxis=dict(ticks='', nticks=8),
    autosize=False,
    width=600,
    height=600
)

# Create figure and plot
fig = go.Figure(data=[heatmap], layout=layout)
fig.show()

In [None]:
def tensor_to_text(board_state: torch.Tensor) -> np.ndarray:
    # Create a mapping from numbers to characters
    # Update this mapping according to your requirements

    # Convert the tensor to numpy array for easier processing
    board_array = board_state.numpy()

    # Create an empty array with the same shape for text
    text_array = np.empty(board_array.shape, dtype=str)

    # Fill the text array with corresponding characters
    for i in range(board_array.shape[0]):
        for j in range(board_array.shape[1]):
            text_array[i, j] = INT_TO_CHAR.get(board_array[i, j], str(board_array[i, j]))

    return text_array

def plot_board_state_with_text(board_state: torch.Tensor):
    # Convert the tensor to a text matrix
    text_matrix = tensor_to_text(board_state)

    # Define the custom colorscale
    colorscale = [
        [0, 'black'],   # Negative values
        [0.49, 'black'],
        [0.5, 'grey'],  # Zero
        [0.51, 'white'],
        [1, 'white']    # Positive values
    ]


    # Create heatmap with text and custom colorscale
    heatmap = go.Heatmap(
        z=board_state.numpy(), 
        text=text_matrix, 
        showscale=False, 
        colorscale=colorscale,
        texttemplate="%{text}"  # Set the texttemplate here
    )

    return heatmap
heatmap = plot_board_state_with_text(move_of_interest_state)

# Define the layout
layout = go.Layout(
    title="Chess board state with text",
    xaxis=dict(ticks='', nticks=8),
    yaxis=dict(ticks='', nticks=8),
    autosize=False,
    width=600,
    height=600
)

# Create figure and plot
fig = go.Figure(data=[heatmap], layout=layout)
fig.show()

In [None]:
from plotly.subplots import make_subplots

move_of_interest_probe_out = probe_out[0][0][move_of_interest]
print(move_of_interest_probe_out.shape)

fig_rows = 4
fig_cols = 3
fig = make_subplots(rows=fig_rows, cols=fig_cols, subplot_titles=[
    "Chess board blank squares", "Probe output blank squares clip=2", "Probe output blank squares no clipping",
    "Chess board white pawns", "Probe output white pawns clip=5", "Probe output white pawns no clipping",
    "Chess board black king", "Probe output black king clip=5", "Probe output black king no clipping",
    "Chess board state", "Probe output board state", "Redundant probe output board state"
])


# Specify the size of each plot
plot_size = 400  # You can adjust this size

fig.add_trace(plot_board_state(move_of_interest_state_one_hot[:, :, blank_index]), row=1, col=1)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, blank_index], clip_size=2), row=1, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, blank_index]), row=1, col=3)

fig.add_trace(plot_board_state(move_of_interest_state_one_hot[:, :, white_pawn_index]), row=2, col=1)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, white_pawn_index], clip_size=5), row=2, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, white_pawn_index]), row=2, col=3)

fig.add_trace(plot_board_state(move_of_interest_state_one_hot[:, :, black_king_index]), row=3, col=1)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, black_king_index], clip_size=5), row=3, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, black_king_index]), row=3, col=3)

fig.add_trace(plot_board_state_with_text(move_of_interest_state), row=4, col=1)
fig.add_trace(plot_board_state_with_text(state_stacks_probe_outputs[0][0][move_of_interest]), row=4, col=2)
fig.add_trace(plot_board_state_with_text(state_stacks_probe_outputs[0][0][move_of_interest]), row=4, col=2)

# Adjust the overall size of the figure
fig.update_layout(height=fig_rows * plot_size, width=fig_cols * plot_size)

# Show the figure
fig.show()

In [None]:
def calculate_matching_percentage(state_stacks: torch.Tensor, probe_outputs: torch.Tensor) -> float:
    """
    Calculate the percentage of matching cells in two tensors.

    :param state_stacks: A tensor of shape [1, 1, 680, 8, 8].
    :param probe_outputs: A tensor of shape [1, 1, 680, 8, 8].
    :return: The percentage of cells that match.
    """
    # Element-wise comparison
    matches = state_stacks == probe_outputs

    # Count the number of matches
    num_matches = matches.sum().item()

    # Total number of elements
    total_elements = state_stacks.numel()

    # Calculate percentage
    percentage = (num_matches / total_elements) * 100
    print(f"Out of {total_elements} elements, {num_matches} matched, {percentage}%")

    return percentage
assert(state_stacks_probe_outputs.shape) == (state_stack_white_moves.shape)
print("Linear probe accuracy on all board squares in sample size:", calculate_matching_percentage(state_stack_white_moves, state_stacks_probe_outputs))

round_trip = chess_utils.one_hot_to_state_stack(chess_utils.state_stack_to_one_hot(modes, config.num_rows, config.num_cols, config.min_val,config.max_val, device, state_stack_white_moves), config.min_val)
round_trip = torch.tensor(round_trip)
print(round_trip.shape)
print(state_stack_white_moves.shape)
assert(round_trip.shape) == (modes, sample_size, num_white_moves, config.num_rows, config.num_cols)
assert(round_trip.shape) == state_stack_white_moves.shape
matching_percentage = calculate_matching_percentage(round_trip, state_stack_white_moves)
assert(matching_percentage == 100.0)
print(f"Round trip matching percentage: {matching_percentage}%")