In [1]:
from dataclasses import dataclass
from pathlib import Path

import huggingface_hub as hf
import numpy as np
import torch as t
import wandb
from datasets import load_dataset
from jaxtyping import Float, Int
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

from othello_gpt.data.vis import plot_game
from othello_gpt.model.nanoGPT import GPT, GPTConfig
from othello_gpt.util import pad_batch, get_all_squares

In [2]:
device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")
device

device(type='mps')

In [3]:
root_dir = Path().cwd().parent.parent.parent
data_dir = root_dir / "data"
n_games = 1000000
size = 6

In [4]:
hf.login(token=(root_dir / "secret.txt").read_text())
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33malfredwong[0m ([33malfredwong-university-of-cambridge[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
dataset_dict = load_dataset("awonga/othello-gpt")
plot_game(dataset_dict["test"][0], subplot_size=180, n_cols=8)

Resolving data files:   0%|          | 0/42 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/42 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/42 [00:00<?, ?it/s]

In [10]:
class HubGPT(GPT, hf.PyTorchModelHubMixin):
    pass

cfg = GPTConfig(
    # block_size=(size * size - 4) * 2 - 1,
    block_size=(size * size - 4) - 1,
    # vocab_size=size * size - 4 + 2,  # pass and pad
    vocab_size=size * size - 4 + 1,  # pad
    n_layer=2,
    n_head=2,
    n_embd=128,
    dropout= 0.0,
    bias=False,
)
display(cfg)
model = HubGPT(cfg).to(device)

GPTConfig(block_size=31, vocab_size=33, n_layer=2, n_head=2, n_embd=128, dropout=0.0, bias=False)

number of parameters: 0.40M


In [11]:
@dataclass
class TransformerTrainingArgs:
    batch_size: int = 1024
    epochs: int = 32
    max_steps_per_epoch: int = 1000
    lr: int = 5e-4
    weight_decay: int = 1e-3
    wandb_project: str | None = "othello-gpt"
    wandb_name: str | None = None

args = TransformerTrainingArgs()

In [12]:
class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: GPT):
        super().__init__()
        self.model = model
        self.args = args

        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.step = 0

        def collate_fn(batch):
            return pad_batch(batch, model.config.block_size + 1)

        self.train_loader = DataLoader(dataset_dict["train"]["input_ids"], batch_size=args.batch_size, shuffle=True, pin_memory=True, collate_fn=collate_fn)
        self.test_loader = DataLoader(dataset_dict["test"]["input_ids"], batch_size=args.batch_size, shuffle=False, pin_memory=True, collate_fn=collate_fn)

    def training_step(self, batch: Int[Tensor, "batch seq"]) -> Float[Tensor, ""]:
        """
        Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.

        Remember that `batch` is a dictionary with the single key 'tokens'.
        """
        _, loss = model(batch[:, :-1], batch[:, 1:])
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.step += 1
        wandb.log({"train_loss": loss}, step=self.step)
        return loss

    @t.inference_mode()
    def evaluate(self) -> float:
        """
        Evaluate the model on the test set and return the accuracy.
        """
        self.model.eval()
        total_correct, total_samples = 0, 0

        for batch in tqdm(self.test_loader, desc="Evaluating"):
            batch = batch.to(device)
            logits, _ = self.model(batch[:, :-1], batch[:, 1:])
            predicted_tokens = logits.argmax(dim=-1)
            total_correct += (predicted_tokens == batch[:, 1:]).sum().item()
            total_samples += batch.size(0) * (batch.size(1) - 1)

        accuracy = total_correct / total_samples
        wandb.log({"accuracy": accuracy}, step=self.step)
        return accuracy

    def train(self):
        """
        Trains the model, for `self.args.epochs` epochs. Also handles wandb initialisation, and early stopping
        for each epoch at `self.args.max_steps_per_epoch` steps.
        """
        config_dict = model.config.__dict__.copy()
        config_dict.update(args.__dict__)
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=config_dict)
        accuracy = np.nan

        progress_bar = tqdm(total=self.args.max_steps_per_epoch * self.args.epochs)

        for epoch in range(self.args.epochs):
            for i, batch in enumerate(self.train_loader):
                loss = self.training_step(batch.to(device))
                progress_bar.update()
                progress_bar.set_description(f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {accuracy:.3f}")
                if i >= self.args.max_steps_per_epoch:
                    break

            accuracy = self.evaluate()

        wandb.finish()

trainer = TransformerTrainer(args, model)
trainer.train()

Evaluating: 100%|██████████| 196/196 [00:07<00:00, 25.21it/s]000 [01:24<44:19, 11.66it/s]
Evaluating: 100%|██████████| 196/196 [00:06<00:00, 28.39it/s]32000 [02:57<40:55, 12.22it/s]   
Evaluating: 100%|██████████| 196/196 [00:07<00:00, 26.91it/s]32000 [04:30<41:43, 11.58it/s]  
Evaluating: 100%|██████████| 196/196 [00:07<00:00, 26.19it/s]32000 [06:05<41:06, 11.35it/s]  
Evaluating: 100%|██████████| 196/196 [00:06<00:00, 28.41it/s]32000 [07:42<37:32, 11.98it/s]  
Evaluating: 100%|██████████| 196/196 [00:06<00:00, 28.24it/s]32000 [09:13<34:58, 12.39it/s]  
Evaluating: 100%|██████████| 196/196 [00:06<00:00, 28.60it/s]32000 [10:44<34:21, 12.13it/s]  
Evaluating: 100%|██████████| 196/196 [00:06<00:00, 28.98it/s]32000 [12:14<33:30, 11.93it/s]  
Evaluating: 100%|██████████| 196/196 [00:06<00:00, 29.64it/s]32000 [13:45<31:42, 12.08it/s]  
Evaluating: 100%|██████████| 196/196 [00:06<00:00, 28.20it/s]0/32000 [15:15<30:16, 12.10it/s] 
Evaluating: 100%|██████████| 196/196 [00:06<00:00, 29.02it/s]1

0,1
accuracy,▁▂▃▄▄▄▅▅▆▆▆▇▇▇▇▇▇▇▇█████████████
train_loss,█▆▆▅▅▅▅▅▅▅▅▅▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.25983
train_loss,1.69932



Run (q2wl9r7m) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.

Epoch 32, loss: 1.699, accuracy: 0.260: : 32032it [48:09, 11.08it/s]


In [13]:
model.push_to_hub("awonga/othello-gpt")

model.safetensors:   0%|          | 0.00/1.61M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/awonga/othello-gpt/commit/3ea2249da9a197bdb44adfba3150a612034f4207', commit_message='Push model using huggingface_hub.', commit_description='', oid='3ea2249da9a197bdb44adfba3150a612034f4207', pr_url=None, repo_url=RepoUrl('https://huggingface.co/awonga/othello-gpt', endpoint='https://huggingface.co', repo_type='model', repo_id='awonga/othello-gpt'), pr_revision=None, pr_num=None)

In [14]:
n_focus = 50
focus_games = dataset_dict["test"].take(n_focus)
focus_input_ids = pad_batch(focus_games["input_ids"], max_len=cfg.block_size + 1).to(
    device
)
focus_logits, loss = model(focus_input_ids[:, :-1], focus_input_ids[:, 1:])
focus_logit_boards = t.full((n_focus, focus_logits.shape[1], size, size), 0.0)
focus_logit_boards.flatten(2)[..., get_all_squares(size)] = focus_logits[..., 1:].detach().cpu()

In [15]:
test_index = 0
test_pred_model = {
    "boards": focus_logit_boards[test_index].detach().cpu(),
    "legalities": focus_games[test_index]["legalities"],
    "moves": focus_games[test_index]["moves"],
}

plot_game(focus_games[test_index], title="Ground truth board states and legal moves")
plot_game(
    test_pred_model,
    reversed=False,
    textcolor="red",
    hovertext=test_pred_model["boards"],
    title="Model predictions for legal moves",
)