In [1]:
import math
import urllib.request
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
batch_size = 64
block_size = 256
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2

In [3]:
torch.manual_seed(1337)

<torch._C.Generator at 0x7b6ed9f84890>

In [4]:
# Download the dataset
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = urllib.request.urlopen(url)
text = response.read().decode('utf-8')


In [5]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [6]:
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [7]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


In [8]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd, num_heads, dropout):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = n_embd // num_heads

        self.key = nn.Linear(n_embd, n_embd, bias=False)
        self.query = nn.Linear(n_embd, n_embd, bias=False)
        self.value = nn.Linear(n_embd, n_embd, bias=False)
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        H = self.num_heads
        head_size = self.head_size

        k = self.key(x).view(B, T, H, head_size).transpose(1, 2)   # (B, T, C) -> (B, H, T, head_size)
        q = self.query(x).view(B, T, H, head_size).transpose(1, 2) # (B, T, C) -> (B, H, T, head_size)
        v = self.value(x).view(B, T, H, head_size).transpose(1, 2) # (B, T, C) -> (B, H, T, head_size)

        wei = q @ k.transpose(-2, -1) * head_size**-0.5  # (B, H, T, head_size) @ (B, H, head_size, T) -> (B, H, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        out = wei @ v  # (B, H, T, T) @ (B, H, T, head_size) -> (B, H, T, head_size)
        out = out.transpose(1, 2).contiguous().view(B, T, H * head_size)  # (B, H, T, head_size) -> (B, T, C)

        return self.dropout(self.proj(out))

In [10]:
class FeedForward(nn.Module):
    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [11]:
class Block(nn.Module):
    def __init__(self, n_embd, n_head, dropout):
        super().__init__()
        self.sa = MultiHeadAttention(n_embd, n_head, dropout)
        self.ffwd = FeedForward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


In [12]:
class GPTLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head, dropout) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


In [13]:
class CustomAdamW:
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01):
        if not isinstance(params, (list, tuple)):
            params = [params]
        self.param_groups = [{'params': p, 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay} for p in params]
        self.state = {}

        for group in self.param_groups:
            for p in group['params']:
                if p.requires_grad:
                    self.state[p] = {'step': 0, 'exp_avg': torch.zeros_like(p), 'exp_avg_sq': torch.zeros_like(p)}

    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1
                step = state['step']

                # Apply weight decay
                if group['weight_decay'] != 0:
                    p.data = p.data - group['lr'] * group['weight_decay'] * p.data

                # Compute first and second moment estimates
                exp_avg.mul_(beta1).add_(1 - beta1, p.grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, p.grad, p.grad)

                # Bias correction
                bias_correction1 = 1 - beta1 ** step
                bias_correction2 = 1 - beta2 ** step
                corrected_exp_avg = exp_avg / bias_correction1
                corrected_exp_avg_sq = exp_avg_sq / bias_correction2

                # Update parameters
                denom = (corrected_exp_avg_sq.sqrt() + group['eps'])
                step_size = group['lr'] / denom
                p.data.addcdiv_(-step_size, corrected_exp_avg)

    def zero_grad(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    p.grad.detach_()
                    p.grad.zero_()


In [14]:
class CustomStepLR:
    def __init__(self, optimizer, step_size, gamma=0.1):
        self.optimizer = optimizer
        self.step_size = step_size
        self.gamma = gamma
        self.last_epoch = -1

    def step(self):
        self.last_epoch += 1
        if self.last_epoch % self.step_size == 0:
            for param_group in self.optimizer.param_groups:
                param_group['lr'] *= self.gamma


In [16]:

model = GPTLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

10.788929 M parameters


In [20]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)

In [18]:
clip_value = 1.0

In [21]:
for iter in range(max_iters):
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()

    # Clip gradients
    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)

    optimizer.step()
    scheduler.step()

step 0: train loss 4.2985, val loss 4.3038
step 500: train loss 1.6939, val loss 1.8474
step 1000: train loss 1.3716, val loss 1.5926
step 1500: train loss 1.3043, val loss 1.5437
step 2000: train loss 1.2825, val loss 1.5307
step 2500: train loss 1.2730, val loss 1.5252
step 3000: train loss 1.2715, val loss 1.5250
step 3500: train loss 1.2729, val loss 1.5260
step 4000: train loss 1.2697, val loss 1.5258
step 4500: train loss 1.2699, val loss 1.5283
step 4999: train loss 1.2696, val loss 1.5251


In [22]:

context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))


Murder:
Let to our battle straiveful fall it
In make dispisit us me with this surm,
Bless to any upon of eyes. Your lad chekely, sleep,
The gransfiers thy busicine,
And would hear to aid, stirry understary,
And suched sup to curse, her usury. Who your destricke
And is beast, my sote; she hark dost shose us.
An yet my trude hearting iT in thy hand mercum,
Both gone! My tale remains iscenate;
That acting friends the sting hardly et,
As alous they broughtments Sheupts captuant.
He turn else pass no
