In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [168]:
text = open('input.txt', 'r').read()

# build vocab
tokens = sorted(list(set(''.join(text))))
# let's assume '#' is my delimiter (start, end) token

# we need to embed these tokens into some numbers
stoi = {x: i + 1 for i, x in enumerate(tokens)}
stoi['#'] = 0
itos = {i: x for x, i in stoi.items()}
vocab_size = len(itos)

# character level tokenizer
# TODO: try sub-word tokenizer
encode = lambda s: [stoi[c] for c in s] # string -> vector
decode = lambda x: ''.join([itos[i] for i in x]) # vector -> string

In [68]:
# encode the dataset
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:100])

torch.Size([1115394]) torch.int64
tensor([19, 48, 57, 58, 59,  2, 16, 48, 59, 48, 65, 44, 53, 11,  1, 15, 44, 45,
        54, 57, 44,  2, 62, 44,  2, 55, 57, 54, 42, 44, 44, 43,  2, 40, 53, 64,
         2, 45, 60, 57, 59, 47, 44, 57,  7,  2, 47, 44, 40, 57,  2, 52, 44,  2,
        58, 55, 44, 40, 50,  9,  1,  1, 14, 51, 51, 11,  1, 32, 55, 44, 40, 50,
         7,  2, 58, 55, 44, 40, 50,  9,  1,  1, 19, 48, 57, 58, 59,  2, 16, 48,
        59, 48, 65, 44, 53, 11,  1, 38, 54, 60])


In [71]:
# build the dataset
n = int(0.9 * len(data))
d_train = data[:n]
d_val = data[n:]

print(d_train.shape)
print(d_val.shape)

torch.Size([1003854])
torch.Size([111540])


In [72]:
block_size = 8
d_train[:block_size + 1]

tensor([19, 48, 57, 58, 59,  2, 16, 48, 59])

In [74]:
x = d_train[:block_size]
y = d_train[1:block_size + 1]
for t in range(block_size):
    ctx = x[:t+1]
    target = y[t]
    print(f"input = {ctx}, output = {target}")

input = tensor([19]), output = 48
input = tensor([19, 48]), output = 57
input = tensor([19, 48, 57]), output = 58
input = tensor([19, 48, 57, 58]), output = 59
input = tensor([19, 48, 57, 58, 59]), output = 2
input = tensor([19, 48, 57, 58, 59,  2]), output = 16
input = tensor([19, 48, 57, 58, 59,  2, 16]), output = 48
input = tensor([19, 48, 57, 58, 59,  2, 16, 48]), output = 59


In [170]:
# usable datasets, batches basically
torch.manual_seed(1332) # for reproducibility
batch_size = 32 # independent sequences fed in one pass
block_size = 8 # context length

def get_batch(split):
    # gen batch of random x and y
    data = d_train if split == 'train' else d_val
    ix = torch.randint(len(data) - block_size, (batch_size, ))
    x = torch.stack([data[i : i + block_size] for i in ix])
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
    return x, y

xb, yb = get_batch('train')

In [171]:
# simple bigram
import torch.nn as nn

class BigramLLM(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # each token reads logits for the next token from lookup table
        self.token_emb_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None):
        # idx and targets are (B, T)
        # B -> batch dim
        # T -> time dim -> ctx
        # after embedding them, each idx[i, j] will be have emb dim
        # it'll become (B, T, C) -> C is emb dim
        logits = self.token_emb_table(idx)
        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            # -ve log liklihood loss
            # cross_entropy in pytorch
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of current ctx indices
        for _ in range(max_new_tokens):
            # predictions
            logits, loss = self(idx)
            # focus on latest time step
            logits = logits[:, -1, :] # pick latest T from (B, T, C) -> becomes (B, C)
            # softmax
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from probabilities
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # add to running index
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

m = BigramLLM(vocab_size)

In [173]:
# pytorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [181]:
batch_size = 32
for steps in range(10000):
    xb, yb = get_batch('train')
    
    # eval loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
print(loss.item())

2.373624801635742


In [184]:
print(decode(m.generate(torch.zeros((1, 1), dtype=torch.long), 1000)[0].tolist()))

#!
Gongo wicowerel-f t helby ES:
Tor, f an mawithitcthease, her e wererery s t y IUR CAl wanakikely ds buret Whethead s ht I at bures iengrto he isthe otarid bu asplo ormecotenol, alidderhie de.

Courearechegr cores atitowarthit

OT:
ICELel, m s y tedas thu y yimete,-prd E t e, O iss?
ONRLIs, culmest beest, thin VOr aro tnondourth'dinsike heanent rego ses, hea te?

Y:

My ter.
And tcameshacerstyortot?
SThay loucherestooims ou aratr
Qundo shichat-szent thoolsenkeanisulengrgrs harangsar hom.
Preno
Wid ht har, he, moono o thelin, s mee.

IZTheand;
Y m t.
I t g t US: there n.
USefend s, oow ssthesther foouns annd, wer.
TIIXIV:


he t we we RENEDu baineethue t l; e.
Gowouou, Mar, meese d myorncthel ifrm; m ss:
Andelitasthas hintlf heno atouatlit ifurere s y?
GULOMitilyomas:
TUCENEENGojo, a n fou can ishome VEYofof

QUCES:
ABo ge INote.
ONor? itchifod GUn s prenews s, nd w brusargat mm, se, BOL.

I it!
ABe shige inghilicar ce tome lltof
PSoe bred fo, p e.
pider ouran:
CEOfe hindeamy het, chi

## Mathematical trick in self-attention

device: cpu


step 0: train loss 4.2175, val loss 4.2171
step 250: train loss 3.2356, val loss 3.2395
step 500: train loss 3.1194, val loss 3.1135
step 750: train loss 3.0183, val loss 3.0104









torch.Size([1, 1])

torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 4])
torch.Size([1, 5])
torch.Size([1, 6])
torch.Size([1, 7])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([

torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([

torch.Size([1, 5])