# Let's build GPT: from scratch, in code, spelled out.
This notebook follows along with https://www.youtube.com/watch?v=kCc8FmEb1nY

In [1]:
# download tiny shakespeare. this command wont work on windows tho!
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

'wget' is not recognized as an internal or external command,
operable program or batch file.


In [2]:
with open("input.txt", 'r', encoding='utf-8') as f:
    text = f.read()

In [7]:
print(f"length of dataset in chars: {len(text):,}")
print(text[:1000])

length of dataset in chars: 1,115,394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger f

In [12]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [14]:
char_to_int = { char:int for int, char in enumerate(chars)}
int_to_char = { int:char for int, char in enumerate(chars)}
encode = lambda s: [char_to_int[c] for c in s] #this tokenizes i guess
decode = lambda l: ''.join([int_to_char[i] for i in l]) # and this untokenizes

[46, 43, 50, 50, 53, 1, 58, 46, 43, 56, 43]
hello there


In [15]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape)

torch.Size([1115394])


Now we split into train and validation

In [17]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [19]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [20]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"whenn input is {context}, target is {target}")

whenn input is tensor([18]), target is 47
whenn input is tensor([18, 47]), target is 56
whenn input is tensor([18, 47, 56]), target is 57
whenn input is tensor([18, 47, 56, 57]), target is 58
whenn input is tensor([18, 47, 56, 57, 58]), target is 1
whenn input is tensor([18, 47, 56, 57, 58,  1]), target is 15
whenn input is tensor([18, 47, 56, 57, 58,  1, 15]), target is 47
whenn input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), target is 58


In [22]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x,y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print("targets:")
print(yb.shape)
print(yb)

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


In [47]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)
        if targets is None: 
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

m = BigramLanguageModel(vocab_size=vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)


In [48]:
idx = torch.zeros((1,2), dtype=torch.long)
print(idx)
output = m.generate(idx, max_new_tokens=100)[0].tolist()
print(output)
decoded = decode(output)
print(decoded)

tensor([[0, 0]])
[0, 0, 31, 56, 12, 55, 28, 7, 29, 35, 49, 58, 36, 53, 24, 4, 48, 24, 16, 22, 45, 27, 24, 34, 64, 5, 30, 21, 53, 16, 55, 20, 42, 46, 57, 34, 4, 60, 24, 24, 62, 39, 58, 48, 57, 41, 25, 54, 61, 24, 17, 30, 31, 28, 63, 39, 53, 8, 55, 44, 64, 57, 3, 37, 57, 3, 64, 18, 7, 61, 6, 11, 43, 17, 49, 64, 62, 48, 45, 15, 23, 18, 15, 46, 57, 2, 47, 35, 35, 8, 27, 40, 64, 16, 52, 62, 13, 1, 25, 57, 3, 9]


Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3


In [50]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [54]:
batch_size = 32
for steps in range(10000):
    xb, yb = get_batch("train")
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    print(loss.item())

3.71895432472229
3.7412240505218506
3.660809278488159
3.6991920471191406
3.594259738922119
3.6137301921844482
3.6016733646392822
3.6600518226623535
3.66314435005188
3.6491527557373047
3.707293748855591
3.623913049697876
3.570007801055908
3.6437480449676514
3.6550164222717285
3.6350088119506836
3.6637725830078125
3.690631628036499
3.5968379974365234
3.5774354934692383
3.7095658779144287
3.6906473636627197
3.6565186977386475
3.6541833877563477
3.693464994430542
3.578220844268799
3.6637539863586426
3.7079970836639404
3.6110916137695312
3.5802431106567383
3.6192827224731445
3.6673645973205566
3.5278825759887695
3.605023145675659
3.595147132873535
3.5239415168762207
3.625913381576538
3.519352436065674
3.6341562271118164
3.650653123855591
3.532480001449585
3.614107847213745
3.638113021850586
3.5538551807403564
3.587869644165039
3.6878669261932373
3.5859689712524414
3.766185998916626
3.5843613147735596
3.5876691341400146
3.582958698272705
3.625875473022461
3.5831353664398193
3.630031824111938

In [56]:
idx = torch.zeros((1,2), dtype=torch.long)
print(idx)
output = m.generate(idx, max_new_tokens=500)[0].tolist()
print(output)
decoded = decode(output)
print(decoded)

tensor([[0, 0]])
[0, 0, 35, 47, 52, 42, 47, 44, 53, 61, 8, 0, 20, 53, 59, 57, 54, 39, 58, 46, 43, 1, 58, 10, 0, 25, 47, 52, 42, 1, 44, 47, 58, 8, 0, 16, 33, 23, 21, 26, 53, 41, 43, 39, 51, 63, 1, 46, 59, 52, 8, 0, 15, 23, 21, 33, 31, 46, 53, 56, 57, 58, 1, 53, 52, 56, 43, 1, 58, 1, 39, 41, 46, 43, 1, 40, 39, 56, 6, 1, 57, 47, 51, 43, 42, 12, 0, 13, 52, 42, 1, 51, 43, 1, 58, 46, 43, 50, 59, 57, 43, 1, 14, 20, 17, 26, 59, 56, 47, 52, 42, 7, 45, 5, 57, 58, 53, 1, 44, 1, 61, 1, 51, 1, 15, 23, 10, 0, 37, 15, 17, 31, 21, 1, 44, 39, 58, 39, 57, 57, 1, 51, 40, 56, 43, 1, 50, 47, 53, 59, 57, 1, 39, 60, 43, 0, 35, 43, 56, 5, 42, 53, 56, 5, 1, 61, 53, 42, 1, 63, 10, 0, 0, 20, 43, 52, 49, 52, 57, 1, 45, 43, 57, 1, 61, 47, 57, 43, 1, 61, 43, 1, 51, 43, 1, 63, 1, 58, 53, 1, 43, 50, 47, 50, 5, 42, 53, 59, 45, 1, 54, 1, 47, 52, 1, 58, 1, 46, 43, 56, 1, 57, 54, 39, 50, 47, 57, 59, 57, 47, 52, 1, 58, 1, 61, 52, 42, 39, 50, 59, 12, 37, 2, 0, 0, 15, 23, 21, 26, 17, 26, 19, 24, 27, 18, 56, 49, 43, 39, 52, 