<a href="https://colab.research.google.com/github/R0bk/ml_replications/blob/main/07_decoder_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import requests

tiny_shake = requests.get('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt').text
print(tiny_shake[:300])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us


In [None]:
import torch

vocab = sorted(list(set(tiny_shake)))
vocab_size = len(vocab)
print(f'Len of vocab: {vocab_size}, vocab: {"".join(vocab)}')

vocab2id = {x: i for i, x in enumerate(vocab)}
id2vocab = {i: x for i, x in enumerate(vocab)}

encode = lambda x: torch.tensor([vocab2id[y] for y in x])
decode = lambda x: ''.join(id2vocab[int(y)] for y in x)

Len of vocab: 65, vocab: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [None]:
tt_split = int(0.9*len(tiny_shake))

train_shake = tiny_shake[:tt_split]
val_shake = tiny_shake[tt_split:]

In [None]:
max_seq_len = 8


def get_batch(data: list[int], b: int, seq_len: int):
    idx = torch.randint(len(data)-(seq_len+1), (b,))
    xb = torch.stack([data[i:i+seq_len] for i in idx])
    yb = torch.stack([data[i+1:i+seq_len+1] for i in idx])
    return xb, yb

x, y = get_batch(encode(train_shake), 4, max_seq_len)
print(x)
for i, y in enumerate(x):
    print(i, decode(y))

x.shape
def split_batch(batch):
    seq_len = batch.shape[1] # (b, seqlen)
    for i in range(1, seq_len):
        x, y = batch[:, :i], batch[:, i:i+1]
        yield x, y


split_batch(x)

tensor([[58,  1, 45, 56, 39, 41, 47, 53],
        [10,  0, 32, 46, 43,  1, 41, 59],
        [50, 50,  1, 46, 43,  1, 41, 53],
        [ 1, 39, 51,  1, 39, 44, 44, 47]])
0 t gracio
1 :
The cu
2 ll he co
3  am affi


<generator object split_batch at 0x7fc7345f2740>

In [None]:
import torch.nn.functional as F
from torch import nn

class BigramModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, vocab_size)

    def forward(self, x, y=None):
        x = self.embedding(x) # (b, seqlen) -> (b, seqlen, vocabsize)
        
        if y is None:
            return x

        B, T, C = x.shape
        x = x.view(B*T, C)
        y = y.view(-1)
        
        loss = F.cross_entropy(x, y)

        return x, loss

    def prompt(self, pmt: str, mtkn: int):
        x = encode(pmt).unsqueeze(0)
        print(x)

        for i in range(mtkn):
            newx = self(x) # (B, T C)
            newx = newx[:, -1, :] # (B, 1, C)
            
            probs = F.softmax(newx, dim=-1)
            
            newx = torch.multinomial(probs, num_samples=1) # (B, 1)

            x = torch.cat([x, newx], axis=1)
        return decode(x[0])


m = BigramModel(len(vocab))

In [None]:
m.prompt('Hi', 50)

tensor([[20, 47]])


"Hi LLG\nOQ$cZ3fXnORhv'skcpC.NxIs ,WpeFei.ZAvph\n !e&av"

In [None]:
print(m.prompt('Hi', 150))

tensor([[20, 47]])
HitcKsAqg-ovndQ?yT&v&m:hrPyNdPhwD-;J&KyiTyi..ZZxhe:huTyYXNQJ.Cmm.qUTrXHsJx,ffpC-yzMHpzIe'x!UAVlW3lFwH.-BNPYQoMUJtfZ; $?swovl$HqIbg,CQjsw &P.$ck
HbruNtQZ


In [None]:
x, y = get_batch(encode(train_shake), 1, max_seq_len)
print(x, y)
x, l = m(x, y)
print(l)

tensor([[39, 56, 43,  1, 44, 47, 58,  1]]) tensor([[56, 43,  1, 44, 47, 58,  1, 58]])
tensor(4.0485, grad_fn=<NllLossBackward0>)


In [None]:
x = get_batch(encode(train_shake), 1, max_seq_len)
print(x)
out = decode(x[0])
print(out)

nsteps = 100
for i in range(nsteps):
    x = torch.cat([x, torch.argmax(m(x)).view(1,1)], axis=1)

print(decode(x[0]))

(tensor([[58, 46, 47, 52, 49,  1, 63, 53]]), tensor([[46, 47, 52, 49,  1, 63, 53, 59]]))


ValueError: ignored

In [None]:
lr = 1e-4
optim = torch.optim.AdamW(m.parameters())

In [None]:
steps = 100

for t in range(steps):
    xb, yb = get_batch(encode(train_shake), 32, max_seq_len)

    output, loss = m(xb, yb)
    optim.zero_grad()
    loss.backward()
    print(loss)

    optim.step()


tensor(4.5281, grad_fn=<NllLossBackward0>)
tensor(4.6831, grad_fn=<NllLossBackward0>)
tensor(4.5577, grad_fn=<NllLossBackward0>)
tensor(4.5618, grad_fn=<NllLossBackward0>)
tensor(4.4680, grad_fn=<NllLossBackward0>)
tensor(4.5053, grad_fn=<NllLossBackward0>)
tensor(4.5063, grad_fn=<NllLossBackward0>)
tensor(4.4562, grad_fn=<NllLossBackward0>)
tensor(4.5525, grad_fn=<NllLossBackward0>)
tensor(4.4821, grad_fn=<NllLossBackward0>)
tensor(4.5869, grad_fn=<NllLossBackward0>)
tensor(4.6048, grad_fn=<NllLossBackward0>)
tensor(4.5572, grad_fn=<NllLossBackward0>)
tensor(4.5717, grad_fn=<NllLossBackward0>)
tensor(4.4437, grad_fn=<NllLossBackward0>)
tensor(4.4187, grad_fn=<NllLossBackward0>)
tensor(4.5367, grad_fn=<NllLossBackward0>)
tensor(4.5730, grad_fn=<NllLossBackward0>)
tensor(4.5213, grad_fn=<NllLossBackward0>)
tensor(4.4505, grad_fn=<NllLossBackward0>)
tensor(4.4537, grad_fn=<NllLossBackward0>)
tensor(4.5617, grad_fn=<NllLossBackward0>)
tensor(4.4751, grad_fn=<NllLossBackward0>)
tensor(4.51

In [None]:
B, S, E = 2, 4, 6
# Version 1
x = torch.rand((B, S, E))

xsum = x.clone()
for tokeni in range(1, S):
    xsum[:, tokeni, :] += xsum[:, tokeni-1, :]


# Version 2
x = torch.randint(10, (B,S,E), dtype=torch.float).transpose(1,2)
wei = torch.triu(torch.ones((S,S)))
wei = wei / wei.sum(dim=0)
xavg = x @ wei


# Version 3
x = torch.randint(10, (B,S,E), dtype=torch.float)
print(x)
wei = torch.zeros((S,S))
wei = wei.masked_fill(torch.tril(torch.ones(S,S))==0, float('-inf'))
wei = F.softmax(wei, dim=-1)
print(wei)
print(wei.shape, x.shape)
print(wei @ x)

# Version 4, self atteniton, single head
head_size = 16

x = torch.rand((B,S,E))

key = nn.Linear(E, head_size, bias=False)
query = nn.Linear(E, head_size, bias=False)
value = nn.Linear(E, head_size, bias=False)

k = key(x) # (B,T,E)->(B,T,head_size)
q = query(x) # (B,T,E)->(B,T,head_size)

wei = q @ k.transpose(1,2) # (B, T, T)
wei = wei / head_size**-0.5

wei = wei.masked_fill(torch.tril(torch.ones(S,S))==0, float('-inf'))
wei = F.softmax(wei, dim=-1)
print(wei)
print(wei.shape, x.shape)
wei

v = value(x) # (B,T,E)->(B,T,head_size)

out = wei @ v # (B,T,T) @ (B,T,E) -> (B, T, E)

tensor([[[7., 7., 2., 5., 5., 7.],
         [9., 6., 1., 3., 1., 1.],
         [8., 4., 8., 9., 9., 4.],
         [9., 7., 4., 9., 9., 2.]],

        [[4., 6., 6., 1., 7., 2.],
         [0., 8., 7., 8., 3., 5.],
         [6., 3., 7., 7., 8., 5.],
         [4., 5., 2., 0., 2., 4.]]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500]])
torch.Size([4, 4]) torch.Size([2, 4, 6])
tensor([[[7.0000, 7.0000, 2.0000, 5.0000, 5.0000, 7.0000],
         [8.0000, 6.5000, 1.5000, 4.0000, 3.0000, 4.0000],
         [8.0000, 5.6667, 3.6667, 5.6667, 5.0000, 4.0000],
         [8.2500, 6.0000, 3.7500, 6.5000, 6.0000, 3.5000]],

        [[4.0000, 6.0000, 6.0000, 1.0000, 7.0000, 2.0000],
         [2.0000, 7.0000, 6.5000, 4.5000, 5.0000, 3.5000],
         [3.3333, 5.6667, 6.6667, 5.3333, 6.0000, 4.0000],
         [3.5000, 5.5000, 5.5000, 4.0000, 5.0000, 4.0000]]])
tensor([[[1.0000, 0.0000, 0.00

In [None]:
n_embed = 32
block_size = 32
max_seq_len = 32
# head_size = 16

class Head(nn.Module):

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, E = x.shape

        k = self.key(x) # (B,T,head_size)
        q = self.query(x) # (B,T,head_size)
        v = self.value(x) # (B,T,head_size)

        wei = q @ k.transpose(1, 2) # (B,T,T)
        wei = wei * head_size**-0.5
        wei = wei.masked_fill(self.tril[:T, :T]==0, float('-inf'))
        wei = F.softmax(wei, dim=-1)

        return wei @ v # (B,T,T) @ (B,T,head_size) -> (B, T, head_size)


class MultiHeadAttention(nn.Module):

    def __init__(self, head_size, head_count):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(head_count)])
        self.proj = nn.Linear(n_embed, n_embed)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # (B, T, head_size*head_count)
        return self.proj(out)


class FeedForward(nn.Module):
    
    def __init__(self, n_embed):
        super().__init__()
        self.l1 = nn.Linear(n_embed, 4*n_embed)
        self.proj = nn.Linear(4*n_embed, n_embed)

    def forward(self, x):
        x = self.l1(x)
        x = F.relu(x)
        return self.proj(x)

class Block(nn.Module):

    def __init__(self, n_embed, n_head):
        super().__init__()
        head_size = n_embed//n_head
        self.sa_head = MultiHeadAttention(head_size, n_head)
        self.ffwd = FeedForward(n_embed)
        self.lnorm1 = nn.LayerNorm(n_embed)
        self.lnorm2 = nn.LayerNorm(n_embed)


    def forward(self, x):
        x = x + self.sa_head(self.lnorm1(x)) # (B, T, head_size*head_count==n_embed)
        x = x + self.ffwd(self.lnorm2(x))
        return x

class BigramModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embed)
        self.pos_embedding = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(
            Block(n_embed, 4),
            Block(n_embed, 4),
            Block(n_embed, 4),
            nn.LayerNorm(n_embed)
        )
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, x, y=None):
        B, T = x.shape
        
        token_embed = self.token_embedding(x) # (B, T) -> (B, T, n_embed)
        pos_embed = self.pos_embedding(torch.arange(T))
        x = token_embed + pos_embed
        x = self.blocks(x)
        

        logits = self.lm_head(x) # (B, T, n_embed) -> (B, T, vocab_size)
        
        if y is None:
            return logits

        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        y = y.view(-1)
        
        loss = F.cross_entropy(logits, y)

        return logits, loss

    def prompt(self, pmt: str, mtkn: int):
        x = encode(pmt).unsqueeze(0)
        print(x)

        for i in range(mtkn):
            x_cond = x[:, -block_size:]

            newx = self(x_cond) # (B, T C)

            newx = newx[:, -1, :] # (B, 1, C)
            
            probs = F.softmax(newx, dim=-1)
            
            newx = torch.multinomial(probs, num_samples=1) # (B, 1)

            x = torch.cat([x, newx], axis=1)
        return decode(x[0])


m = BigramModel()

In [None]:
lr = 3e-4
optim = torch.optim.AdamW(m.parameters())

steps = 2500

for t in range(steps):
    xb, yb = get_batch(encode(train_shake), 32, max_seq_len)

    output, loss = m(xb, yb)
    optim.zero_grad()
    loss.backward()
    print(t, loss)

    optim.step()


0 tensor(4.3447, grad_fn=<NllLossBackward0>)
1 tensor(4.2665, grad_fn=<NllLossBackward0>)
2 tensor(4.2109, grad_fn=<NllLossBackward0>)
3 tensor(4.1300, grad_fn=<NllLossBackward0>)
4 tensor(4.1077, grad_fn=<NllLossBackward0>)
5 tensor(4.0284, grad_fn=<NllLossBackward0>)
6 tensor(3.9783, grad_fn=<NllLossBackward0>)
7 tensor(3.9293, grad_fn=<NllLossBackward0>)
8 tensor(3.8856, grad_fn=<NllLossBackward0>)
9 tensor(3.8378, grad_fn=<NllLossBackward0>)
10 tensor(3.7974, grad_fn=<NllLossBackward0>)
11 tensor(3.8055, grad_fn=<NllLossBackward0>)
12 tensor(3.7279, grad_fn=<NllLossBackward0>)
13 tensor(3.7594, grad_fn=<NllLossBackward0>)
14 tensor(3.6441, grad_fn=<NllLossBackward0>)
15 tensor(3.6660, grad_fn=<NllLossBackward0>)
16 tensor(3.6011, grad_fn=<NllLossBackward0>)
17 tensor(3.5662, grad_fn=<NllLossBackward0>)
18 tensor(3.5600, grad_fn=<NllLossBackward0>)
19 tensor(3.5942, grad_fn=<NllLossBackward0>)
20 tensor(3.5567, grad_fn=<NllLossBackward0>)
21 tensor(3.5351, grad_fn=<NllLossBackward0>

In [None]:
print(m.prompt('t', 250))

tensor([[58]])
to warrruck of wheman come. Gor too, they of whe with ulintase;
Thisch somw whis'd you.
Thorewe hat morre, leger, alese ret of waltus,
The, you tid knore, him kine.

QUCUTUTOLO:
Jhat him! i'l luff ne thein atre han do;
And not bether benfuterbr to hou
