In [None]:
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer, HookedTransformerConfig
import einops
import torch
from tqdm import tqdm
import numpy as np
from fancy_einsum import einsum
import chess
import numpy as np
import csv
from dataclasses import dataclass
from torch.nn import MSELoss, L1Loss
import pandas as pd
import pickle
import os
import logging

import chess_utils

In [None]:
# Flags to control logging
debug_mode = True
info_mode = True

if debug_mode:
    log_level = logging.DEBUG
elif info_mode:
    log_level = logging.INFO
else:
    log_level = logging.WARNING

# Configure logging
logging.basicConfig(level=log_level)
logger = logging.getLogger(__name__)

In [None]:
MODEL_DIR = "models/"
DATA_DIR = "data/"

@dataclass
class Config:
    min_val: int
    max_val: int
    custom_function: callable
    linear_probe_name: str
    num_rows: int = 8
    num_cols: int = 8
    board_seqs_int_file: str = f"{DATA_DIR}train_board_seqs_int.npy"
    board_seqs_str_file: str = f"{DATA_DIR}train_board_seqs_str.csv"
    df_file: str = f"{DATA_DIR}train.csv"
    skill_file: str = None

piece_config = Config(
    min_val = -6,
    max_val = 6,
    custom_function = chess_utils.board_to_piece_state,
    linear_probe_name = "chess_piece_probe",
)

color_config = Config(
    min_val = -1,
    max_val = 1,
    custom_function=chess_utils.board_to_piece_color_state,
    linear_probe_name="chess_color_probe",
)

random_config = Config(
    min_val = -1,
    max_val = 1,
    custom_function=chess_utils.board_to_random_state,
    linear_probe_name="chess_random_probe",
)

skill_config = Config(
    min_val = -2,
    max_val = 20,
    custom_function=chess_utils.board_to_skill_state,
    linear_probe_name="chess_skill_probe",
    num_rows = 1,
    num_cols= 1,
    board_seqs_int_file = f"{DATA_DIR}skill_train_board_seqs_int.npy",
    board_seqs_str_file = f"{DATA_DIR}skill_train_board_seqs_str.csv",
    df_file = f"{DATA_DIR}skill_train.csv",
    skill_file = f"{DATA_DIR}skill_train_skill_level.npy",
)

config = piece_config
# config = color_config
# config = random_config
config = skill_config

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

# device = "cpu" # For debugging

n_layers = 16
n_heads = 8

PROCESS_DATA = True
levels_of_interest = [0, 5, 10, 20] # NOTE: This is only used if PROCESS_DATA is True
TRAIN_WITH_MSE = True
NORMALIZE_SKILL_FOR_MSE = True
wandb_logging = False
os.environ["WANDB_MODE"] = "online"

In [None]:
cfg = HookedTransformerConfig(
    n_layers = n_layers,
    d_model = 512,
    d_head = 64,
    n_heads = n_heads,
    d_mlp = 2048,
    d_vocab = 32,
    n_ctx = 1023,
    act_fn="gelu",
    normalization_type="LNPre"
)
model = HookedTransformer(cfg)
model_name = f"tf_lens_{n_layers}"
model.load_state_dict(torch.load(f'{MODEL_DIR}{model_name}.pth'))
model.to(device)

In [None]:
# meta is used to encode the string pgn strings into integer sequences
with open("nanogpt/out/meta.pkl", "rb") as f:
    meta = pickle.load(f)

logger.info(meta)

stoi, itos = meta["stoi"], meta["itos"]
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])

logger.info(encode(";1.e4 "))
logger.info(decode(encode(";1.e4 ")))

In [None]:
input_df_filename = config.df_file
processing_df_filename = f"{DATA_DIR}temporary_in_processing.csv"

df = pd.read_csv(input_df_filename)
df.to_csv(processing_df_filename, index=False)
user_state_dict_one_hot_mapping = None

if PROCESS_DATA:
    df = pd.read_csv(processing_df_filename)
    user_state_dict_one_hot_mapping = {}
    for i in range(len(levels_of_interest)):
        user_state_dict_one_hot_mapping[levels_of_interest[i]] = i

    matches = {f'Stockfish {number}' for number in levels_of_interest}
    logger.info(f"Stockfish levels to be used in probe dataset: {matches}")
    
    # Filter the DataFrame based on these matches
    filtered_df = df[df['player_two'].isin(matches)]
    filtered_df.to_csv(processing_df_filename, index=False)
    logger.info(f"Number of games in filtered dataset: {len(filtered_df)}")

df = pd.read_csv(processing_df_filename)



prefix = ""
split = "train_"
if "test" in input_df_filename:
    split = "test_"

if "skill" in input_df_filename:
    prefix = "skill_"

df = pd.read_csv(f"{processing_df_filename}")
row_length = len(df["transcript"].iloc[0])
num_games = len(df)

assert all(
    df["transcript"].apply(lambda x: len(x) == row_length)
), "Not all transcripts are of length {}".format(row_length)

df["transcript"].to_csv(
    config.board_seqs_str_file, index=False, header=False
)

logger.info(f'Number of games: {len(df)},length of a game in chars: {len(df["transcript"].iloc[0])}')

assert (len(df), len(df["transcript"].iloc[0])) == (num_games, row_length)

df = pd.read_csv(processing_df_filename)
encoded_df = df["transcript"].apply(encode)
logger.info(encoded_df.head())
board_seqs_int = np.array(encoded_df.apply(list).tolist())
logger.info(f"board_seqs_int shape: {board_seqs_int.shape}")
assert board_seqs_int.shape == (num_games, row_length)

np.save(config.board_seqs_int_file, board_seqs_int)

if prefix == "skill_":
    df = pd.read_csv(f"{processing_df_filename}")
    # Extract skill levels as integers
    skill_levels_list = [int(x.split()[1]) for x in df["player_two"]]

    # Convert the list to a numpy array
    skill_level = np.array(skill_levels_list)
    logger.info(f"skill_level shape: {skill_level.shape}")
    assert skill_level.shape == (num_games,)
    np.save(config.skill_file, skill_level)

In [None]:
custom_indexing_function = chess_utils.find_even_spaces_indices
indexing_function_name = custom_indexing_function.__name__

board_seqs_int = torch.tensor(np.load(config.board_seqs_int_file)).long()
logger.info(f"board_seqs_int shape: {board_seqs_int.shape}")

board_seqs_string = []

with open(config.board_seqs_str_file, newline='') as csvfile:
    reader = csv.reader(csvfile, delimiter=',')
    for row in reader:
        board_seqs_string.append(row[0])
logger.info(f"Number of games in board_seqs_string: {len(board_seqs_string)}, length of game in chars: {len(board_seqs_string[0])}")
# logger.debug(board_seqs_string[0])


custom_indices = chess_utils.find_custom_indices(processing_df_filename, custom_indexing_function)
custom_indices = torch.tensor(custom_indices).long()
# state_stack = torch.tensor(np.load("state_stacks_5k.npy")).long() # TODO: Does loading state stack to memory speed up training?
# logger.debug(state_stack.shape)
logger.info(f"custom_indices shape: {custom_indices.shape}")

num_games, shortest_length = custom_indices.shape
num_games, game_length = board_seqs_int.shape
assert num_games == len(board_seqs_string)
assert game_length == len(board_seqs_string[0])
assert num_games, shortest_length == custom_indices.shape


In [None]:
custom_function = config.custom_function
skill_stack = None
test_skill = None
std = 1 # Default value if NORMALIZE_SKILL_FOR_MSE is False

if config.skill_file is not None:
    skill_stack = torch.tensor(np.load(config.skill_file)).long()
    # skill_stack = (skill_stack >= 10).long() # This line can be used for binning the skill levels
    logger.info(f"Unique values in skill_stack: {skill_stack.unique()}")
    logger.info(f"skill_stack shape: {skill_stack.shape}")
    test_skill = skill_stack[0]

    if TRAIN_WITH_MSE and NORMALIZE_SKILL_FOR_MSE:
        skill_stack = skill_stack.to(dtype=torch.float32) # necessary for mean and std to be float32
        mean = skill_stack.mean()
        std = skill_stack.std()

        # Normalize the target values
        skill_stack = (skill_stack - mean) / std


state_stack = torch.tensor(chess_utils.create_state_stack(board_seqs_string[0], custom_function, test_skill)).long()
logger.info(f"A single state_stack shape: {state_stack.shape}")

state_stacks = chess_utils.create_state_stacks(board_seqs_string[:50], custom_function, skill_stack)
logger.info(f"state_stack shape: {state_stacks.shape}")

In [None]:
batch_size = 10
wd = 0.01
pos_start = 25 # indexes into custom_indexing_function. Example: for find_dots_indices, selects everything after the first 25 moves
# pos_end = model.cfg.n_ctx - 5
# input_length = 680
# pos_end = input_length - 0
# length = pos_end - pos_start
one_hot_range = config.max_val - config.min_val + 1
if PROCESS_DATA:
    one_hot_range = len(levels_of_interest)
num_epochs = 6
num_games = (len(board_seqs_int) // batch_size) * batch_size # Unfortunately, num_games must be divisible by batch_size TODO: Fix this
modes = 1

max_lr = 3e-4
min_lr = max_lr / 10
max_iters = num_games * num_epochs
decay_lr = True


state_stack_one_hot = chess_utils.state_stack_to_one_hot(modes, config.num_rows, config.num_cols, config.min_val, config.max_val, device, state_stacks, user_state_dict_one_hot_mapping)
logger.info(f"state_stack_one_hot shape: {state_stack_one_hot.shape}\n")
logger.info(f"Note: This will only be meaningful if training on board state: \n{state_stack_one_hot[:, 1, 170, 4:9, 2:5]}")
logger.info(f"Note: This will only be meaningful if training on board state: \n{state_stacks[:, 1, 170, 4:9, 2:5]}")

In [None]:
if not TRAIN_WITH_MSE:

    def get_lr(current_iter: int, max_iters: int, max_lr: float, min_lr: float) -> float:
        """
        Calculate the learning rate using linear decay.

        Args:
        - current_iter (int): The current iteration.
        - max_iters (int): The total number of iterations for decay.
        - lr (float): The initial learning rate.
        - min_lr (float): The minimum learning rate after decay.

        Returns:
        - float: The calculated learning rate.
        """
        # Ensure current_iter does not exceed max_iters
        current_iter = min(current_iter, max_iters)

        # Calculate the linearly decayed learning rate
        decayed_lr = max_lr - (max_lr - min_lr) * (current_iter / max_iters)

        return decayed_lr

    def train_linear_probe(layer: int):
        linear_probe_name = f"{MODEL_DIR}{model_name}_{config.linear_probe_name}_layer_{layer}.pth"
        linear_probe = torch.randn(
            modes, model.cfg.d_model, config.num_rows, config.num_cols, one_hot_range, requires_grad=False, device=device
        )/np.sqrt(model.cfg.d_model)
        linear_probe.requires_grad = True
        logger.info(f"linear_probe shape: {linear_probe.shape}")
        lr = max_lr
        optimiser = torch.optim.AdamW([linear_probe], lr=lr, betas=(0.9, 0.99), weight_decay=wd)

        logger.info(f"custom_indices shape: {custom_indices.shape}")

        # logger.debug(dots_indices.shape)

        wandb_project = "chess_linear_probes"
        wandb_run_name = f"{config.linear_probe_name}_{model_name}_layer_{layer}_indexing_{indexing_function_name}"

        if wandb_logging:
            import wandb
            logging_dict = {"linear_probe_name": config.linear_probe_name, "model_name": model_name, "layer": layer,
                            "indexing_function_name": indexing_function_name,
                            "batch_size": batch_size, "max_lr": max_lr, "wd": wd, "pos_start": pos_start,
                            "num_epochs": num_epochs, "num_games": num_games, "modes": modes,
                            "one_hot_range": one_hot_range, "wandb_project": wandb_project, "wandb_run_name": wandb_run_name}
            wandb.init(project=wandb_project, name=wandb_run_name, config=logging_dict)


        current_iter = 0
        loss = 0
        accuracy = 0
        for epoch in range(num_epochs):
            full_train_indices = torch.randperm(num_games)
            for i in tqdm(range(0, num_games, batch_size)):

                lr = get_lr(current_iter, max_iters, max_lr, min_lr) if decay_lr else lr
                for param_group in optimiser.param_groups:
                    param_group['lr'] = lr
                
                indices = full_train_indices[i:i+batch_size]
                list_of_indices = indices.tolist() # For indexing into the board_seqs_string list of strings
                # logger.debug(list_of_indices)
                games_int = board_seqs_int[indices]
                games_int = games_int[:, :]
                # logger.debug(games_int.shape)
                games_str = [board_seqs_string[idx] for idx in list_of_indices]
                games_str = [s[:] for s in games_str]
                games_dots = custom_indices[indices]
                games_dots = games_dots[:, pos_start:]
                # logger.debug(games_dots.shape)

                if config.skill_file is not None:
                    games_skill = skill_stack[indices]
                    # logger.debug(games_skill.shape)
                else:
                    games_skill = None

                state_stack = chess_utils.create_state_stacks(games_str, custom_function, games_skill)
                # state_stack = state_stack[:, pos_start:pos_end, :, :]
                # logger.debug(state_stack.shape)
                # Initialize a list to hold the indexed state stacks
                indexed_state_stacks = []

                for batch_idx in range(batch_size):
                    # Get the indices for the current batch
                    dots_indices_for_batch = games_dots[batch_idx]

                    # Index the state_stack for the current batch
                    indexed_state_stack = state_stack[:, batch_idx, dots_indices_for_batch, :, :]

                    # Append the result to the list
                    indexed_state_stacks.append(indexed_state_stack)

                # Stack the indexed state stacks along the first dimension
                # This results in a tensor of shape [2, 61, 8, 8] (assuming all batches have 61 indices)
                state_stack = torch.stack(indexed_state_stacks)
                
                # Use einops to rearrange the dimensions after stacking
                state_stack = einops.rearrange(state_stack, 'batch modes pos row col -> modes batch pos row col')

                # logger.debug("after indexing state stack shape", state_stack.shape)

                # state_stack_one_hot = chess_utils.state_stack_to_one_hot(modes, config.num_rows, config.num_cols, config.min_val, config.max_val, device, state_stack).to(device)
                state_stack_one_hot = chess_utils.state_stack_to_one_hot(modes, config.num_rows, config.num_cols, 1, len(levels_of_interest), device, state_stack, user_state_dict_one_hot_mapping).to(device)
                # logger.debug(state_stack_one_hot.shape)
                with torch.inference_mode():
                    _, cache = model.run_with_cache(games_int.to(device)[:, :-1], return_type=None)
                    resid_post = cache["resid_post", layer][:, :]
                # Initialize a list to hold the indexed state stacks
                indexed_resid_posts = []

                for batch_idx in range(games_dots.size(0)):
                    # Get the indices for the current batch
                    dots_indices_for_batch = games_dots[batch_idx]

                    # Index the state_stack for the current batch
                    indexed_resid_post = resid_post[batch_idx, dots_indices_for_batch]

                    # Append the result to the list
                    indexed_resid_posts.append(indexed_resid_post)

                # Stack the indexed state stacks along the first dimension
                # This results in a tensor of shape [2, 61, 8, 8] (assuming all batches have 61 indices)
                resid_post = torch.stack(indexed_resid_posts)
                # logger.debug("Resid post", resid_post.shape)
                probe_out = einsum(
                    "batch pos d_model, modes d_model rows cols options -> modes batch pos rows cols options",
                    resid_post,
                    linear_probe,
                )
                # logger.debug(probe_out.shape, state_stack_one_hot.shape, state_stack.shape)

                assert probe_out.shape == state_stack_one_hot.shape

                accuracy = (probe_out[0].argmax(-1) == state_stack_one_hot[0].argmax(-1)).float().mean()
                

                probe_log_probs = probe_out.log_softmax(-1)
                probe_correct_log_probs = einops.reduce(
                    probe_log_probs * state_stack_one_hot,
                    "modes batch pos rows cols options -> modes pos rows cols",
                    "mean"
                ) * one_hot_range # Multiply to correct for the mean over one_hot_range
                loss = -probe_correct_log_probs[0, :].mean(0).sum()

                loss.backward()
                if i % 100 == 0:
                    logger.info(f"epoch {epoch}, batch {i}, acc {accuracy}, loss {loss}, lr {lr}")
                    if wandb_logging:
                        wandb.log({"acc": accuracy,
                                "loss": loss,
                                "lr": lr,
                                "epoch": epoch,
                                "iter": current_iter})

                optimiser.step()
                optimiser.zero_grad()
                current_iter += batch_size

        checkpoint = {
            "linear_probe": linear_probe,
            "layer": layer,
            "config_name": config.linear_probe_name,
            "final_loss": loss,
            "model_name": model_name,
            "iters": current_iter,
            "epochs": epoch,
            "acc": accuracy,
        }
        torch.save(checkpoint, linear_probe_name)

    for i in range(0, n_layers, 2):
        train_linear_probe(i)
    # train_linear_probe(12)

In [None]:
if TRAIN_WITH_MSE:

    mse_loss_function = MSELoss()
    mae_loss_function = L1Loss()

    def get_lr(current_iter: int, max_iters: int, max_lr: float, min_lr: float) -> float:
        """
        Calculate the learning rate using linear decay.

        Args:
        - current_iter (int): The current iteration.
        - max_iters (int): The total number of iterations for decay.
        - lr (float): The initial learning rate.
        - min_lr (float): The minimum learning rate after decay.

        Returns:
        - float: The calculated learning rate.
        """
        # Ensure current_iter does not exceed max_iters
        current_iter = min(current_iter, max_iters)

        # Calculate the linearly decayed learning rate
        decayed_lr = max_lr - (max_lr - min_lr) * (current_iter / max_iters)

        return decayed_lr

    def train_linear_probe(layer: int):
        linear_probe_name = f"{MODEL_DIR}{model_name}_{config.linear_probe_name}_layer_{layer}.pth"
        linear_probe = torch.randn(
            modes, model.cfg.d_model, config.num_rows, config.num_cols, requires_grad=False, device=device
        )/np.sqrt(model.cfg.d_model)
        linear_probe.requires_grad = True
        logger.info(f"linear_probe shape: {linear_probe.shape}")
        lr = max_lr
        optimiser = torch.optim.AdamW([linear_probe], lr=lr, betas=(0.9, 0.99), weight_decay=wd)

        logger.info(f"custom_indices shape: {custom_indices.shape}")

        # logger.debug(dots_indices.shape)

        wandb_project = "chess_linear_probes_mse"
        wandb_run_name = f"{config.linear_probe_name}_{model_name}_layer_{layer}_indexing_{indexing_function_name}"

        if wandb_logging:
            import wandb
            logging_dict = {"linear_probe_name": config.linear_probe_name, "model_name": model_name, "layer": layer,
                            "indexing_function_name": indexing_function_name,
                            "batch_size": batch_size, "max_lr": max_lr, "wd": wd, "pos_start": pos_start,
                            "num_epochs": num_epochs, "num_games": num_games, "modes": modes,
                            "one_hot_range": one_hot_range, "wandb_project": wandb_project, "wandb_run_name": wandb_run_name}
            wandb.init(project=wandb_project, name=wandb_run_name, config=logging_dict)


        current_iter = 0
        loss = 0
        acc_blank = 0
        for epoch in range(num_epochs):
            full_train_indices = torch.randperm(num_games)
            for i in tqdm(range(0, num_games, batch_size)):

                lr = get_lr(current_iter, max_iters, max_lr, min_lr) if decay_lr else lr
                for param_group in optimiser.param_groups:
                    param_group['lr'] = lr
                
                indices = full_train_indices[i:i+batch_size]
                # logger.debug("indices", indices)
                list_of_indices = indices.tolist() # For indexing into the board_seqs_string list of strings
                # logger.debug(list_of_indices)
                games_int = board_seqs_int[indices]
                games_int = games_int[:, :]
                # logger.debug(games_int.shape)
                games_str = [board_seqs_string[idx] for idx in list_of_indices]
                games_str = [s[:] for s in games_str]
                # logger.debug(games_str)
                games_dots = custom_indices[indices]
                games_dots = games_dots[:, pos_start:]
                # logger.debug(games_dots)
                # logger.debug(games_dots.shape)

                if config.skill_file is not None:
                    games_skill = skill_stack[indices]
                    # logger.debug("GAMES SKILL", games_skill)
                    # logger.debug(games_skill.shape)
                else:
                    games_skill = None

                # logger.debug("Games skill", games_skill)

                state_stack = chess_utils.create_state_stacks(games_str, custom_function, games_skill)
                # logger.debug("STATE STACK", state_stack)
                # state_stack = state_stack[:, pos_start:pos_end, :, :]
                # logger.debug("Shape before indexing state stack:", state_stack.shape)
                # Initialize a list to hold the indexed state stacks
                indexed_state_stacks = []

                for batch_idx in range(batch_size): # TODO FIX Batching
                    # Get the indices for the current batch
                    dots_indices_for_batch = games_dots[batch_idx]

                    # Index the state_stack for the current batch
                    indexed_state_stack = state_stack[:, batch_idx, dots_indices_for_batch, :, :]

                    # Append the result to the list
                    indexed_state_stacks.append(indexed_state_stack)

                # Stack the indexed state stacks along the first dimension
                # This results in a tensor of shape [2, 61, 8, 8] (assuming all batches have 61 indices)
                # logger.debug("Length of indexed state stacks", len(indexed_state_stacks))
                state_stack = torch.stack(indexed_state_stacks).to(device)
                state_stack = state_stack.to(dtype=torch.float32)
                
                # Use einops to rearrange the dimensions after stacking
                state_stack = einops.rearrange(state_stack, 'batch modes pos row col -> modes batch pos row col')

                # logger.debug("state stack", state_stack)

                # logger.debug("after indexing state stack shape", state_stack.shape)

                with torch.inference_mode():
                    _, cache = model.run_with_cache(games_int.to(device)[:, :-1], return_type=None)
                    resid_post = cache["resid_post", layer][:, :]
                # Initialize a list to hold the indexed state stacks
                indexed_resid_posts = []

                for batch_idx in range(games_dots.size(0)):
                    # Get the indices for the current batch
                    dots_indices_for_batch = games_dots[batch_idx]

                    # Index the state_stack for the current batch
                    indexed_resid_post = resid_post[batch_idx, dots_indices_for_batch]

                    # Append the result to the list
                    indexed_resid_posts.append(indexed_resid_post)

                # Stack the indexed state stacks along the first dimension
                # This results in a tensor of shape [2, 61, 8, 8] (assuming all batches have 61 indices)
                resid_post = torch.stack(indexed_resid_posts)
                # logger.debug("Resid post", resid_post.shape)
                probe_out = einsum(
                    "batch pos d_model, modes d_model rows cols -> modes batch pos rows cols",
                    resid_post,
                    linear_probe,
                )
                # logger.debug(probe_out.shape)

                assert probe_out.shape == state_stack.shape

                # acc_blank = (probe_out[0].argmax(-1) == state_stack_one_hot[0].argmax(-1)).float().mean()
                # acc_color = ((probe_out[1].argmax(-1) == state_stack_one_hot[1].argmax(-1)) * state_stack_one_hot[1].sum(-1)).float().sum()/(state_stack_one_hot[1]).float().sum()

                

                # loss_even = -probe_correct_log_probs[0, 0::2].mean(0).sum() # note that "even" means odd in the game framing, since we offset by 5 moves lol
                # loss_odd = -probe_correct_log_probs[1, 1::2].mean(0).sum()
                # logger.debug(probe_out.shape, state_stack.shape)
                # logger.debug(probe_out[0][0][0][0][0])
                # logger.debug(state_stack[0][0][0][0][0])
                mse_loss = mse_loss_function(probe_out, state_stack)
                mae_loss = mae_loss_function(probe_out, state_stack)
                
                mae_loss_denormalized = mae_loss * std

                # logger.debug(probe_out.shape, probe_out.dtype)
                # logger.debug(state_stack.shape, state_stack.dtype)


                if i % 200 == 0:
                    logger.info(f"epoch {epoch}, batch {i}, mae loss {mae_loss.item()}, mae denorm loss {mae_loss_denormalized.item()}, mse loss {mse_loss.item()}, lr {lr}")
                    if wandb_logging:
                        wandb.log({"acc": mae_loss.item(),
                                "loss": mse_loss.item(),
                                "mae_denorm_loss": mae_loss_denormalized.item(),
                                "lr": lr,
                                "epoch": epoch,
                                "iter": current_iter})

                loss = mse_loss
                loss.backward() # it's important to do a single backward pass for mysterious PyTorch reasons, so we add up the losses - it's per mode and per square.
            

                optimiser.step()
                optimiser.zero_grad()
                current_iter += batch_size

        checkpoint = {
            "linear_probe": linear_probe,
            "layer": layer,
            "config_name": config.linear_probe_name,
            "final_loss": loss,
            "model_name": model_name,
            "iters": current_iter,
            "epochs": epoch,
            "acc": acc_blank,
        }
        torch.save(checkpoint, linear_probe_name)

    for i in range(0, n_layers, 2):
        train_linear_probe(i)
    # train_linear_probe(10)