In [7]:
from dataclasses import dataclass
from pathlib import Path
from typing import List

import huggingface_hub as hf
import numpy as np
import torch as t
import wandb
from bidict import bidict
from datasets import DatasetDict
from jaxtyping import Float, Int
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

from othello_gpt.data.generate import generate_dataset
from othello_gpt.data.vis import plot_game
from othello_gpt.model.nanoGPT import GPT, GPTConfig

In [8]:
hf.login()
wandb.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

[34m[1mwandb[0m: Currently logged in as: [33malfredwong[0m ([33malfredwong-university-of-cambridge[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")
device

device(type='mps')

In [3]:
root_dir = Path().cwd().parent.parent.parent
data_dir = root_dir / "data"
n_games = 10000
size = 6
PAD_TOKEN = -1

nw_middle_id = (size // 2 - 1) * size + (size // 2 - 1)
initial_squares = set([nw_middle_id, nw_middle_id + 1, nw_middle_id + size, nw_middle_id + size + 1])
all_squares = [i for i in range(size * size) if i not in initial_squares]
# id_to_token_id_map = bidict({square_id: token_id for token_id, square_id in enumerate([-1, size * size] + all_squares)})
id_to_token_id_map = bidict({square_id: token_id for token_id, square_id in enumerate([PAD_TOKEN] + all_squares)})

def tokenize(history):
    return {"input_ids": [id_to_token_id_map[i] for i in history]}

def decode(token_ids):
    return {"square_ids": [id_to_token_id_map.inverse[i] for i in token_ids]}

In [4]:
dataset_dict_path = data_dir / f"othello_{n_games}_{size}"

if dataset_dict_path.exists():
    dataset_dict = DatasetDict.load_from_disk(dataset_dict_path)
else:
    dataset = generate_dataset(n_games, size)
    dataset_dict = dataset.train_test_split(test_size=0.1)
    dataset_dict.save_to_disk(dataset_dict_path)

dataset_dict["test"] = dataset_dict["test"].map(lambda x: tokenize(x["moves"]))
dataset_dict["train"] = dataset_dict["train"].map(lambda x: tokenize(x["moves"]))

plot_game(dataset_dict["test"][0], subplot_size=180, n_cols=8)

Loading dataset from disk:   0%|          | 0/21 [00:00<?, ?it/s]

In [5]:
cfg = GPTConfig(
    # block_size=(size * size - 4) * 2 - 1,
    block_size=(size * size - 4) - 1,
    # vocab_size=size * size - 4 + 2,  # pass and pad
    vocab_size=size * size - 4 + 1,  # pad
    n_layer=8,
    n_head=8,
    n_embd=128,
    dropout= 0.0,
    bias=False, # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster,
)
display(cfg)
model = GPT(cfg).to(device)

GPTConfig(block_size=31, vocab_size=33, n_layer=8, n_head=8, n_embd=128, dropout=0.0, bias=False)

number of parameters: 1.58M


In [6]:
@dataclass
class TransformerTrainingArgs:
    batch_size = 32
    epochs = 8
    max_steps_per_epoch = 5120
    lr = 5e-4
    weight_decay = 1e-3
    wandb_project: str | None = "othello-gpt"
    wandb_name: str | None = None

args = TransformerTrainingArgs()

In [7]:
def pad_batch(batch: List[List[int]], max_len: int = cfg.block_size+1, pad_token: int = PAD_TOKEN) -> Int[Tensor, "batch max_len"]:
    padded_batch = t.full((len(batch), max_len), pad_token)
    for i, seq in enumerate(batch):
        padded_batch[i, -len(seq):] = t.tensor(seq)
    return padded_batch

In [8]:
class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: GPT):
        super().__init__()
        self.model = model
        self.args = args

        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.step = 0

        self.train_loader = DataLoader(dataset_dict["train"]["input_ids"], batch_size=args.batch_size, shuffle=True, pin_memory=True, collate_fn=pad_batch)
        self.test_loader = DataLoader(dataset_dict["test"]["input_ids"], batch_size=args.batch_size, shuffle=False, pin_memory=True, collate_fn=pad_batch)

    def training_step(self, batch: Int[Tensor, "batch seq"]) -> Float[Tensor, ""]:
        """
        Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.

        Remember that `batch` is a dictionary with the single key 'tokens'.
        """
        _, loss = model(batch[:, :-1], batch[:, 1:])
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.step += 1
        wandb.log({"train_loss": loss}, step=self.step)
        return loss

    @t.inference_mode()
    def evaluate(self) -> float:
        """
        Evaluate the model on the test set and return the accuracy.
        """
        self.model.eval()
        total_correct, total_samples = 0, 0

        for batch in tqdm(self.test_loader, desc="Evaluating"):
            batch = batch.to(device)
            logits, _ = self.model(batch[:, :-1], batch[:, 1:])
            predicted_tokens = logits.argmax(dim=-1)
            total_correct += (predicted_tokens == batch[:, 1:]).sum().item()
            total_samples += batch.size(0) * (batch.size(1) - 1)

        accuracy = total_correct / total_samples
        wandb.log({"accuracy": accuracy}, step=self.step)
        return accuracy

    def train(self):
        """
        Trains the model, for `self.args.epochs` epochs. Also handles wandb initialisation, and early stopping
        for each epoch at `self.args.max_steps_per_epoch` steps.
        """
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        accuracy = np.nan

        progress_bar = tqdm(total=self.args.max_steps_per_epoch * self.args.epochs)

        for epoch in range(self.args.epochs):
            for i, batch in enumerate(self.train_loader):
                loss = self.training_step(batch.to(device))
                progress_bar.update()
                progress_bar.set_description(f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {accuracy:.3f}")
                if i >= self.args.max_steps_per_epoch:
                    break

            accuracy = self.evaluate()

        wandb.finish()

trainer = TransformerTrainer(args, model)
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33malfredwong[0m ([33malfredwong-university-of-cambridge[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.



'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.

Evaluating: 100%|██████████| 32/32 [00:00<00:00, 140.11it/s]960 [00:06<14:15, 47.55it/s]
Evaluating: 100%|██████████| 32/32 [00:00<00:00, 225.92it/s]40960 [00:12<14:04, 47.82it/s]
Evaluating: 100%|██████████| 32/32 [00:00<00:00, 216.73it/s]40960 [00:19<14:08, 47.25it/s]
Evaluating: 100%|██████████| 32/32 [00:00<00:00, 208.03it/s]/40960 [00:26<15:04, 44.03it/s]
Evaluating: 100%|██████████| 32/32 [00:00<00:00, 189.99it/s]/40960 [00:32<15:34, 42.32it/s]
Evaluating: 100%|██████████| 32/32 [00:00<00:00, 215.29it/s]/40960 [00:38<14:01, 46.68it/s]
Evaluating: 100%|██████████| 32/32 [00:00<00:00, 215.59it/s]/40960 [00:45<13:55, 46.68it/s]
Evaluating: 100%|██████████| 32/32 [00:00<00:00, 214.31it/s]/40960 [00:51<13:58, 46.16it/s]


0,1
accuracy,▁▃▄▅▆█▇█
train_loss,██▇▅▄▄▃▃▃▃▃▃▂▃▂▂▃▂▂▂▂▂▂▂▂▁▂▁▂▂▁▂▁▁▁▁▁▁▁▁

0,1
accuracy,0.224
train_loss,2.02685



Run (cwq3r4fv) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.

Epoch 8, loss: 2.027, accuracy: 0.220:   6%|▌         | 2256/40960 [00:56<16:06, 40.04it/s]


In [None]:
weights_dir = data_dir / "weights"
weights_dir.mkdir(exist_ok=True)
weights_path = weights_dir / f"othello_{n_games}_{size}"

weights = model.state_dict()
t.save(weights, weights_path)
weights = t.load(weights_dir / f"othello_{n_games}_{size}", weights_only=True)
model = GPT(cfg).load_state_dict(weights)

ModuleNotFoundError: No module named 'othello_gpt.model.nanoGPT.model'; 'othello_gpt.model.nanoGPT' is not a package

In [8]:
test_game = dataset_dict["test"][0]
tokens = t.tensor(tokenize(test_game["moves"])["input_ids"]).unsqueeze(0).to(device)
logits, loss = model(tokens[:, :-1], tokens[:, 1:])
probs = logits.softmax(-1)

n_moves = probs.shape[1]
test_probs = test_game.copy()
prob_boards = t.full((n_moves, size, size), fill_value=0.0, device=device)
for i in range(n_moves):
    prob_boards[i].flatten()[all_squares] = probs[0, i, 1:] / probs[0, i, 1:].sum()

test_pred = test_game.copy()
test_pred["boards"] = prob_boards.detach().cpu().numpy()
plot_game(test_game)
plot_game(test_pred, reversed=False, textcolor="pink", hovertext=test_pred["boards"])

KeyError: 'histories'

In [9]:
import plotly.graph_objects as go

logits = model(t.tensor([[9]], device=device))[0].detach().cpu()
board = t.zeros((size, size))
board.flatten()[all_squares] = logits[0, 0, 1:].softmax(-1)

fig = go.Figure()
fig.add_trace(
    go.Heatmap(
        z=board,
        colorscale="gray",
        x=list("ABCDEF"),
        y=list(range(1, 7)),
        xgap=0.2,
        ygap=0.2,
    )
)
fig.update_yaxes(
    showline=True,
    linecolor="black",
    linewidth=1,
    mirror=True,
    constrain="domain",
    autorange="reversed",
)

fig.update_xaxes(
    showline=True,
    linecolor="black",
    linewidth=1,
    mirror=True,
    scaleanchor="y",
    scaleratio=1,
    constrain="domain",
)

fig.update_layout(
    width=400,
    height=300,
    margin=dict(l=20, r=20, t=20, b=20),
)
fig.show()

NameError: name 'model' is not defined