In [1]:
import math, json, os, time, requests
from collections import defaultdict

import regex as re
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

In [2]:
class CfgNode(dict):
    """Simple config dictionary allowing dot notation."""
    def __getattr__(self, name): return self[name]
    def __setattr__(self, name, value): self[name] = value
    def merge_from_dict(self, d): self.update(d)

In [3]:
def bytes_to_unicode():
    bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
    cs = bs[:]
    n = 0
    for b in range(2 ** 8):
        if b not in bs:
            bs.append(b)
            cs.append(2 ** 8 + n)
            n += 1
    return dict(zip(bs, map(chr, cs)))

def get_pairs(word):
    return set(zip(word, word[1:]))

class Encoder:
    def __init__(self, encoder, bpe_merges):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        self.encoder = encoder
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.cache = {}
        self.pat = re.compile(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+")

    def bpe(self, token):
        if token in self.cache: return self.cache[token]
        word = tuple(token)
        pairs = get_pairs(word)
        if not pairs: return token

        while True:
            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks: break
            first, second = bigram
            new_word, i = [], 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break
                if i < len(word) - 1 and word[i + 1] == second:
                    new_word.append(first + second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            word = tuple(new_word)
            if len(word) == 1: break
            pairs = get_pairs(word)

        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_idx = []
        for token in re.findall(self.pat, text):
            token_bytes = token.encode('utf-8')
            token_translated = ''.join(self.byte_encoder[b] for b in token_bytes)
            token_merged = self.bpe(token_translated).split(' ')
            token_ix = [self.encoder[bpe_token] for bpe_token in token_merged]
            bpe_idx.extend(token_ix)
        return bpe_idx

    def decode(self, bpe_idx):
        tokens_merged = [self.decoder[token] for token in bpe_idx]
        tokens_flat = ''.join(tokens_merged)
        tokens_bytes = bytearray([self.byte_decoder[c] for c in tokens_flat])
        return tokens_bytes.decode('utf-8', errors='replace')


def get_encoder():
    cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'mingpt')
    os.makedirs(cache_dir, exist_ok=True)

    def get_file(local, remote):
        if not os.path.isfile(local):
            with open(local, 'wb') as f:
                f.write(requests.get(remote).content)

    enc_path = os.path.join(cache_dir, 'encoder.json')
    get_file(enc_path, 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json')
    with open(enc_path, 'r') as f: encoder = json.load(f)

    vocab_path = os.path.join(cache_dir, 'vocab.bpe')
    get_file(vocab_path, 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe')
    with open(vocab_path, 'r', encoding='utf-8') as f:
        bpe_merges = [tuple(merge_str.split()) for merge_str in f.read().split('\n')[1:-1]]

    return Encoder(encoder, bpe_merges)

In [4]:
class NewGELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * x.pow(3))))

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                             .view(1, 1, config.block_size, config.block_size))
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.resid_dropout(self.c_proj(y))

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            NewGELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.block_size = config.block_size
        self.transformer = nn.ModuleDict(dict(
            wte=nn.Embedding(config.vocab_size, config.n_embd),
            wpe=nn.Embedding(config.block_size, config.n_embd),
            drop=nn.Dropout(config.embd_pdrop),
            h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f=nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith("c_proj.weight"):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)

        x = self.transformer.drop(self.transformer.wte(idx) + self.transformer.wpe(pos))
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:] if idx.size(1) > self.block_size else idx
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1) if do_sample else torch.topk(probs, 1, dim=-1)[1]
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [5]:
class CharDataset(Dataset):
    def __init__(self, text, tokenizer, block_size):
        self.tokenizer = tokenizer
        self.block_size = block_size
        self.data = tokenizer.encode(text)

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        chunk = self.data[idx : idx + self.block_size + 1]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        return x, y

In [6]:
def train(model, dataset, config):
    model.train()
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)

    for epoch in range(config.epochs):
        total_loss = 0.0
        for step, (x, y) in enumerate(dataloader):
            x, y = x.to(config.device), y.to(config.device)
            logits, loss = model(x, y)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
            optimizer.step()
            total_loss += loss.item()
            if step % 100 == 0:
                print(f"Epoch {epoch} Step {step}: loss = {loss.item():.4f}")
        print(f"Epoch {epoch} completed with average loss: {total_loss / len(dataloader):.4f}")

In [7]:
if __name__ == '__main__':
    tokenizer = get_encoder()

    # Download Tiny Shakespeare
    data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    data_path = os.path.join(os.path.expanduser("~"), ".cache", "mingpt", "shakespeare.txt")
    if not os.path.isfile(data_path):
        print("Downloading Tiny Shakespeare dataset...")
        r = requests.get(data_url)
        with open(data_path, 'w', encoding='utf-8') as f:
            f.write(r.text)

    with open(data_path, 'r', encoding='utf-8') as f:
        text = f.read()

    config = CfgNode()
    config.model_type = 'gpt-mini'
    config.vocab_size = 50257
    config.block_size = 128
    config.n_layer = 6
    config.n_head = 6
    config.n_embd = 192
    config.embd_pdrop = 0.1
    config.resid_pdrop = 0.1
    config.attn_pdrop = 0.1
    config.batch_size = 32
    config.learning_rate = 3e-4
    config.epochs = 1
    config.grad_norm_clip = 1.0
    config.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = GPT(config).to(config.device)
    dataset = CharDataset(text, tokenizer, config.block_size)

    print("Starting training...")
    train(model, dataset, config)

    prompt = "Once upon a time"
    input_ids = tokenizer.encode(prompt)
    input_tensor = torch.tensor([input_ids], dtype=torch.long).to(config.device)

    model.eval()
    with torch.no_grad():
        output = model.generate(input_tensor, max_new_tokens=50, do_sample=True, temperature=0.9)
        generated = tokenizer.decode(output[0].tolist())

    print("--- PROMPT ---")
    print(prompt)
    print("--- GENERATED ---")
    print(generated)

Downloading Tiny Shakespeare dataset...
Starting training...
Epoch 0 Step 0: loss = 10.8547
Epoch 0 Step 100: loss = 5.8631
Epoch 0 Step 200: loss = 5.1259
Epoch 0 Step 300: loss = 4.7753
Epoch 0 Step 400: loss = 4.5102
Epoch 0 Step 500: loss = 4.3270
Epoch 0 Step 600: loss = 4.0366
Epoch 0 Step 700: loss = 4.0473
Epoch 0 Step 800: loss = 3.8545
Epoch 0 Step 900: loss = 3.9681
Epoch 0 Step 1000: loss = 3.6981
Epoch 0 Step 1100: loss = 3.7519
Epoch 0 Step 1200: loss = 3.5719
Epoch 0 Step 1300: loss = 3.6516
Epoch 0 Step 1400: loss = 3.5567
Epoch 0 Step 1500: loss = 3.3861
Epoch 0 Step 1600: loss = 3.3048
Epoch 0 Step 1700: loss = 3.2328
Epoch 0 Step 1800: loss = 3.2342
Epoch 0 Step 1900: loss = 3.1277
Epoch 0 Step 2000: loss = 3.2755
Epoch 0 Step 2100: loss = 3.0583
Epoch 0 Step 2200: loss = 3.0136
Epoch 0 Step 2300: loss = 2.8334
Epoch 0 Step 2400: loss = 3.0729
Epoch 0 Step 2500: loss = 2.8441
Epoch 0 Step 2600: loss = 2.8261
Epoch 0 Step 2700: loss = 2.7673
Epoch 0 Step 2800: loss = 