In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from fastcore.foundation import *
from fastcore.basics import *

EX2: Train the GPT on your own dataset of choice! What other data could be fun to blabber on about? (A fun advanced suggestion if you like: train a GPT to do addition of two numbers, i.e. a+b=c. You may find it helpful to predict the digits of c in reverse order, as the typical addition algorithm (that you're hoping it learns) would proceed right to left too. You may want to modify the data loader to simply serve random problems and skip the generation of train.bin, val.bin. You may want to mask out the loss at the input positions of a+b that just specify the problem using y=-1 in the targets (see CrossEntropyLoss ignore_index). Does your Transformer learn to add? Once you have this, swole doge project: build a calculator clone in GPT, for all of +-*/. Not an easy problem. You may need Chain of Thought traces.)

Example data:

x: 9 8 7 1 . <br>
y: 8 7 1 . ?

x: 2 3 5 . ? <br>
y: 3 5 . ? ?

In [2]:
a, b = torch.randint(0, 9, (1, )).item(), torch.randint(0, 9, (1, )).item()
a, b

(6, 2)

In [3]:
c = a+b
c

8

In [4]:
str(c)

'8'

In [5]:
''.join(list(reversed(str(c))))

'8'

In [6]:
data_pt = str(a) + str(b) + ''.join(list(reversed(str(c)))) + '.'
data_pt

'628.'

In [7]:
padded_dtpt = data_pt
while True:
    if len(padded_dtpt) >= 5: break
    padded_dtpt = padded_dtpt + '?'

padded_dtpt

'628.?'

In [8]:
a, b = torch.randint(0, 9, (1, )), torch.randint(0, 9, (1, ))
print(a, b)
c = a+b
rev_c = list(reversed(str(c.item())))
rev_c = torch.tensor([int(o) for o in rev_c])
print(rev_c)
out = torch.cat([a, b, rev_c, torch.tensor([10])], dim=-1)

while len(out) < 5:
    out = torch.cat([out, torch.tensor([11])], dim=-1)
out

tensor([5]) tensor([0])
tensor([5])


tensor([ 5,  0,  5, 10, 11])

In [9]:
# 10 -> '.' (end token)
# 11 -> '?' (pad token)
def _data_pt(blk_sz=5):
    a, b = torch.randint(0, 9, (1, )), torch.randint(0, 9, (1, ))
    c = a+b
    rev_c = list(reversed(str(c.item())))
    rev_c = torch.tensor([int(o) for o in rev_c])
    out = torch.cat([a, b, rev_c, torch.tensor([10])])
    while len(out) < blk_sz:
        out = torch.cat([out, torch.tensor([11])])
    x = out[:-1]
    # copy to avoid changing the original tensor
    y = out[1:].clone()
    # first token should be pad token because we do not want to predict the input
    y[0] = 11
    return x, y

data_pt = _data_pt()
data_pt

(tensor([ 0,  5,  5, 10]), tensor([11,  5, 10, 11]))

`nn.functional.cross_entropy` takes targets as class indices.

At the places where we are just predicting the input tokens, we must mask them. Since, we are doing single digit addition, we can mask the inputs by taking the last input only when `T` < 2. ~~Now that I see it, we can also remove the need for a '?' padding token and just use 0 instead when we have single digit as the sum.~~

In [10]:
for t in range(4):
    ctx = data_pt[0][t]
    tgt = data_pt[1][t]
    print(f"Input: {ctx}; Target: {tgt}")

Input: 0; Target: 11
Input: 5; Target: 5
Input: 5; Target: 10
Input: 10; Target: 11


In [11]:
x, y = data_pt
x, y

(tensor([ 0,  5,  5, 10]), tensor([11,  5, 10, 11]))

In [12]:
batch_sz = 8

[_data_pt() for _ in range(batch_sz)]

[(tensor([8, 7, 5, 1]), tensor([11,  5,  1, 10])),
 (tensor([ 4,  3,  7, 10]), tensor([11,  7, 10, 11])),
 (tensor([ 7,  1,  8, 10]), tensor([11,  8, 10, 11])),
 (tensor([ 5,  3,  8, 10]), tensor([11,  8, 10, 11])),
 (tensor([ 1,  5,  6, 10]), tensor([11,  6, 10, 11])),
 (tensor([ 1,  0,  1, 10]), tensor([11,  1, 10, 11])),
 (tensor([ 2,  6,  8, 10]), tensor([11,  8, 10, 11])),
 (tensor([ 2,  6,  8, 10]), tensor([11,  8, 10, 11]))]

In [13]:
batch_data = [_data_pt() for _ in range(batch_sz)]
xb = torch.stack([o[0] for o in batch_data])
yb = torch.stack([o[1] for o in batch_data])
xb, yb

(tensor([[ 2,  7,  9, 10],
         [ 3,  6,  9, 10],
         [ 1,  6,  7, 10],
         [ 0,  7,  7, 10],
         [ 8,  2,  0,  1],
         [ 1,  8,  9, 10],
         [ 7,  6,  3,  1],
         [ 4,  1,  5, 10]]),
 tensor([[11,  9, 10, 11],
         [11,  9, 10, 11],
         [11,  7, 10, 11],
         [11,  7, 10, 11],
         [11,  0,  1, 10],
         [11,  9, 10, 11],
         [11,  3,  1, 10],
         [11,  5, 10, 11]]))

In [14]:
def get_batch():
    batch_data = [_data_pt() for _ in range(batch_sz)]
    xb = torch.stack([o[0] for o in batch_data])
    yb = torch.stack([o[1] for o in batch_data])
    return xb, yb

get_batch()

(tensor([[ 2,  2,  4, 10],
         [ 8,  0,  8, 10],
         [ 7,  2,  9, 10],
         [ 0,  5,  5, 10],
         [ 5,  8,  3,  1],
         [ 3,  7,  0,  1],
         [ 4,  8,  2,  1],
         [ 2,  2,  4, 10]]),
 tensor([[11,  4, 10, 11],
         [11,  8, 10, 11],
         [11,  9, 10, 11],
         [11,  5, 10, 11],
         [11,  3,  1, 10],
         [11,  0,  1, 10],
         [11,  2,  1, 10],
         [11,  4, 10, 11]]))

All we need is that the llm take 7 2 and predict 9 and 0. We should "ignore_index" at the first position in `y` because that is the input digit and we do not want to penalize the model for not correctly predicting the input digit itself.

In a batch, `targets` is of size `(B*T)`. Hence we must essentially replace the index at that position to a predetermined dummy index, aka, padding token.

In [15]:
class CausalAttn(nn.Module):
    def __init__(self, head_sz, n_heads, n_embd, blk_sz):
        super().__init__()
        store_attr()
        self.attn = nn.Linear(n_embd, 3*n_embd)
        self.register_buffer('tril', torch.tril(torch.ones((blk_sz, blk_sz))))
        self.proj = nn.Linear(n_embd, n_embd, bias=True)

    def forward(self, x):
        B,T,C = x.shape
        # compute k,q,v at once and destructure it to have `n_embd` size
        k,q,v = self.attn(x).split(self.n_embd, dim=-1)
        # first view the tensors as (B, T, n_heads, head_sz), then transpose the middle dimensions to get (B, n_heads, T, head_sz).
        # think about [T, n_heads] to be a separate matrix and think about transposing it.
        # initially, you'll have T number of n_heads (have n_heads heads at each timestep) (T, n_heads)
        # after transposing, you'll have, at each head, T "blocks" or timestep elements    (n_heads, T)
        k = k.view(B, T, self.n_heads, self.head_sz).transpose(1, 2) # (B, n_heads, T, head_sz)
        q = q.view(B, T, self.n_heads, self.head_sz).transpose(1, 2) # (B, n_heads, T, head_sz)
        v = v.view(B, T, self.n_heads, self.head_sz).transpose(1, 2) # (B, n_heads, T, head_sz)
        
        # raw weights based on q, k affinity --> scaled dot product attn
        wei = q @ k.transpose(-2, -1) * self.head_sz**-0.5 # (B, n_heads, T, head_sz) @ (B, n_heads, head_sz, T) --> (B, n_heads, T, T)
        # mask past tokens and get a normalized distribution for affinities
        wei = wei.masked_fill(self.tril[:T, :T]==0, float('-inf')) # (B, n_heads, T, T)
        wei = wei.softmax(dim=-1) # (B, n_heads, T, T)
        # scale value vector with affinities
        out = wei @ v # (B, n_heads, T, T) @ (B, n_heads, T, head_sz) --> (B, n_heads, T, head_sz)
        # transpose(1, 2) --> (B, T, n_heads, head_sz)
        # contiguous --> transpose operations make the underlying memory non-contiguous. operations like view require contiguous memory representations.
        # view --> (B, T, n_heads, head_sz) -> (B, T, n_embd) (n_embd = n_heads * head_sz)
        out = out.transpose(1, 2).contiguous().view(B, T, self.n_embd)
        out = self.proj(out)
        return out

In [16]:
class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.ReLU(),
            nn.Linear(4*n_embd, n_embd),
        )

    def forward(self, x):
        return self.net(x)

In [17]:
class Block(nn.Module):
    def __init__(self, n_embd, n_heads, head_sz, blk_sz):
        super().__init__()
        store_attr()
        self.causal_attn = CausalAttn(head_sz, n_heads, n_embd, blk_sz)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.causal_attn(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x # (B, T, n_embd)

In [18]:
# 0-9 + end_token(10) + padding_token(11) = 12
vocab_sz = 12
n_embd = 8
blk_sz = 5
n_heads = 2
head_sz = n_embd // n_heads 
head_sz

4

In [19]:
class AddGPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.tok_emb_table = nn.Embedding(vocab_sz, n_embd)
        self.pos_emb_table = nn.Embedding(blk_sz, n_embd)

        self.blocks = nn.Sequential(
            Block(n_embd=n_embd, n_heads=n_heads, head_sz=head_sz, blk_sz=blk_sz),
            Block(n_embd=n_embd, n_heads=n_heads, head_sz=head_sz, blk_sz=blk_sz),
            nn.LayerNorm(n_embd)
        )
        self.lm_head = nn.Linear(n_embd, vocab_sz)

    def forward(self, idx, targets=None):
        B,T = idx.shape
        tok_emb = self.tok_emb_table(idx)
        pos_emb = self.pos_emb_table(torch.arange(T))

        x = tok_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x)

        if targets is None: loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.contiguous().view(B*T)
            loss = F.cross_entropy(logits, targets, ignore_index=11)
        return logits, loss


In [21]:
m = AddGPT()
xb, yb = get_batch()
xb.shape, yb.shape

(torch.Size([8, 4]), torch.Size([8, 4]))

In [22]:
logits, _ = m(xb, yb)
logits.shape

torch.Size([32, 12])

In [23]:
model = AddGPT()

In [24]:
optim = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [25]:
for i in range(7000):
    xb, yb = get_batch()
    
    logits, loss = model(xb, yb)
    optim.zero_grad(set_to_none=True)
    loss.backward()
    optim.step()
    
    if i % 100 == 0: print(f"step {i}: train loss {loss}")

step 0: train loss 2.768364429473877
step 100: train loss 1.9739805459976196
step 200: train loss 1.649142861366272
step 300: train loss 1.3180294036865234
step 400: train loss 1.06648588180542
step 500: train loss 0.8726248741149902
step 600: train loss 0.9190966486930847
step 700: train loss 1.018977403640747
step 800: train loss 0.8676979541778564
step 900: train loss 0.8712950348854065
step 1000: train loss 0.49270910024642944
step 1100: train loss 0.5775040984153748
step 1200: train loss 0.45224449038505554
step 1300: train loss 0.5545138716697693
step 1400: train loss 0.5130059719085693
step 1500: train loss 0.420223593711853
step 1600: train loss 0.4801110029220581
step 1700: train loss 0.3629035949707031
step 1800: train loss 0.31049254536628723
step 1900: train loss 0.253854900598526
step 2000: train loss 0.11121837049722672
step 2100: train loss 0.10780270397663116
step 2200: train loss 0.03641210123896599
step 2300: train loss 0.04282999783754349
step 2400: train loss 0.1906

In [26]:
inp = torch.tensor([[2, 3]])
inp

tensor([[2, 3]])

In [27]:
with torch.no_grad():
    logits, loss = model(inp)

logits.shape

torch.Size([1, 2, 12])

In [28]:
logits[:, -1, :].shape

torch.Size([1, 12])

In [29]:
logits = logits[:, -1, :]

In [30]:
with torch.no_grad():
    probs = F.softmax(logits, dim=-1)

probs

tensor([[2.0652e-04, 3.7571e-09, 6.6792e-05, 3.3451e-06, 3.0085e-04, 9.9867e-01,
         1.3629e-04, 4.6625e-05, 4.9033e-05, 1.9042e-06, 4.9086e-04, 3.2574e-05]])

In [31]:
torch.argmax(probs)

tensor(5)

In [39]:
ctx = torch.tensor([[9, 9]])
ctx

tensor([[9, 9]])

In [40]:
with torch.no_grad():
    logits,_ = model(ctx)
    logits = logits[:, -1, :]

logits.shape

torch.Size([1, 12])

In [41]:
with torch.no_grad():
    probs = F.softmax(logits, dim=-1)

probs

tensor([[3.0043e-05, 8.9165e-02, 1.6099e-04, 4.8151e-07, 1.2751e-05, 6.6755e-08,
         1.4713e-05, 1.1228e-08, 1.9014e-06, 3.7309e-06, 9.1059e-01, 1.6245e-05]])

In [42]:
torch.argmax(probs)

tensor(10)

In [38]:
@torch.no_grad()
def add(a, b):
    ctx = torch.tensor([[a, b]])
    while ctx[0, -1] != 10:
        logits, _ = model(ctx)
        logits = logits[:, -1, :] # B, T, C -> B, C (selects the last timestep)
        probs = F.softmax(logits, dim=-1)
        next_tok = torch.argmax(probs).unsqueeze(0).unsqueeze(0) # 0 dimensional -> 2 dimensional
        ctx = torch.cat([ctx, next_tok], dim=-1)

    return ctx

print(add(8, 8))
print(add(1, 2))
print(add(4, 5))
print(add(6, 7))
print(add(9, 0))
print(add(9, 9))

tensor([[ 8,  8,  6,  1, 10]])
tensor([[ 1,  2,  3, 10]])
tensor([[ 4,  5,  9, 10]])
tensor([[ 6,  7,  3,  1, 10]])
tensor([[ 9,  0,  1, 10]])
tensor([[ 9,  9, 10]])


In [61]:
for i in range(10):
    for j in range(i+1, 10):
        if i == 9 or j == 9: continue
        print(f"{i} + {j} = {add(i, j)[0][2:-1]}")

0 + 1 = tensor([1])
0 + 2 = tensor([2])
0 + 3 = tensor([3])
0 + 4 = tensor([4])
0 + 5 = tensor([5])
0 + 6 = tensor([6])
0 + 7 = tensor([7])
0 + 8 = tensor([8])
1 + 2 = tensor([3])
1 + 3 = tensor([4])
1 + 4 = tensor([5])
1 + 5 = tensor([6])
1 + 6 = tensor([7])
1 + 7 = tensor([8])
1 + 8 = tensor([9])
2 + 3 = tensor([5])
2 + 4 = tensor([6])
2 + 5 = tensor([7])
2 + 6 = tensor([8])
2 + 7 = tensor([9])
2 + 8 = tensor([0, 1])
3 + 4 = tensor([7])
3 + 5 = tensor([8])
3 + 6 = tensor([9])
3 + 7 = tensor([0, 1])
3 + 8 = tensor([1, 1])
4 + 5 = tensor([9])
4 + 6 = tensor([0, 1])
4 + 7 = tensor([1, 1])
4 + 8 = tensor([2, 1])
5 + 6 = tensor([1, 1])
5 + 7 = tensor([2, 1])
5 + 8 = tensor([3, 1])
6 + 7 = tensor([3, 1])
6 + 8 = tensor([4, 1])
7 + 8 = tensor([5, 1])


In [65]:
for i in range(10):
    for j in range(i+1, 10):
        if i == 9 or j == 9: continue
        print(f"{j} + {i} = {add(j, i)[0][2:-1]}")

1 + 0 = tensor([1])
2 + 0 = tensor([2])
3 + 0 = tensor([3])
4 + 0 = tensor([4])
5 + 0 = tensor([5])
6 + 0 = tensor([6])
7 + 0 = tensor([7])
8 + 0 = tensor([8])
2 + 1 = tensor([3])
3 + 1 = tensor([4])
4 + 1 = tensor([5])
5 + 1 = tensor([6])
6 + 1 = tensor([7])
7 + 1 = tensor([8])
8 + 1 = tensor([9])
3 + 2 = tensor([5])
4 + 2 = tensor([6])
5 + 2 = tensor([7])
6 + 2 = tensor([8])
7 + 2 = tensor([9])
8 + 2 = tensor([0, 1])
4 + 3 = tensor([7])
5 + 3 = tensor([8])
6 + 3 = tensor([9])
7 + 3 = tensor([0, 1])
8 + 3 = tensor([1, 1])
5 + 4 = tensor([9])
6 + 4 = tensor([0, 1])
7 + 4 = tensor([1, 1])
8 + 4 = tensor([2, 1])
6 + 5 = tensor([1, 1])
7 + 5 = tensor([2, 1])
8 + 5 = tensor([3, 1])
7 + 6 = tensor([3, 1])
8 + 6 = tensor([4, 1])
8 + 7 = tensor([5, 1])


In [64]:
for i in range(10):
    for j in range(i+1, 10):
        if i != 9 and j != 9: continue
        print(f"{i} + {j} = {add(i, j)[0]}")

0 + 9 = tensor([ 0,  9, 10])
1 + 9 = tensor([ 1,  9, 10])
2 + 9 = tensor([ 2,  9, 10])
3 + 9 = tensor([ 3,  9,  3, 10])
4 + 9 = tensor([ 4,  9,  1, 10])
5 + 9 = tensor([ 5,  9,  3, 10])
6 + 9 = tensor([ 6,  9,  1, 10])
7 + 9 = tensor([ 7,  9,  8, 10])
8 + 9 = tensor([ 8,  9, 10])


In [59]:
add(3, 9)

tensor([[ 3,  9,  3, 10]])