## Train Othello-GPT and save to `ckpts`

Use `jupyter nbconvert --execute --to notebook --allow-errors --ExecutePreprocessor.timeout=-1 train_gpt_othello.ipynb --inplace --output ckpts/checkpoint.ipynb` to run in background

In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
# make deterministic
from mingpt.utils import set_seed
set_seed(44)

In [5]:
import os
import math
import time
from tqdm import tqdm
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
from torch.nn import functional as F
from data import get_othello
from data.othello import permit, start_hands, OthelloBoardState, permit_reverse
from mingpt.dataset import CharDataset
from mingpt.utils import sample
from mingpt.model import GPT, GPTConfig
from mingpt.trainer import Trainer, TrainerConfig

In [None]:
synthetic_or_championship = True  # True for training on the synthetic dataset

In [63]:
class RPSData:
    def __init__(self, data_path):
        self.sequences = self.get_sequences(data_path)
    def get_sequences(self, data_path):
        with open(data_path, "r") as f:
            data = f.readlines()
        return [s.replace("\n", "").split(" ") for s in data]
    def __len__(self):
        return len(self.sequences)
    def __getitem__(self, idx):
        return self.sequences[idx]

In [67]:
data_path = "/Users/davidsewell/Projects/open_spiel/notebook/rps_tourney.txt"

rps_data = RPSData(data_path)

train_dataset = CharDataset.load_rps_data(rps_data)
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, n_layer=8, n_head=8, n_embd=128)
model = GPT(mconf)

Dataset created has 17792 sequences, 53 unique words.


In [68]:
#othello = get_othello(ood_num=-1, data_root=None if synthetic_or_championship else "data/othello_championship", wthor=True)
#train_dataset = CharDataset(othello)
#mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, n_layer=8, n_head=8, n_embd=512)
#model = GPT(mconf)

In [None]:
max_epochs = 250
# initialize a trainer instance and kick off training
t_start = time.strftime("_%Y%m%d_%H%M%S")
tconf = TrainerConfig(
    max_epochs=max_epochs, 
    batch_size=512*1,  # assuming 8 GPU's
    learning_rate=5e-4,
    lr_decay=True, 
    warmup_tokens=len(train_dataset)*train_dataset.block_size*5, 
    final_tokens=len(train_dataset)*train_dataset.block_size*max_epochs,
    num_workers=0, 
    ckpt_path=f"./ckpts/gpt_at{t_start}.ckpt", 
)
trainer = Trainer(model, train_dataset, None, tconf)
device = trainer.device
print(t_start)
trainer.train()

_20230526_220525


epoch 1 iter 34: train loss 2.51305. lr 1.000000e-04: 100%|█████████████████████████████████████████| 35/35 [35:48<00:00, 61.38s/it]
epoch 2 iter 34: train loss 1.73731. lr 2.000000e-04: 100%|█████████████████████████████████████████| 35/35 [16:14<00:00, 27.85s/it]
epoch 3 iter 34: train loss 1.18844. lr 3.000000e-04: 100%|█████████████████████████████████████████| 35/35 [16:20<00:00, 28.01s/it]
epoch 4 iter 34: train loss 1.06038. lr 4.000000e-04: 100%|█████████████████████████████████████████| 35/35 [16:21<00:00, 28.05s/it]
epoch 5 iter 34: train loss 1.05481. lr 5.000000e-04: 100%|█████████████████████████████████████████| 35/35 [16:13<00:00, 27.82s/it]
epoch 6 iter 34: train loss 0.97667. lr 4.999794e-04: 100%|█████████████████████████████████████████| 35/35 [16:02<00:00, 27.49s/it]
epoch 7 iter 34: train loss 0.98106. lr 4.999178e-04: 100%|█████████████████████████████████████████| 35/35 [15:54<00:00, 27.27s/it]
epoch 8 iter 34: train loss 0.96858. lr 4.998150e-04: 100%|██████████

## Or load trained model from `ckpts`

In [42]:
load_res = model.load_state_dict(torch.load("./ckpts/gpt_at_20230526_170210.ckpt"))
if torch.cuda.is_available():
    device = torch.cuda.current_device()
    model = model.to(device)

In [53]:
def get_sample(model, data,idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
    for _ in range(max_new_tokens):
        # if the sequence context is growing too long we must crop it at block_size
        idx_cond = idx if idx.size(1) <= data.block_size else idx[:, -data.block_size:]
        # forward the model to get the logits for the index in the sequence
        logits, _ = model(idx_cond)
        # pluck the logits at the final step and scale by desired temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = F.softmax(logits, dim=-1)
        # either sample from the distribution or take the most likely element
        if do_sample:
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            _, idx_next = torch.topk(probs, k=1, dim=-1)
        # append sampled index to the running sequence and continue
        idx = torch.cat((idx, idx_next), dim=1)

    return idx


def generate(model, data, prompt='', num_samples=1, steps=100, do_sample=True):
        
    dix = [train_dataset.stoi[s] for s in prompt]
    
    x = torch.tensor(dix, dtype=torch.long)
    # we'll process all desired num_samples in a batch, so expand out the batch dim
    x = x.expand(num_samples, -1)

    # forward the model `steps` times to get samples, in a batch
    sample = get_sample(model, data, x, steps, do_sample=do_sample, top_k=None)

    return sample
    
    

In [49]:
prompt = ['rockbot', 'greenberg']

In [50]:
dix = [train_dataset.stoi[s] for s in prompt]

In [51]:
dix

[42, 24]

In [55]:
sample = generate(model, train_dataset, prompt)

In [62]:
len(sample[0])

102

In [60]:
[train_dataset.itos[int(s)] for s in sample[0]]

['rockbot',
 'greenberg',
 '0',
 '1',
 '0',
 '1',
 '0',
 '2',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1',
 '0',
 '1']