In [None]:
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer, HookedTransformerConfig
# from mech_interp_othello_utils import OthelloBoardState
import einops
import torch
from tqdm import tqdm
import numpy as np
from fancy_einsum import einsum
import chess
import numpy as np
import csv
import chess_utils
from dataclasses import dataclass
from torch.nn import MSELoss, L1Loss

device = "cuda"
device = "cpu"
device = "mps"

n_layers = 16
n_heads = 8
MODEL_DIR = "models/"
DATA_DIR = "data/"
cfg = HookedTransformerConfig(
    n_layers = n_layers,
    d_model = 512,
    d_head = 64,
    n_heads = n_heads,
    d_mlp = 2048,
    d_vocab = 32,
    n_ctx = 1023,
    act_fn="gelu",
    normalization_type="LNPre"
)
model = HookedTransformer(cfg)
model_name = f"tf_lens_{n_layers}"
model.load_state_dict(torch.load(f'{MODEL_DIR}{model_name}.pth'))
model.to(device)

In [None]:
# layer = 12

@dataclass
class Config:
    min_val: int
    max_val: int
    custom_function: callable
    linear_probe_name: str
    num_rows: int = 8
    num_cols: int = 8
    board_seqs_int_file: str = f"{DATA_DIR}train_board_seqs_int.npy"
    board_seqs_str_file: str = f"{DATA_DIR}train_board_seqs_str.csv"
    dots_indices_file: str = f"{DATA_DIR}train_dots_indices.npy"
    skill_file: str = None

piece_config = Config(
    min_val = -6,
    max_val = 6,
    custom_function = chess_utils.board_to_piece_state,
    linear_probe_name = "chess_piece_probe",
)

color_config = Config(
    min_val = -1,
    max_val = 1,
    custom_function=chess_utils.board_to_piece_color_state,
    linear_probe_name="chess_color_probe",
)

random_config = Config(
    min_val = -1,
    max_val = 1,
    custom_function=chess_utils.board_to_random_state,
    linear_probe_name="chess_random_probe",
)

skill_config = Config(
    min_val = -2,
    max_val = 20,
    custom_function=chess_utils.board_to_skill_state,
    linear_probe_name="chess_skill_probe",
    num_rows = 1,
    num_cols= 1,
    board_seqs_int_file = f"{DATA_DIR}skill_train_board_seqs_int.npy",
    board_seqs_str_file = f"{DATA_DIR}skill_train_board_seqs_str.csv",
    dots_indices_file = f"{DATA_DIR}skill_train_dots_indices.npy",
    skill_file = f"{DATA_DIR}skill_train_skill_level.npy",
)

# config = piece_config
# config = color_config
# config = random_config
config = skill_config

In [None]:
board_seqs_int = torch.tensor(np.load(config.board_seqs_int_file)).long()
print(board_seqs_int.shape)
dots_indices = torch.tensor(np.load(config.dots_indices_file)).long()
# state_stack = torch.tensor(np.load("state_stacks_5k.npy")).long() # TODO: Does loading state stack to memory speed up training?
# print(state_stack.shape)
print(dots_indices.shape)

board_seqs_string = []

with open(config.board_seqs_str_file, newline='') as csvfile:
    reader = csv.reader(csvfile, delimiter=',')
    for row in reader:
        board_seqs_string.append(row[0])
print(len(board_seqs_string), len(board_seqs_string[0]))
# print(board_seqs_string[0])


In [None]:
# print(board_seqs_string[0])
# custom_function = chess_utils.board_to_piece_state
custom_function = config.custom_function
skill_stack = None
test_skill = None

if config.skill_file is not None:
    skill_stack = torch.tensor(np.load(config.skill_file)).long()
    # skill_stack = (skill_stack >= 10).long() # TODO: Remove this line, used for testing with cross entropy loss
    print(skill_stack.unique())
    print(skill_stack.shape)
    skill_stack = skill_stack.to(dtype=torch.float32)
    test_skill = skill_stack[0]

    mean = skill_stack.mean()
    std = skill_stack.std()

    # Normalize the target values
    skill_stack = (skill_stack - mean) / std

    print(skill_stack.unique())


state_stack = torch.tensor(chess_utils.create_state_stack(board_seqs_string[0], custom_function, test_skill)).long()
print(state_stack.shape)

state_stacks = chess_utils.create_state_stacks(board_seqs_string[:50], custom_function, skill_stack)
print(state_stacks.shape)

In [None]:
batch_size = 10
max_lr = 2e-4
wd = 0.01
pos_start = 30 # indexes into white_moves_indices or dot_indices
# pos_end = model.cfg.n_ctx - 5
# input_length = 680
# pos_end = input_length - 0
# length = pos_end - pos_start
one_hot_range = config.max_val - config.min_val + 1
num_epochs = 2
num_games = 30000
x = 0
y = 2
# The first mode is blank or not, the second mode is next or prev GIVEN that it is not blank
modes = 1
# alternating = torch.tensor([1 if i%2 == 0 else -1 for i in range(length)], device=device)

print((state_stacks[:, 1, 170, :, :]))

In [None]:
max_lr = 3e-4
min_lr = max_lr / 10
max_iters = num_games * num_epochs
decay_lr = True
games_skill = None

mse_loss_function = MSELoss()
mae_loss_function = L1Loss()

import os
os.environ["WANDB_MODE"] = "offline"

def get_lr(current_iter: int, max_iters: int, max_lr: float, min_lr: float) -> float:
    """
    Calculate the learning rate using linear decay.

    Args:
    - current_iter (int): The current iteration.
    - max_iters (int): The total number of iterations for decay.
    - lr (float): The initial learning rate.
    - min_lr (float): The minimum learning rate after decay.

    Returns:
    - float: The calculated learning rate.
    """
    # Ensure current_iter does not exceed max_iters
    current_iter = min(current_iter, max_iters)

    # Calculate the linearly decayed learning rate
    decayed_lr = max_lr - (max_lr - min_lr) * (current_iter / max_iters)

    return decayed_lr

def train_linear_probe(layer: int):
    linear_probe_name = f"{MODEL_DIR}{model_name}_{config.linear_probe_name}_layer_{layer}.pth"
    linear_probe = torch.randn(
        modes, model.cfg.d_model, config.num_rows, config.num_cols, requires_grad=False, device=device
    )/np.sqrt(model.cfg.d_model)
    linear_probe.requires_grad = True
    print(linear_probe.shape)
    lr = max_lr
    optimiser = torch.optim.AdamW([linear_probe], lr=lr, betas=(0.9, 0.99), weight_decay=wd)

    print(dots_indices.shape)

    # print(dots_indices.shape)

    wandb_logging = False
    wandb_project = "chess_linear_probes"
    wandb_run_name = f"{config.linear_probe_name}_{model_name}_layer_{layer}"

    if wandb_logging:
        import wandb
        logging_dict = {"linear_probe_name": config.linear_probe_name, "model_name": model_name, "layer": layer,
                        "batch_size": batch_size, "max_lr": max_lr, "wd": wd, "pos_start": pos_start,
                        "num_epochs": num_epochs, "num_games": num_games, "x": x, "y": y, "modes": modes,
                        "one_hot_range": one_hot_range, "wandb_project": wandb_project, "wandb_run_name": wandb_run_name}
        wandb.init(project=wandb_project, name=wandb_run_name, config=config)


    current_iter = 0
    loss = 0
    acc_blank = 0
    for epoch in range(num_epochs):
        full_train_indices = torch.randperm(num_games)
        for i in tqdm(range(0, num_games, batch_size)):

            lr = get_lr(current_iter, max_iters, max_lr, min_lr) if decay_lr else lr
            for param_group in optimiser.param_groups:
                param_group['lr'] = lr
            
            indices = full_train_indices[i:i+batch_size]
            list_of_indices = indices.tolist() # For indexing into the board_seqs_string list of strings
            # print(list_of_indices)
            games_int = board_seqs_int[indices]
            games_int = games_int[:, :]
            # print(games_int.shape)
            games_str = [board_seqs_string[idx] for idx in list_of_indices]
            games_str = [s[:] for s in games_str]
            games_dots = dots_indices[indices]
            games_dots = games_dots[:, pos_start:]
            # print(games_dots.shape)

            if config.skill_file is not None:
                games_skill = skill_stack[indices]
                # print("GAMES SKILL", games_skill)
                # print(games_skill.shape)
            else:
                games_skill = None

            state_stack = chess_utils.create_state_stacks(games_str, custom_function, games_skill)
            # print("STATE STACK", state_stack)
            # state_stack = state_stack[:, pos_start:pos_end, :, :]
            # print("Shape before indexing state stack:", state_stack.shape)
            # Initialize a list to hold the indexed state stacks
            indexed_state_stacks = []

            for batch_idx in range(batch_size): # TODO FIX Batching
                # Get the indices for the current batch
                dots_indices_for_batch = games_dots[batch_idx]

                # Index the state_stack for the current batch
                indexed_state_stack = state_stack[:, batch_idx, dots_indices_for_batch, :, :]

                # Append the result to the list
                indexed_state_stacks.append(indexed_state_stack)

            # Stack the indexed state stacks along the first dimension
            # This results in a tensor of shape [2, 61, 8, 8] (assuming all batches have 61 indices)
            # print("Length of indexed state stacks", len(indexed_state_stacks))
            state_stack = torch.stack(indexed_state_stacks).to(device)
            state_stack = state_stack.to(dtype=torch.float32)
            
            # Use einops to rearrange the dimensions after stacking
            state_stack = einops.rearrange(state_stack, 'batch modes pos row col -> modes batch pos row col')

            # print("after indexing state stack shape", state_stack.shape)

            with torch.inference_mode():
                _, cache = model.run_with_cache(games_int.to(device)[:, :-1], return_type=None)
                resid_post = cache["resid_post", layer][:, :]
            # Initialize a list to hold the indexed state stacks
            indexed_resid_posts = []

            for batch_idx in range(games_dots.size(0)):
                # Get the indices for the current batch
                dots_indices_for_batch = games_dots[batch_idx]

                # Index the state_stack for the current batch
                indexed_resid_post = resid_post[batch_idx, dots_indices_for_batch]

                # Append the result to the list
                indexed_resid_posts.append(indexed_resid_post)

            # Stack the indexed state stacks along the first dimension
            # This results in a tensor of shape [2, 61, 8, 8] (assuming all batches have 61 indices)
            resid_post = torch.stack(indexed_resid_posts)
            # print("Resid post", resid_post.shape)
            probe_out = einsum(
                "batch pos d_model, modes d_model rows cols -> modes batch pos rows cols",
                resid_post,
                linear_probe,
            )
            # print(probe_out.shape)

            assert probe_out.shape == state_stack.shape

            # acc_blank = (probe_out[0].argmax(-1) == state_stack_one_hot[0].argmax(-1)).float().mean()
            # acc_color = ((probe_out[1].argmax(-1) == state_stack_one_hot[1].argmax(-1)) * state_stack_one_hot[1].sum(-1)).float().sum()/(state_stack_one_hot[1]).float().sum()

            

            # loss_even = -probe_correct_log_probs[0, 0::2].mean(0).sum() # note that "even" means odd in the game framing, since we offset by 5 moves lol
            # loss_odd = -probe_correct_log_probs[1, 1::2].mean(0).sum()
            # print(probe_out.shape, state_stack.shape)
            # print(probe_out[0][0][0][0][0])
            # print(state_stack[0][0][0][0][0])
            mse_loss = mse_loss_function(probe_out, state_stack)
            mae_loss = mae_loss_function(probe_out, state_stack)
            
            mae_loss_denormalized = mae_loss * std

            # print(probe_out.shape, probe_out.dtype)
            # print(state_stack.shape, state_stack.dtype)

            # loss_all = probe_out.sum()

            if i % 100 == 0:
                print(f"epoch {epoch}, batch {i}, mae loss {mae_loss.item()}, mae denorm loss {mae_loss_denormalized.item()}, mse loss {mse_loss.item()}, lr {lr}")
                if wandb_logging:
                    wandb.log({"acc": mae_loss.item(),
                            "loss": mse_loss.item(),
                            "mae_denorm_loss": mae_loss_denormalized.item(),
                            "lr": lr,
                            "epoch": epoch,
                            "iter": current_iter})

            # if i % 1000 == 0:
            #     print(f"epoch {epoch}, batch {i}, acc_blank {acc_blank}, acc_color {acc_color}, loss_even {loss_even}, loss_odd {loss_odd}, loss_all {loss_all}")
            
            # loss = loss_even + loss_odd + loss_all
            loss = mse_loss
            loss.backward() # it's important to do a single backward pass for mysterious PyTorch reasons, so we add up the losses - it's per mode and per square.
        

            optimiser.step()
            optimiser.zero_grad()
            current_iter += batch_size

    checkpoint = {
        "linear_probe": linear_probe,
        "layer": layer,
        "config_name": config.linear_probe_name,
        "final_loss": loss,
        "model_name": model_name,
        "iters": current_iter,
        "epochs": epoch,
        "acc": acc_blank,
    }
    torch.save(checkpoint, linear_probe_name)

# for i in range(0, n_layers+1, 2):
#     train_linear_probe(i)
train_linear_probe(12)