<a href="https://colab.research.google.com/github/aspiringastro/gpt-zero-to-hero/blob/main/gpt_hero.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import re

In [2]:
# hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 1000
eval_interval = 100
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
dropout = 0.125
n_layers = 6
n_heads = 6
# ------------

In [3]:
torch.manual_seed(1337)

<torch._C.Generator at 0x7ff861fdcbf0>

In [4]:
with open('/content/drive/MyDrive/gpt/supreme.txt', 'r', encoding='utf-8') as f:
    text = f.read()
text = re.sub('[^A-Za-z0-9\n.,!"\'\-:;_ ]+', '', text)
print('Size:', len(text))
print('Sample:\n', text[:1000],'---------')


Size: 3947033
Sample:
 j__john_g_roberts_jr:
We'll hear argument next in Case 18-877, Allen versus Cooper. Mr. Shaffer.
derek_l_shaffer:
Mr. Chief Justice, and may it please the Court: When states infringe the exclusive federal rights that Congress is charged with securing, Congress can make states pay for doing so.
That's our respectful submission today, one that follows from the Constitution's text and affords ample basis for this Court to uphold the work Congress did in enacting the CRCA. Article I, Section 8, clause 8, what we're calling the intellectual property clause, is unique within Article I in laying down an express constitutional mandate for Congress to protect specified private property rights against any and all intrusion. Consider just how pointed and clear the constitutional text is.
Congress is not only to be granting copyrights but securing them, and the resulting rights by definition are meant to be exclusive rights.
Exclusive against whom, Your Honors Exclusive agai

In [5]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string


In [6]:
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [7]:
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


In [8]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [9]:
class Head(nn.Module):
    """ one head of single attention """
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)  # (B,T,C)
        q = self.query(x) # (B,T,C)
        
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * C**-0.5 # sqrt of head size, (B,T,C) @ (B,T,C)^T => (B,T,C) @ (B,C,T) => (B,T,T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B,T,T)
        wei = F.softmax(wei, dim=-1) # (B,T,T)
        wei = self.dropout(wei) # 1/8th dropout

        v = self.value(x) #(B,T,C)
        out = wei @ v # (B,T,T) @ (B,T,C) = (B,T,C)
        return out


In [10]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # concat over channel dimension
        out = self.proj(out) # projection is a linear transformation of the outcome of the previous multi-head layer
        out = self.dropout(out)
        return out
    

In [11]:
class FeedForward(nn.Module):
    """ a simple linear layer of feedforward followed by non-linearity"""

    def __init__(self, n_embd):
        super().__init__()
        self.nn = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd), # projection layer in FFwd
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.nn(x)
    

In [12]:
class Block(nn.Module):
    """ Transformer Block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd : embedding dimension
        # n_head : number of heads needed for multi-head self-attention
        super().__init__()
        assert n_embd % n_head == 0, f'n_embd {n_embd}, n_head: {n_head} must be a divisor'
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size) # communication
        self.ffwd = FeedForward(n_embd) # computation
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        # No residual connections
        # x = self.sa(x)
        # x = self.ffwd(x)
        # with residual connection
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))

        return x

In [13]:
# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_heads) for _ in range(n_layers)],
            nn.LayerNorm(n_embd,))
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # T,C
        x = tok_emb + pos_emb  #(B,T,C)
        x = self.blocks(x)
        logits = self.lm_head(x) # (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


In [17]:
model = BigramLanguageModel(vocab_size)
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        context = torch.zeros((1, 1), dtype=torch.long, device=device)
        print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))

10.794312 M parameters
step 0: train loss 4.4313, val loss 4.4366

IJuTAY1l BSff4t7rzHLRaE;RC.li:3AppbVcsJ689z68OxFpap8DYC1I2yKyXsKir4vMer5I1RB3nsGw8J3hv-Fzbapbrw.fIYuwlH2I CtgEyLWakayyoeUd-Z0_ejte4koBhQQphzjF-9.;Uvb:'M:zqi3Lgak' 6Y-"bitRqS91ob3r6WAZFP'T60ZkS;IwGMhscr,wJWM q8DKn20a7AYAS5hzTH5wJYcAy,gdG"c:cVQO1pZ,
U"V9h'QwX8HAa72"N9cF85Vp_0kKYwTE"t8IyHganTsp5bWshTQUb58q56dac0D6Qf80frgMx55F8fgK.80xW_SikgLrM;wf6SZOI2VKbR69zgFI
Qmmu6hM6OyBKevBB8lPK-
:'r2f6C815m5"BBSH GWplaeqG5ywfXF'Ni5hIgOsIxN'4K56QdnFXV.SUpyry5benA6FYAsFg_9XGDg7QsvdI-FVuFaCPOENsvahTkaycaBm2I;KCQ-C
step 100: train loss 2.4072, val loss 2.4269

pow herit iobio os s ing thougng pocky t.
d_nngod ct t dor:
Vo2erarrrerer o - ans ye t usulstuse izecauche. tokng_as lercanson. aryovenyeyony raklis sothPe wa. thibrng_.
Smiver relyoeralaty moth e -----
Thave e:
Ce woins t t'ns t4, y d an, d'ts, t; I outs d. g__j_lin_o_judeldse eermed.
I fewss.
My, ben y, rofindays reome s, - s_ays cul, - thitndee, t t, thatear wley inancisensorsacor

In [18]:
open('/content/drive/MyDrive/gpt/more.txt', 'w').write(decode(m.generate(context, max_new_tokens=10000)[0].tolist()))

10001