## Train a character-level GPT on some text data

The inputs here are simple text files, which we chop up to individual characters and then train GPT on. So you could say this is a char-transformer instead of a char-rnn. Doesn't quite roll off the tongue as well. In this example we will feed it some shakespear, which we'll get it to predict character-level.

In [11]:
# make deterministic
from pytorch_lightning import seed_everything
seed_everything(42)

42

In [12]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

In [13]:
import math
from torch.utils.data import Dataset, DataLoader

class CharDataset(Dataset):

    def __init__(self, data, block_size):
        chars = list(set(data))
        data_size, vocab_size = len(data), len(chars)
        print('data has %d characters, %d unique.' % (data_size, vocab_size))

        self.stoi = { ch:i for i,ch in enumerate(chars) }
        self.itos = { i:ch for i,ch in enumerate(chars) }
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data

    def __len__(self):
        return math.ceil(len(self.data) / (self.block_size + 1))

    def __getitem__(self, idx):
        # we're actually going to "cheat" and pick a spot in the dataset at random
        i = np.random.randint(0, len(self.data) - (self.block_size + 1))
        chunk = self.data[i:i+self.block_size+1]
        dix = [self.stoi[s] for s in chunk]
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        return x, y


In [14]:
block_size = 128 # spatial extent of the model for its context

In [15]:
# download text
url = 'https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt'

import urllib
urllib.request.urlretrieve(url, 'input.txt')

('input.txt', <http.client.HTTPMessage at 0x7fef199d81d0>)

In [30]:
# you can download this file at https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt
text = open('input.txt', 'r').read() # don't worry we won't run out of file handles
train_dataset = CharDataset(text, block_size) # one line of poem is roughly 50 characters
train_loader = DataLoader(train_dataset, batch_size=512, num_workers=4)

data has 81872 characters, 91 unique.


In [31]:
from mingpt.model import GPT
model = GPT(vocab_size=train_dataset.vocab_size, 
            block_size=train_dataset.block_size,
            n_layer=8, 
            n_head=8, 
            n_embd=512, 
            learning_rate=6e-4)

In [27]:
from pytorch_lightning import Trainer
from mingpt.lr_decay import LearningRateDecayCallback

# scheduler
lr_decay = LearningRateDecayCallback(learning_rate=6e-4, warmup_tokens=512*20,
                                    final_tokens=00*len(train_dataset)*block_size)

trainer = Trainer(max_epochs=200, 
                  gradient_clip_val=1.0, 
                  callbacks=[lr_decay], 
                  progress_bar_refresh_rate=1, 
                  row_log_interval=1)
trainer.fit(model, train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name    | Type       | Params
---------------------------------------
0 | tok_emb | Embedding  | 46 K  
1 | drop    | Dropout    | 0     
2 | blocks  | Sequential | 25 M  
3 | ln_f    | LayerNorm  | 1 K   
4 | head    | Linear     | 46 K  


Epoch 1:   2%|▏         | 7/318 [00:03<02:48,  1.85it/s, loss=3.920, v_num=5]


1

In [29]:
# alright, let's sample some character-level shakespear
from mingpt.utils import sample

context = "O God, O God!"
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(model.device)
y = sample(model, x, 2000, temperature=0.9, sample=True, top_k=5)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)

KeyboardInterrupt: 

In [None]:
# well that was fun