In [1]:
from mingpt.model import GPT, GPTConfig
import tensorflow as tf
import math
import numpy as np

In [2]:
class CharDataset:

    def __init__(self, data, block_size):
        chars = list(set(data))
        data_size, vocab_size = len(data), len(chars)
        print('data has %d characters, %d unique.' % (data_size, vocab_size))
        
        self.stoi = { ch:i for i,ch in enumerate(chars) }
        self.itos = { i:ch for i,ch in enumerate(chars) }
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data
    
    def __len__(self):
        return math.ceil(len(self.data) / (self.block_size + 1))

    def __iter__(self):
        # we're actually going to "cheat" and pick a spot in the dataset at random
        for _ in range(self.__len__()):
            i = np.random.randint(0, len(self.data) - (self.block_size + 1))
            chunk = self.data[i:i+self.block_size+1]
            dix = [self.stoi[s] for s in chunk]
            x = tf.convert_to_tensor(dix[:-1], dtype=tf.int32)
            y = tf.convert_to_tensor(dix[1:], dtype=tf.int32)
            yield x, y
    
    __call__ = __iter__

In [3]:
block_size = 128 

In [4]:
text = open('../minGPT/input.txt', 'r').read()
train_dataset_gen = CharDataset(text, block_size) 

data has 1115394 characters, 65 unique.


In [5]:
train_dataset = tf.data.Dataset.from_generator(train_dataset_gen,(tf.int32,tf.int32))

In [6]:
from mingpt.model import GPT, GPTConfig
mconf = GPTConfig(train_dataset_gen.vocab_size, train_dataset_gen.block_size,
                  n_layer=8, n_head=8, n_embd=512)

In [7]:
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=200, batch_size=4, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=512*20, final_tokens=200*len(train_dataset_gen)*block_size,
                      num_workers=4)
trainer = Trainer(GPT, mconf, train_dataset, len(train_dataset_gen), None, None, tconf)

  warn("Couldn't import ipywidgets properly, progress bar will use console behavior")


In [8]:
trainer.train()

█

KeyboardInterrupt: 

In [None]:
# trainer.model.trainable_variables

In [None]:
# trainer.optimizer._decayed_lr(tf.float32)

In [None]:
# trainer.optimizer.iterations