In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
import math
import EWOthello.utils.plot_helpers as plt_util
from EWOthello.data.othello import *
from EWOthello.mingpt.dataset import ProbingDataset, CharDataset # AK's mingpt data child 
from EWOthello.mingpt.model import GPT, GPTConfig, GPTforProbing, GPTforProbing_v2
from EWOthello.mingpt.utils import set_seed
set_seed(44)

print(torch.cuda.is_available())
device = torch.cuda.current_device()
print(torch.cuda.get_device_name(device))

  from .autonotebook import tqdm as notebook_tqdm


True
NVIDIA GeForce RTX 3090 Ti


In [3]:
def train_GPT_Othello(game_dataset, n_layers, n_heads, batch_size=64, num_epochs=100, train_to=120000, train_ratio=0.8, val_legal_stepsize=10, num_val=500, save_at_steps=20, learning_rate=1e-4, lr_schedule=False, verbose=False):
    model_name = f"GPT_Synthetic_{n_layers}Layers_{n_heads}Heads"
    savepath = f"../EWOthello/ckpts/Dean_GPTv2_Synthetic_{n_layers}L{n_heads}H/"
    mconf = GPTConfig(vocab_size=61, block_size=59, n_layer=n_layers, n_head=n_heads, n_embd=512)
    model = GPT(mconf)
    model = model.to(device)

    training_data = game_dataset
    train_size = int(train_ratio * len(training_data))
    test_size = len(training_data) - train_size - num_val
    val_size = num_val

    train_dataset, test_dataset, val_dataset = random_split(training_data, [train_size, test_size, val_size])
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True)
    test_iter = iter(test_dataloader)

    warm_up_tokens = train_size*59
    warm_up_max = train_size*59*num_epochs
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    training_loss_history = []
    testing_loss_history = []
    perc_legal_games = []
    tokens_processed = 0
    num_steps = 0 if len(training_loss_history)==0 else len(training_loss_history)

    print(f"Training Model Name: {model_name}; training/test set size {train_size}/{test_size}; validation games {val_size}")
    if not os.path.exists(savepath):
        os.mkdir(savepath)

    if os.path.exists(savepath + model_name + ".ckpt"):
        model.load_state_dict(torch.load(savepath + model_name + ".ckpt"))
        print(f"Loaded model checkpopint from {savepath + model_name + '.ckpt'}")
        with open(savepath + model_name + ".pickle", 'rb') as fhandle:
            training_history = pickle.load(fhandle)
            training_loss_history = training_history["training_loss"]
            testing_loss_history = training_history["testing_loss"]
            perc_legal_games = training_history["val_legal_perc"]
            tokens_processed = training_history["tokens_processed"]

    num_steps = 0 if len(training_loss_history)==0 else len(training_loss_history)
    print(f"Training steps Current: {num_steps}")

    for epoch in range(num_epochs):
        i = 0
        for (x,y) in tqdm(train_dataloader):
            if num_steps > train_to:
                break

            # Run update training step SGD
            model.train()
            x = x.to(device)
            y = y.to(device)
            
            logits, loss = model(x, y)
            train_loss = loss.item()
            training_loss_history.append(train_loss)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update the learning rate if using non-fixed scheduler
            if lr_schedule:
                tokens_processed += (y>=0).sum()
                if tokens_processed < warm_up_tokens:
                    lr_mult = tokens_processed / warm_up_tokens
                else:
                    progress = (tokens_processed - warm_up_tokens) / (warm_up_max-warm_up_tokens)
                    lr_mult = max(0.1, 0.5*(1 + math.cos(math.pi*progress)))
                for param_group in optimizer.param_groups:
                    param_group["lr"] = learning_rate * lr_mult

            ## Compute the error on test batch
            model.eval()
            with torch.no_grad():
                try:
                    x,y = next(test_iter)
                except:
                    test_iter = iter(test_dataloader)
                    x,y = next(test_iter)
                x = x.to(device)
                y = y.to(device)        
                logits, loss = model(x,y)
                test_loss = loss.item()
                testing_loss_history.append(test_loss)

            ## After a certain number of steps, calculate percent of legal moves the model plays
            if i % val_legal_stepsize == 0:
                legal_moves_played = 0
                for _, (x,y) in enumerate(val_dataloader):
                    x = x.to(device)
                    logits,_ = model(x)
                    moves = torch.argmax(logits, dim=2)[0]
                    moves = moves.detach().cpu().numpy()
                    x = x.detach().cpu().numpy()
                    for len_partial in range(59):
                        partial_x = list(x[0,:len_partial+1])
                        partial_x.append(moves[len_partial])
                        game_string = [training_data.itos[int(move_idx)] for move_idx in partial_x]
                        try:
                            OthelloBoardState().update(game_string, prt=False)
                        except Exception:
                            pass
                        else:
                            legal_moves_played +=1        
                perc_legal_games.append(legal_moves_played/val_size/59)
            else:
                perc_legal_games.append(None)
        
            ## Save/print
            num_steps +=1
            i = i + 1
            if (i+1) % save_at_steps == 0:
                #print("Saving Model Checkpoint")
                torch.save(model.state_dict(), savepath + model_name + ".ckpt")
                training_history = {"training_loss": training_loss_history, "testing_loss": testing_loss_history, "val_legal_perc": perc_legal_games, "tokens_processed": tokens_processed}
                with open(savepath + model_name + ".pickle", 'wb') as fhandle:
                    pickle.dump(training_history, fhandle)
        
        # Save after each epoch also
        torch.save(model.state_dict(), savepath + model_name + ".ckpt")
        training_history = {"training_loss": training_loss_history, "testing_loss": testing_loss_history, "val_legal_perc": perc_legal_games, "tokens_processed": tokens_processed}
        with open(savepath + model_name + ".pickle", 'wb') as fhandle:
            pickle.dump(training_history, fhandle)

    return

In [4]:
othello = get(ood_num=-1, data_root=None, num_preload=200) 
game_dataset = CharDataset(othello) 

Max num files: 230; Use_num: 200
['gen10e5__20220324_165952.pickle', 'gen10e5__20220324_154919.pickle', 'gen10e5__20220324_164123.pickle', 'gen10e5__20220324_154043.pickle', 'gen10e5__20220324_155251.pickle', 'gen10e5__20220324_160016.pickle', 'gen10e5__20220324_165748.pickle', 'gen10e5__20220324_154002.pickle', 'gen10e5__20220324_155241.pickle', 'gen10e5__20220324_165707.pickle', 'gen10e5__20220324_160046.pickle', 'gen10e5__20220324_154811.pickle', 'gen10e5__20220324_154806.pickle', 'gen10e5__20220324_162637.pickle', 'gen10e5__20220324_154048.pickle', 'gen10e5__20220324_155439.pickle', 'gen10e5__20220324_155255.pickle', 'gen10e5__20220324_154235.pickle', 'gen10e5__20220324_160049.pickle', 'gen10e5__20220324_154032.pickle', 'gen10e5__20220324_164213.pickle', 'gen10e5__20220324_155245.pickle', 'gen10e5__20220324_154722.pickle', 'gen10e5__20220324_165841.pickle', 'gen10e5__20220324_162202.pickle', 'gen10e5__20220324_154533.pickle', 'gen10e5__20220324_164648.pickle', 'gen10e5__20220324_17

Mem Used: 12.78 GB: 100%|██████████| 200/200 [00:52<00:00,  3.81it/s]


Deduplicating...
Deduplicating finished with 19996732 games left
Using 20 million for training, 0 for validation
Dataset created has 19996732 sequences, 61 unique words.


In [5]:
for num_layers in [8]:
    train_GPT_Othello(game_dataset, n_layers=num_layers, n_heads=8, batch_size=512, learning_rate=1e-4, lr_schedule=False, num_epochs=4, train_to=120000, val_legal_stepsize=1000, save_at_steps=1000)

Training Model Name: GPT_Synthetic_8Layers_8Heads; training/test set size 15997385/3998847; validation games 500
Loaded model checkpopint from ../EWOthello/ckpts/Dean_GPTv2_Synthetic_8L8H/GPT_Synthetic_8Layers_8Heads.ckpt
Training steps Current: 84112


  0%|          | 84/31245 [00:53<3:46:01,  2.30it/s] 

In [None]:
for num_layers in [1, 4, 8]:
    train_GPT_Othello(game_dataset, n_layers=num_layers, n_heads=1, batch_size=512, learning_rate=1e-4, lr_schedule=False, num_epochs=5, train_to=120000,val_legal_stepsize=1000, save_at_steps=1000)

Training Model Name: GPT_Synthetic_8Layers_1Heads; training/test set size 7998824/1999206; validation games 500


100%|██████████| 15623/15623 [1:54:29<00:00,  2.27it/s]  


Training Model Name: GPT_Synthetic_4Layers_1Heads; training/test set size 7998824/1999206; validation games 500


100%|██████████| 15623/15623 [1:03:11<00:00,  4.12it/s]


Training Model Name: GPT_Synthetic_2Layers_1Heads; training/test set size 7998824/1999206; validation games 500


100%|██████████| 15623/15623 [38:01<00:00,  6.85it/s]  


Training Model Name: GPT_Synthetic_1Layers_1Heads; training/test set size 7998824/1999206; validation games 500


100%|██████████| 15623/15623 [25:21<00:00, 10.27it/s]  
