## Train a character-level GPT on some text data

The inputs here are simple text files, which we chop up to individual characters and then train GPT on. So you could say this is a char-transformer instead of a char-rnn. Doesn't quite roll off the tongue as well. In this example we will feed it some Shakespeare, which we'll get it to predict character-level.

In [1]:
import numpy as np
from os import listdir
from os.path import join as pathjoin
import torch
import torch.nn as nn
from torch.nn import functional as F
import tqdm

from mingpt.model import GPT, GPTConfig
from mingpt.trainer import Trainer, TrainerConfig
# make deterministic
from mingpt.utils import sample, set_seed
set_seed(42)

In [2]:
import math
from torch.utils.data import Dataset

class CharDataset(Dataset):

    def __init__(self, data, block_size):
        chars = sorted(list(set(data)))
        data_size, vocab_size = len(data), len(chars)
        print('data has %d characters, %d unique.' % (data_size, vocab_size))
        
        self.stoi = { ch:i for i,ch in enumerate(chars) }
        self.itos = { i:ch for i,ch in enumerate(chars) }
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data
    
    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        # grab a chunk of (block_size + 1) characters from the data
        chunk = self.data[idx:idx + self.block_size + 1]
        # encode every character to an integer
        dix = [self.stoi[s] for s in chunk]
        """
        arrange data and targets so that the first i elements of x
        will be asked to predict the i-th element of y. Notice that
        the eventual language model will actually make block_size
        individual predictions at the same time based on this data,
        so we are being clever and amortizing the cost of the forward
        pass of the network. So for example if block_size is 4, then
        we could e.g. sample a chunk of text "hello", the integers in
        x will correspond to "hell" and in y will be "ello". This will
        then actually "multitask" 4 separate examples at the same time
        in the language model:
        - given just "h", please predict "e" as next
        - given "he" please predict "l" next
        - given "hel" predict "l" next
        - given "hell" predict "o" next
        
        In addition, because the DataLoader will create batches of examples,
        every forward/backward pass during traning will simultaneously train
        a LOT of predictions, amortizing a lot of computation. In particular,
        for a batched input of integers X (B, T) where B is batch size and
        T is block_size and Y (B, T), the network will during training be
        simultaneously training to make B*T predictions, all at once! Of course,
        at test time we can paralellize across batch B, but unlike during training
        we cannot parallelize across the time dimension T - we have to run
        a forward pass of the network to recover the next single character of the 
        sequence along each batch dimension, and repeatedly always feed in a next
        character to get the next one.
        
        So yes there is a big asymmetry between train/test time of autoregressive
        models. During training we can go B*T at a time with every forward pass,
        but during test time we can only go B at a time, T times, with T forward 
        passes.
        """
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        return x, y


In [3]:
block_size = 128

In [4]:
def train_gpt_generator(train_text_file, state_dict_file, n_layer=8, n_head=8, n_embd=512,
                        max_epochs=2, batch_size=512):
    text = open(train_text_file, 'r').read()
    train_dataset = CharDataset(text, block_size) 
    mconf = GPTConfig(
        train_dataset.vocab_size, train_dataset.block_size,
        n_layer=n_layer, n_head=n_head, n_embd=n_embd
    )
    model = GPT(mconf)
    tconf = TrainerConfig(
        max_epochs=max_epochs, batch_size=batch_size, learning_rate=6e-4,
        lr_decay=True, warmup_tokens=batch_size*20, final_tokens=2*len(train_dataset)*block_size,
        num_workers=4
    )
    trainer = Trainer(model, train_dataset, None, tconf)
    trainer.train()
    torch.save(model.state_dict(), state_dict_file)

In [5]:
GENRE_DATA_DIR = '/home/mlepekhin/data/genre'
GPT_MODELS_DIR = '/home/mlepekhin/models/mini_gpt/'
LANG = 'en'

In [None]:
for train_text_file in tqdm.tqdm(listdir(pathjoin(GENRE_DATA_DIR, LANG))):
    label = train_text_file[:-4]
    train_gpt_generator(
        pathjoin(GENRE_DATA_DIR, LANG, train_text_file),
        pathjoin(GPT_MODELS_DIR, LANG, label)
    )

  0%|          | 0/11 [00:00<?, ?it/s]

data has 807353 characters, 95 unique.




epoch 1 iter 0: train loss 4.64142. lr 5.999999e-04:   0%|          | 0/1577 [00:07<?, ?it/s][A
epoch 1 iter 0: train loss 4.64142. lr 5.999999e-04:   0%|          | 1/1577 [00:07<3:18:55,  7.57s/it][A
epoch 1 iter 1: train loss 3.70626. lr 5.999995e-04:   0%|          | 1/1577 [00:07<3:18:55,  7.57s/it][A
epoch 1 iter 1: train loss 3.70626. lr 5.999995e-04:   0%|          | 2/1577 [00:07<2:21:53,  5.41s/it][A
epoch 1 iter 2: train loss 5.45379. lr 5.999988e-04:   0%|          | 2/1577 [00:08<2:21:53,  5.41s/it][A
epoch 1 iter 2: train loss 5.45379. lr 5.999988e-04:   0%|          | 3/1577 [00:08<1:41:46,  3.88s/it][A
epoch 1 iter 3: train loss 4.12409. lr 5.999978e-04:   0%|          | 3/1577 [00:08<1:41:46,  3.88s/it][A
epoch 1 iter 3: train loss 4.12409. lr 5.999978e-04:   0%|          | 4/1577 [00:08<1:14:14,  2.83s/it][A
epoch 1 iter 4: train loss 3.62736. lr 5.999965e-04:   0%|          | 4/1577 [00:08<1:14:14,  2.83s/it][A
epoch 1 iter 4: train loss 3.62736. lr 5.9999

epoch 1 iter 36: train loss 2.57910. lr 5.997979e-04:   2%|▏         | 37/1577 [00:19<08:33,  3.00it/s][A
epoch 1 iter 37: train loss 2.58006. lr 5.997868e-04:   2%|▏         | 37/1577 [00:19<08:33,  3.00it/s][A
epoch 1 iter 37: train loss 2.58006. lr 5.997868e-04:   2%|▏         | 38/1577 [00:19<08:31,  3.01it/s][A
epoch 1 iter 38: train loss 2.57380. lr 5.997753e-04:   2%|▏         | 38/1577 [00:20<08:31,  3.01it/s][A
epoch 1 iter 38: train loss 2.57380. lr 5.997753e-04:   2%|▏         | 39/1577 [00:20<08:30,  3.01it/s][A
epoch 1 iter 39: train loss 2.56470. lr 5.997636e-04:   2%|▏         | 39/1577 [00:20<08:30,  3.01it/s][A
epoch 1 iter 39: train loss 2.56470. lr 5.997636e-04:   3%|▎         | 40/1577 [00:20<08:30,  3.01it/s][A
epoch 1 iter 40: train loss 2.54831. lr 5.997516e-04:   3%|▎         | 40/1577 [00:20<08:30,  3.01it/s][A
epoch 1 iter 40: train loss 2.54831. lr 5.997516e-04:   3%|▎         | 41/1577 [00:20<08:28,  3.02it/s][A
epoch 1 iter 41: train loss 2.55179. 

epoch 1 iter 74: train loss 2.42002. lr 5.991663e-04:   5%|▍         | 75/1577 [00:31<08:03,  3.11it/s][A
epoch 1 iter 75: train loss 2.41512. lr 5.991438e-04:   5%|▍         | 75/1577 [00:32<08:03,  3.11it/s][A
epoch 1 iter 75: train loss 2.41512. lr 5.991438e-04:   5%|▍         | 76/1577 [00:32<08:03,  3.11it/s][A
epoch 1 iter 76: train loss 2.39270. lr 5.991211e-04:   5%|▍         | 76/1577 [00:32<08:03,  3.11it/s][A
epoch 1 iter 76: train loss 2.39270. lr 5.991211e-04:   5%|▍         | 77/1577 [00:32<08:03,  3.10it/s][A
epoch 1 iter 77: train loss 2.39834. lr 5.990981e-04:   5%|▍         | 77/1577 [00:32<08:03,  3.10it/s][A
epoch 1 iter 77: train loss 2.39834. lr 5.990981e-04:   5%|▍         | 78/1577 [00:32<08:16,  3.02it/s][A
epoch 1 iter 78: train loss 2.40575. lr 5.990748e-04:   5%|▍         | 78/1577 [00:33<08:16,  3.02it/s][A
epoch 1 iter 78: train loss 2.40575. lr 5.990748e-04:   5%|▌         | 79/1577 [00:33<08:12,  3.04it/s][A
epoch 1 iter 79: train loss 2.40202. 

epoch 1 iter 112: train loss 2.35460. lr 5.981058e-04:   7%|▋         | 113/1577 [00:44<07:54,  3.09it/s][A
epoch 1 iter 113: train loss 2.36220. lr 5.980721e-04:   7%|▋         | 113/1577 [00:44<07:54,  3.09it/s][A
epoch 1 iter 113: train loss 2.36220. lr 5.980721e-04:   7%|▋         | 114/1577 [00:44<07:54,  3.09it/s][A
epoch 1 iter 114: train loss 2.36341. lr 5.980382e-04:   7%|▋         | 114/1577 [00:44<07:54,  3.09it/s][A
epoch 1 iter 114: train loss 2.36341. lr 5.980382e-04:   7%|▋         | 115/1577 [00:44<07:53,  3.09it/s][A
epoch 1 iter 115: train loss 2.36236. lr 5.980039e-04:   7%|▋         | 115/1577 [00:45<07:53,  3.09it/s][A
epoch 1 iter 115: train loss 2.36236. lr 5.980039e-04:   7%|▋         | 116/1577 [00:45<07:51,  3.10it/s][A
epoch 1 iter 116: train loss 2.35596. lr 5.979693e-04:   7%|▋         | 116/1577 [00:45<07:51,  3.10it/s][A
epoch 1 iter 116: train loss 2.35596. lr 5.979693e-04:   7%|▋         | 117/1577 [00:45<07:51,  3.10it/s][A
epoch 1 iter 117: t

epoch 1 iter 150: train loss 2.29992. lr 5.966181e-04:  10%|▉         | 150/1577 [00:56<07:42,  3.08it/s][A
epoch 1 iter 150: train loss 2.29992. lr 5.966181e-04:  10%|▉         | 151/1577 [00:56<07:42,  3.08it/s][A
epoch 1 iter 151: train loss 2.29312. lr 5.965732e-04:  10%|▉         | 151/1577 [00:56<07:42,  3.08it/s][A
epoch 1 iter 151: train loss 2.29312. lr 5.965732e-04:  10%|▉         | 152/1577 [00:56<07:55,  3.00it/s][A
epoch 1 iter 152: train loss 2.30017. lr 5.965280e-04:  10%|▉         | 152/1577 [00:57<07:55,  3.00it/s][A
epoch 1 iter 152: train loss 2.30017. lr 5.965280e-04:  10%|▉         | 153/1577 [00:57<07:51,  3.02it/s][A
epoch 1 iter 153: train loss 2.29256. lr 5.964825e-04:  10%|▉         | 153/1577 [00:57<07:51,  3.02it/s][A
epoch 1 iter 153: train loss 2.29256. lr 5.964825e-04:  10%|▉         | 154/1577 [00:57<07:47,  3.04it/s][A
epoch 1 iter 154: train loss 2.29124. lr 5.964367e-04:  10%|▉         | 154/1577 [00:57<07:47,  3.04it/s][A
epoch 1 iter 154: t

epoch 1 iter 187: train loss 2.23489. lr 5.947610e-04:  12%|█▏        | 188/1577 [01:08<07:33,  3.06it/s][A
epoch 1 iter 188: train loss 2.22899. lr 5.947052e-04:  12%|█▏        | 188/1577 [01:08<07:33,  3.06it/s][A
epoch 1 iter 188: train loss 2.22899. lr 5.947052e-04:  12%|█▏        | 189/1577 [01:08<07:32,  3.07it/s][A
epoch 1 iter 189: train loss 2.23048. lr 5.946492e-04:  12%|█▏        | 189/1577 [01:09<07:32,  3.07it/s][A
epoch 1 iter 189: train loss 2.23048. lr 5.946492e-04:  12%|█▏        | 190/1577 [01:09<07:31,  3.07it/s][A
epoch 1 iter 190: train loss 2.23484. lr 5.945928e-04:  12%|█▏        | 190/1577 [01:09<07:31,  3.07it/s][A
epoch 1 iter 190: train loss 2.23484. lr 5.945928e-04:  12%|█▏        | 191/1577 [01:09<07:31,  3.07it/s][A
epoch 1 iter 191: train loss 2.22491. lr 5.945362e-04:  12%|█▏        | 191/1577 [01:09<07:31,  3.07it/s][A
epoch 1 iter 191: train loss 2.22491. lr 5.945362e-04:  12%|█▏        | 192/1577 [01:09<07:30,  3.07it/s][A
epoch 1 iter 192: t

epoch 1 iter 225: train loss 2.10997. lr 5.924368e-04:  14%|█▍        | 225/1577 [01:21<07:35,  2.97it/s][A
epoch 1 iter 225: train loss 2.10997. lr 5.924368e-04:  14%|█▍        | 226/1577 [01:21<07:49,  2.88it/s][A
epoch 1 iter 226: train loss 2.12034. lr 5.923699e-04:  14%|█▍        | 226/1577 [01:21<07:49,  2.88it/s][A
epoch 1 iter 226: train loss 2.12034. lr 5.923699e-04:  14%|█▍        | 227/1577 [01:21<07:44,  2.90it/s][A
epoch 1 iter 227: train loss 2.10756. lr 5.923028e-04:  14%|█▍        | 227/1577 [01:21<07:44,  2.90it/s][A
epoch 1 iter 227: train loss 2.10756. lr 5.923028e-04:  14%|█▍        | 228/1577 [01:21<07:40,  2.93it/s][A
epoch 1 iter 228: train loss 2.11758. lr 5.922354e-04:  14%|█▍        | 228/1577 [01:22<07:40,  2.93it/s][A
epoch 1 iter 228: train loss 2.11758. lr 5.922354e-04:  15%|█▍        | 229/1577 [01:22<07:38,  2.94it/s][A
epoch 1 iter 229: train loss 2.09640. lr 5.921677e-04:  15%|█▍        | 229/1577 [01:22<07:38,  2.94it/s][A
epoch 1 iter 229: t

epoch 1 iter 262: train loss 2.00852. lr 5.897709e-04:  17%|█▋        | 263/1577 [01:33<07:24,  2.96it/s][A
epoch 1 iter 263: train loss 2.01187. lr 5.896934e-04:  17%|█▋        | 263/1577 [01:34<07:24,  2.96it/s][A
epoch 1 iter 263: train loss 2.01187. lr 5.896934e-04:  17%|█▋        | 264/1577 [01:34<07:23,  2.96it/s][A
epoch 1 iter 264: train loss 2.00984. lr 5.896156e-04:  17%|█▋        | 264/1577 [01:34<07:23,  2.96it/s][A
epoch 1 iter 264: train loss 2.00984. lr 5.896156e-04:  17%|█▋        | 265/1577 [01:34<07:22,  2.96it/s][A
epoch 1 iter 265: train loss 2.00567. lr 5.895375e-04:  17%|█▋        | 265/1577 [01:34<07:22,  2.96it/s][A
epoch 1 iter 265: train loss 2.00567. lr 5.895375e-04:  17%|█▋        | 266/1577 [01:34<07:23,  2.96it/s][A
epoch 1 iter 266: train loss 1.99145. lr 5.894591e-04:  17%|█▋        | 266/1577 [01:35<07:23,  2.96it/s][A
epoch 1 iter 266: train loss 1.99145. lr 5.894591e-04:  17%|█▋        | 267/1577 [01:35<07:22,  2.96it/s][A
epoch 1 iter 267: t

epoch 1 iter 300: train loss 1.89722. lr 5.866232e-04:  19%|█▉        | 300/1577 [01:46<07:21,  2.89it/s][A
epoch 1 iter 300: train loss 1.89722. lr 5.866232e-04:  19%|█▉        | 301/1577 [01:46<07:17,  2.91it/s][A
epoch 1 iter 301: train loss 1.89727. lr 5.865348e-04:  19%|█▉        | 301/1577 [01:47<07:17,  2.91it/s][A
epoch 1 iter 301: train loss 1.89727. lr 5.865348e-04:  19%|█▉        | 302/1577 [01:47<07:15,  2.93it/s][A
epoch 1 iter 302: train loss 1.90420. lr 5.864461e-04:  19%|█▉        | 302/1577 [01:47<07:15,  2.93it/s][A
epoch 1 iter 302: train loss 1.90420. lr 5.864461e-04:  19%|█▉        | 303/1577 [01:47<07:13,  2.94it/s][A
epoch 1 iter 303: train loss 1.88084. lr 5.863571e-04:  19%|█▉        | 303/1577 [01:47<07:13,  2.94it/s][A
epoch 1 iter 303: train loss 1.88084. lr 5.863571e-04:  19%|█▉        | 304/1577 [01:47<07:12,  2.95it/s][A
epoch 1 iter 304: train loss 1.88391. lr 5.862679e-04:  19%|█▉        | 304/1577 [01:48<07:12,  2.95it/s][A
epoch 1 iter 304: t