## Train a byte-pair-level GPT on some text data

The inputs here are simple text files, which we chop up to byte pair chunks and then train GPT on. So you could say this is a byte-pair-transformer instead of a byte-pair-rnn. Doesn't quite roll off the tongue as well or what not. In this example we will feed it some shakespear, which we'll get it to predict.

(Forked from Andrej karpathy's minGPT)

In [1]:
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

In [2]:
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

In [3]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

In [4]:
import math
import random
from torch.utils.data import Dataset

def split_word(word, length=2):
    if len(word) > length:
        return (word[n:n+length] for n in range(0, len(word), length))
    else:
        if len(word) > 0:
            return (word[n:n+len(word)] for n in range(0, len(word), len(word)))
        else:
            return ""

# Byte-Pair enabled
# Multiple token per bytepair
class BytePairDataset(Dataset):

    def __init__(self, fp, block_size, is_charlevel=True, bMultipleIns=False, bMultipleOuts=False, maxlines=10000):
        
        line = fp.readline()
        self.chars = set([])
        data = []
        self.bMultipleIns = bMultipleIns
        self.bMultipleOuts = bMultipleOuts
        totlinec = 0
        
        while line and totlinec<maxlines:
            lines = line
            newline = "dummy"
            linec = 0
            while newline and linec<10000 and totlinec<maxlines:
                newline = fp.readline()
                if newline:
                    lines += newline
                linec += 1
            
            totlinec += linec
            print(f"Line {totlinec}")
            
            if is_charlevel:
                if bMultipleIns or bMultipleOuts:
                    ch = list(set(lines))
                    ch1 = list(set(lines))

                    ch2 = [item + "§2§" for item in ch]
                    ch3 = [item + "§3§" for item in ch]
                    ch4 = [item + "§4§" for item in ch]
                    ch5 = [item + "§5§" for item in ch]
                    self.chars = self.chars.union(set(ch1+ch2+ch3+ch4+ch5))
                else:
                    self.chars = self.chars.union(list(set(lines)))
            else:
                ch1 = [item for word in lines.split("[NL]") for item in split_word(word)]
                if bMultipleIns or bMultipleOuts:
                    ch1 += [item + "§2§" for word in lines.split("[NL]") for item in split_word(word)]
                    ch1 += [item + "§3§" for word in lines.split("[NL]") for item in split_word(word)]
                    ch1 += [item + "§4§" for word in lines.split("[NL]") for item in split_word(word)]
                    ch1 += [item + "§5§" for word in lines.split("[NL]") for item in split_word(word)]
                    self.chars = self.chars.union(set(ch1))
                else:
                    self.chars = self.chars.union(set(ch1))
            
            if is_charlevel:
                data += lines
            else:
                data += [item for word in lines.split("[NL]") for item in split_word(word)]
            line = fp.readline()
            
        self.chars = sorted(self.chars)
            
        data_size, vocab_size = len(data), len(self.chars)
        print('data has %d characters, %d unique.' % (data_size, vocab_size))
        
        self.stoi = { ch:i for i,ch in enumerate(self.chars) }
        self.itos = { i:ch for i,ch in enumerate(self.chars) }
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data
        
            
    def stoimul(self, chunk):
        counters = {}
        imul = []
        for s in chunk:
            idx = self.stoi[s]
            if idx in counters:
                counters[idx] = min(5, counters[idx] + 1)
            else:
                counters[idx] = 1

            imul.append(idx + counters[idx] - 1)
        return imul
    
    def stoiouts(self, chunk):
        iouts = []
        for s in chunk:
            idx = self.stoi[s]
            #iouts.append(idx + random.randint(0, 4))
            iouts.append(idx)
        iouts[len(iouts)-1] = idx
        return iouts
    
    def stoimulOuts(self, alldx):
        iout = []
        idx = alldx[len(alldx)-1]
        for s in range(0, 5):
            irow = []
            nidx = idx + s
            for s2 in range(1, len(alldx)-1):
                iidx = alldx[s2]
                irow.append(iidx)
             
            irow.append(nidx)
            iout.append(torch.tensor(irow, dtype=torch.long))
                
        return iout
    
    def __len__(self):
        #return math.ceil(len(self.data) / (self.block_size*2 + 1))
        return math.ceil(len(self.data) - (self.block_size + 1))

    def __getitem__(self, idx):
        # grab a chunk of (block_size + 1) characters from the data
        chunk = self.data[idx:idx + self.block_size + 1]
        
        z = None
        if self.bMultipleIns:
            dix = self.stoimul(chunk)
        else:
            if self.bMultipleOuts:
                dix = self.stoiouts(chunk)
            else:
                dix = [self.stoi[s] for s in chunk]
            
        if self.bMultipleOuts:
            z = self.stoimulOuts(dix)
        
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)

        return x, y, z


In [5]:
def toSE(strSE): # Function to remove multiple token instances markup from output
    strSE = strSE.replace("§2§", "")
    strSE = strSE.replace("§3§", "")
    strSE = strSE.replace("§4§", "")
    strSE = strSE.replace("§5§", "")
    return strSE

In [6]:
block_size = 32 # spatial extent of the model for its context
is_charlevel = False # True = tokens is per character, False = tokens is corpus split by bytepair chunks
bMultipleIns = False # Multiple incremental instances for each token?
bMultipleOuts = True # Multiple output instances for each token?
bDecisionTreeLayers = True
modelcount = 4 # Amount of cooperative models to train in parallell
internalModelCount = 1 # Amount of parallell models internal for each model

In [7]:
# you can download this file at https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt
fp = open('input.txt', encoding="utf8", mode='r') # don't worry we won't run out of file handles
maxlines = 100000 # Limit, read the first 100.000 lines only
train_dataset = BytePairDataset(fp, block_size, is_charlevel, bMultipleIns, bMultipleOuts, maxlines) # one line of poem is roughly 50 characters


Line 10000
Line 20000
Line 30000
Line 39997
data has 557697 characters, 6670 unique.


In [8]:
from mingpt.model import GPT, GPTConfig
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, True,
                  n_layer=8, n_head=8, n_embd=256, internalModelCount=internalModelCount)

decmconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, False,
                  n_layer=8, n_head=8, n_embd=256, internalModelCount=1)

models=[]
decmodel = GPT(decmconf)
for k in range(modelcount):
    models.append(GPT(mconf))
    


10/07/2020 22:45:06 - INFO - mingpt.model -   number of parameters: 9.741824e+06
10/07/2020 22:45:06 - INFO - mingpt.model -   number of parameters: 9.741824e+06
10/07/2020 22:45:06 - INFO - mingpt.model -   number of parameters: 9.741824e+06
10/07/2020 22:45:06 - INFO - mingpt.model -   number of parameters: 9.741824e+06
10/07/2020 22:45:06 - INFO - mingpt.model -   number of parameters: 9.741824e+06


In [9]:
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
# We can't preload batches because the input tokens for the models depends on the output from the branching model
tconf = TrainerConfig(max_epochs=1, batch_size=32, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(train_dataset)*block_size,
                      num_workers=0, block_size = block_size, bMultipleOuts = bMultipleOuts)
trainer = Trainer(models, decmodel, modelcount, internalModelCount, train_dataset, None, tconf)
trainer.train()

epoch 1 iter 17426: tl 4.8571. brtl 4.4935. bestm: 3. bim: 515. lr 3.001352e-04: 100%|██████████| 17427/17427 [5:00:39<00:00,  1.04s/it]   

Model: 0, trained iter: 5652
Model: 1, trained iter: 3844
Model: 2, trained iter: 4055
Model: 3, trained iter: 3876





In [10]:
# alright, let's sample some bytepair-level shakespear
from mingpt.utils import sampleMultiModelProb

context = "O God,"
if is_charlevel==False:
    context = [item for word in context.split("[NL]") for item in split_word(word)]

x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)
y = sampleMultiModelProb(models, decmodel, modelcount, internalModelCount, x, 2000, temperature=1.0, sample=True, top_k=10, bMultipleIns=bMultipleIns, bMultipleOuts=bMultipleOuts)[0]
completionMultiModel = ''.join([train_dataset.itos[int(i)] for i in y])


Model: 0, sampled iter: 544, train bias: 1000.0
Model: 1, sampled iter: 478, train bias: 1000.0
Model: 2, sampled iter: 512, train bias: 1000.0
Model: 3, sampled iter: 466, train bias: 1000.0


In [11]:
print(toSE(completionMultiModel))

O God,
And bath not and, and by make sir we mane, not make of not the ther.

ANTONONONONIO:
And mave as not thy not to a mut, nod the not his costen should ther or not not me.

ALLO:
The he and make hour mane,
Thes, our that the ban mine.

ANSTPEBAPTANANIO:
I cang trance com in the hath your as of shat and ing in am and of heess as thincent.

ANTO:
Not, you ind than and thou are hall won I was thim ing main the mar,
When.

PONSTTISTO:
Shood, I be will not, bear thour the coster and ou make the wer that,
And my marroonth with wit this you in the so ing of so the herest thich him say, the of so to sir, of have, will here thich of the to has, of bere to me ine, that's the he pine ing of and that not hich hand ance iu and of shat not my hear ine ing make
Thicher an of he the now the thand to from is were there his make in me, ing ince withis cone,
This or mond, ther, my from thour the and his maine the withe th thou theen hast on is of the in on of wor me mand din to me,
And of an make, th

In [12]:
# well that was fun