In [6]:
import torch
from EWOthello.data.othello import *

from EWOthello.mingpt.dataset import CharDataset # AK's mingpt 
from EWOthello.mingpt.model import GPT, GPTConfig
from EWOthello.mingpt.trainer import Trainer, TrainerConfig
from EWOthello.mingpt.utils import set_seed, sample

set_seed(44)

### Get some intuition about training times

In [7]:
othello = get(ood_num=-1, data_root=None)
train_dataset = CharDataset(othello) 

# Example of the training data pair (blocked results format)
# Note that values is not the board number but is the board number converted to dictionary index!
x, y = train_dataset[5]
print(x)
print(y)

Max num files: 230; Use_num: 10
['gen10e5__20220324_165952.pickle', 'gen10e5__20220324_154919.pickle', 'gen10e5__20220324_164123.pickle', 'gen10e5__20220324_154043.pickle', 'gen10e5__20220324_155251.pickle', 'gen10e5__20220324_160016.pickle', 'gen10e5__20220324_165748.pickle', 'gen10e5__20220324_154002.pickle', 'gen10e5__20220324_155241.pickle', 'gen10e5__20220324_165707.pickle']


Mem Used: 2.324 GB: 100%|██████████| 10/10 [00:01<00:00,  5.48it/s]


Deduplicating...
Deduplicating finished with 999947 games left
Using 20 million for training, 0 for validation
Dataset created has 999947 sequences, 61 unique words.
tensor([20, 19, 18, 10,  2,  1, 27,  3, 41, 26, 25, 21, 11, 42, 22, 14, 34, 17,
        13, 23, 50, 39, 33, 43, 36, 31, 28, 51, 15, 12,  9, 35, 30,  8, 47, 16,
        40, 48, 32, 46, 60, 49, 57, 55, 29,  5, 45, 38, 37, 58, 24, 59, 52, 54,
        44, 53,  6,  7,  4])
tensor([19, 18, 10,  2,  1, 27,  3, 41, 26, 25, 21, 11, 42, 22, 14, 34, 17, 13,
        23, 50, 39, 33, 43, 36, 31, 28, 51, 15, 12,  9, 35, 30,  8, 47, 16, 40,
        48, 32, 46, 60, 49, 57, 55, 29,  5, 45, 38, 37, 58, 24, 59, 52, 54, 44,
        53,  6,  7,  4, 56])


In [8]:
#mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, n_layer=8, n_head=8, n_embd=512)
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, n_layer=2, n_head=8, n_embd=512)
model = GPT(mconf)

In [10]:
# On the full synthetic dataset, we can estimate an order of 3+ hours on single gpu with settings below
# It would be reasonable to train on cluster with these settings and this architecture and we could likely train 
# the smaller real world dataset locally
max_epochs = 32
savepath = "../EWOthello/ckpts/"

t_start = time.strftime("_%Y%m%d_%H%M%S")
tconf = TrainerConfig(
    max_epochs=max_epochs, 
    batch_size=1024, 
    learning_rate=5e-4,
    lr_decay=True, 
    warmup_tokens=len(train_dataset)*train_dataset.block_size*5, 
    final_tokens=len(train_dataset)*train_dataset.block_size*max_epochs,
    num_workers=0, 
    ckpt_path=savepath + f"gpt_at{t_start}.ckpt", 
)
trainer = Trainer(model, train_dataset, None, tconf)
device = trainer.device
trainer.train()

Trainer on GPU


epoch 1 iter 976: train loss 2.71466. lr 1.000000e-04: 100%|██████████| 977/977 [02:53<00:00,  5.63it/s]
epoch 2 iter 976: train loss 2.53513. lr 2.000000e-04: 100%|██████████| 977/977 [02:55<00:00,  5.58it/s]
epoch 3 iter 976: train loss 2.43578. lr 3.000000e-04: 100%|██████████| 977/977 [02:54<00:00,  5.60it/s]
epoch 4 iter 976: train loss 2.37902. lr 4.000000e-04: 100%|██████████| 977/977 [02:53<00:00,  5.63it/s]
epoch 5 iter 976: train loss 2.35797. lr 5.000000e-04: 100%|██████████| 977/977 [02:56<00:00,  5.55it/s]
epoch 6 iter 976: train loss 2.32269. lr 4.983096e-04: 100%|██████████| 977/977 [02:56<00:00,  5.54it/s]
epoch 7 iter 976: train loss 2.30786. lr 4.932612e-04: 100%|██████████| 977/977 [02:54<00:00,  5.61it/s]
epoch 8 iter 976: train loss 2.30639. lr 4.849232e-04: 100%|██████████| 977/977 [02:56<00:00,  5.54it/s]
epoch 9 iter 976: train loss 2.29102. lr 4.734082e-04: 100%|██████████| 977/977 [02:55<00:00,  5.57it/s]
epoch 10 iter 976: train loss 2.27887. lr 4.588720e-04:

In [11]:
val_dat = othello.val
print(len(val_dat))

total_nodes=0
success_nodes = 0
bar = tqdm(val_dat[:1000])
for game in bar:
    len_game = len(game)
    for len_partial_game in range(1, len_game):
        total_nodes += 1
        context = game[:len_partial_game]
        x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None, ...].to(device)
        y = sample(model, x, 1, temperature=1.0)[0]
        completion = [train_dataset.itos[int(i)] for i in y]
        #completion = [train_dataset.itos[int(i)] for i in y if i != -1]
        try:
            OthelloBoardState().update(completion, prt=False)
        except Exception:
            pass
        else:
            success_nodes += 1
    bar.set_description(f"{success_nodes/total_nodes*100:.2f}% pass rate: {success_nodes}/{total_nodes} among all searched nodes")
    
print(f"{success_nodes/total_nodes*100:.2f}% pass rate: {success_nodes}/{total_nodes} among all searched nodes")
 

0


0it [00:00, ?it/s]


ZeroDivisionError: division by zero

### Probe loaded checkpoint

In [2]:
othello = get(ood_num=100, data_root=None)
train_dataset = CharDataset(othello) 

100%|██████████| 100/100 [00:00<00:00, 364.99it/s]

Dataset created has 100 sequences, 61 unique words.





In [3]:
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, n_layer=8, n_head=8, n_embd=512)
model = GPT(mconf)
load_res = model.load_state_dict(torch.load("../EWOthello/ckpts/gpt_synthetic.ckpt"))
if torch.cuda.is_available():
    device = torch.cuda.current_device()
    model = model.to(device)

In [43]:
val_dat = othello.val
print(len(val_dat))

total_nodes=0
success_nodes = 0
bar = tqdm(val_dat[:1000])
for game in bar:
    len_game = len(game)
    for len_partial_game in range(1, len_game):
        total_nodes += 1
        context = game[:len_partial_game]
        x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None, ...].to(device)
        y = sample(model, x, 1, temperature=1.0)[0]
        completion = [train_dataset.itos[int(i)] for i in y]
        #completion = [train_dataset.itos[int(i)] for i in y if i != -1]
        try:
            OthelloBoardState().update(completion, prt=False)
        except Exception:
            pass
        else:
            success_nodes += 1
    bar.set_description(f"{success_nodes/total_nodes*100:.2f}% pass rate: {success_nodes}/{total_nodes} among all searched nodes")
    
print(f"{success_nodes/total_nodes*100:.2f}% pass rate: {success_nodes}/{total_nodes} among all searched nodes")
 

3096133


99.99% pass rate: 58940/58947 among all searched nodes: 100%|██████████| 1000/1000 [04:23<00:00,  3.80it/s]

99.99% pass rate: 58940/58947 among all searched nodes



