In [22]:
import numpy as np
import torch
import torch.nn as nn
import random
from torch.nn import functional as F
from minGPT.mingpt import model
# make deterministic
from minGPT.mingpt.utils import set_seed
set_seed(42)
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

In [25]:
import automataBattle
from importlib import reload
reload(automataBattle)
from torch.utils.data import Dataset
class FastLearnAutomataDataset(Dataset):
    def __init__(self, nStates, nSymbols, split, sequenceLen, numSequences):
        self.nStates = nStates
        self.nSymbols = nSymbols
        self.split = split # train/test
        self.vocab_size = nSymbols*nSymbols
        # +1 due to potential carry overflow, but then -1 because very last digit doesn't plug back
        self.block_size = sequenceLen
        
        self.sequenceLen, self.numSequences = sequenceLen, numSequences
        
        '''
        # split up all addition problems into either training data or test data
        num = (10**self.ndigit)**2 # total number of possible combinations
        r = np.random.RandomState(1337) # make deterministic
        perm = r.permutation(num)
        num_test = min(int(num*0.2), 1000) # 20% of the whole dataset, or only up to 1000
        self.ixes = perm[:num_test] if split == 'test' else perm[num_test:]
        '''


    def __len__(self):
        return self.numSequences

    def __getitem__(self, idx):
        
        a = automataBattle.Automata(nStates=self.nStates, symbols=range(self.nSymbols), randomConnect=True)
        a.minimize()
        while a.complexity() != self.nStates:
            a = automataBattle.Automata(nStates=self.nStates, symbols=range(self.nSymbols), randomConnect=True)
            a.minimize()
        X, Y = a.generate(self.sequenceLen, lambda: random.choice(range(self.nSymbols)))
        x = torch.tensor(X)
        y = torch.tensor(Y) # predict the output of the Automata
        previous = y[:-1]
        shiftedForwadInputsOne = x[1:]
        outputs = y[1:] # Todo: look into encoding multiple things ("tuple encodings") instead of this gross thing
        xOutput = shiftedForwadInputsOne+previous*self.nSymbols
        yOutput = outputs
        return xOutput, yOutput
        
        '''
        # given a problem index idx, first recover the associated a + b
        idx = self.ixes[idx]
        nd = 10**self.ndigit
        a = idx // nd
        b = idx %  nd
        c = a + b
        render = f'%0{self.ndigit}d%0{self.ndigit}d%0{self.ndigit+1}d' % (a,b,c) # e.g. 03+25=28 becomes "0325028" 
        dix = [int(s) for s in render] # convert each character to its token index
        # x will be input to GPT and y will be the associated expected outputs
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence
        y[:self.ndigit*2-1] = -100 # we will only train in the output locations. -100 will mask loss to zero
        return x, y
        '''

In [62]:

import minGPT
from importlib import reload
from minGPT.mingpt import trainer
from minGPT.mingpt import model
reload(minGPT.mingpt.model)
reload(minGPT.mingpt.trainer)
from minGPT.mingpt.model import GPT, GPTConfig, GPT1Config
import gc
model = None
train_dataset = None
test_dataset = None
gc.collect()
torch.cuda.empty_cache()
gc.collect()
train_dataset = FastLearnAutomataDataset(nStates=3, nSymbols=2, split='train', sequenceLen=200, numSequences=600000)
test_dataset = FastLearnAutomataDataset(nStates=3, nSymbols=2, split='test', sequenceLen=200, numSequences=10)
print(train_dataset[0], train_dataset[1])
# initialize a baby GPT model
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, 
                  n_layer=8, n_head=8, n_embd=128)
model = GPT(mconf)
from minGPT.mingpt.trainer import Trainer, TrainerConfig
set_seed(27)

    

10/20/2020 01:51:10 - INFO - minGPT.mingpt.model -   number of parameters: 1.613056e+06


(tensor([0, 2, 2, 2, 3, 2, 2, 3, 3, 3, 1, 3, 1, 3, 1, 2, 0, 3, 3, 3, 0, 2, 2, 2,
        2, 3, 3, 2, 1, 2, 2, 3, 3, 2, 0, 3, 3, 3, 0, 2, 2, 2, 2, 3, 3, 3, 0, 2,
        2, 3, 3, 2, 1, 3, 3, 0, 3, 3, 3, 1, 2, 0, 2, 2, 3, 3, 2, 0, 2, 2, 2, 2,
        3, 3, 3, 1, 2, 1, 2, 2, 2, 3, 3, 2, 1, 3, 3, 1, 2, 1, 2, 3, 3, 3, 0, 3,
        3, 2, 1, 3, 3, 1, 3, 0, 2, 2, 3, 3, 2, 1, 3, 3, 1, 3, 1, 2, 1, 3, 3, 0,
        2, 2, 2, 3, 2, 2, 2, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2,
        2, 2, 3, 3, 2, 1, 3, 3, 1, 2, 1, 2, 2, 2, 3, 2, 2, 3, 2, 3, 2, 2, 3, 2,
        2, 3, 3, 2, 0, 2, 3, 2, 3, 3, 3, 1, 2, 0, 2, 3, 2, 2, 3, 2, 3, 2, 2, 3,
        3, 3, 1, 3, 0, 3, 3]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,
        1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
        1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,
        1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1,
        

In [63]:

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=100, batch_size=256, learning_rate=6e-5,
                      lr_decay=True, warmup_tokens=256, final_tokens=50*len(train_dataset)*(2+1),
                      num_workers=0)
trainer = Trainer(model, train_dataset, test_dataset, tconf)
trainer.train()

epoch 1 iter 2343: train loss 0.48253. lr 1.445926e-05: 100%|██████████| 2344/2344 [15:18<00:00,  2.55it/s]
10/20/2020 02:06:29 - INFO - minGPT.mingpt.trainer -   test loss: 0.493038
epoch 2 iter 2343: train loss 0.49297. lr 1.610073e-05: 100%|██████████| 2344/2344 [15:18<00:00,  2.55it/s]
10/20/2020 02:21:48 - INFO - minGPT.mingpt.trainer -   test loss: 0.465012
epoch 3 iter 774: train loss 0.48079. lr 6.000000e-06:  33%|███▎      | 775/2344 [05:03<10:14,  2.55it/s]


KeyboardInterrupt: 

In [49]:
ckpt_path = "juniper_fit_128_3_states"
raw_model = model.module if hasattr(model, "module") else model
torch.save(raw_model.state_dict(), ckpt_path)

# seems 8layer, 8head, embed32 got stuck at around 0.5, but it's possible it could have gone further
# juniper_fit fit really well, n_layer=8, n_head=8, n_embd=64