# Generative Model from Scratch
---

In [101]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [26]:
SEED = 1337

## 1. Data

In [2]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-01-28 11:43:26--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8003::154, 2606:50c0:8000::154, 2606:50c0:8001::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8003::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.2’


2024-01-28 11:43:26 (8.27 MB/s) - ‘input.txt.2’ saved [1115394/1115394]



### 1.1 Read data

In [4]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print("length of dataset in characters: ", len(text))
print("first 1000 characters: ", text[:1000])

length of dataset in characters:  1115394
first 1000 characters:  First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods k

### 1.2 Tokenizer

In [5]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [7]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


### 1.3 Tokenize Data

In [14]:
data = torch.tensor(encode(text), dtype=torch.long)
data.shape, data.dtype

(torch.Size([1115394]), torch.int64)

In [12]:
data

tensor([18, 47, 56,  ..., 45,  8,  0])

In [25]:
n = int(len(data)*0.9)
train_data, val_data = data[:n], data[n:]

train_data.shape, val_data.shape

(torch.Size([1003854]), torch.Size([111540]))

### 1.4 Batch + Chunk

In [81]:
block_size = 8  # chunk size
batch_size = 4

In [23]:
for t in range(1, block_size+1):
    inputs = data[:t]
    targets = data[t]
    print(inputs, "->", targets)

tensor([18]) -> tensor(47)
tensor([18, 47]) -> tensor(56)
tensor([18, 47, 56]) -> tensor(57)
tensor([18, 47, 56, 57]) -> tensor(58)
tensor([18, 47, 56, 57, 58]) -> tensor(1)
tensor([18, 47, 56, 57, 58,  1]) -> tensor(15)
tensor([18, 47, 56, 57, 58,  1, 15]) -> tensor(47)
tensor([18, 47, 56, 57, 58,  1, 15, 47]) -> tensor(58)


In [341]:
torch.manual_seed(SEED)

def random_chunks(data, block_size, batch_size):
    idxs = torch.randint(len(data) - block_size, (batch_size, ))
    return (
        torch.stack([data[i: i+block_size] for i in idxs]), 
        torch.stack([data[i+1: i+1+block_size] for i in idxs]),
    )

x, y = random_chunks(data=val_data, block_size=block_size, batch_size=batch_size)
print(x.shape, y.shape)

for b in range(batch_size):
    for t in range(block_size):
        ctx = x[b, :t+1]
        tgt = y[b, t]
        print(ctx, "->", tgt)

torch.Size([32, 8]) torch.Size([32, 8])
tensor([6]) -> tensor(1)
tensor([6, 1]) -> tensor(52)
tensor([ 6,  1, 52]) -> tensor(53)
tensor([ 6,  1, 52, 53]) -> tensor(58)
tensor([ 6,  1, 52, 53, 58]) -> tensor(1)
tensor([ 6,  1, 52, 53, 58,  1]) -> tensor(58)
tensor([ 6,  1, 52, 53, 58,  1, 58]) -> tensor(47)
tensor([ 6,  1, 52, 53, 58,  1, 58, 47]) -> tensor(50)
tensor([6]) -> tensor(1)
tensor([6, 1]) -> tensor(54)
tensor([ 6,  1, 54]) -> tensor(50)
tensor([ 6,  1, 54, 50]) -> tensor(39)
tensor([ 6,  1, 54, 50, 39]) -> tensor(52)
tensor([ 6,  1, 54, 50, 39, 52]) -> tensor(58)
tensor([ 6,  1, 54, 50, 39, 52, 58]) -> tensor(43)
tensor([ 6,  1, 54, 50, 39, 52, 58, 43]) -> tensor(58)
tensor([1]) -> tensor(58)
tensor([ 1, 58]) -> tensor(46)
tensor([ 1, 58, 46]) -> tensor(47)
tensor([ 1, 58, 46, 47]) -> tensor(57)
tensor([ 1, 58, 46, 47, 57]) -> tensor(1)
tensor([ 1, 58, 46, 47, 57,  1]) -> tensor(50)
tensor([ 1, 58, 46, 47, 57,  1, 50]) -> tensor(47)
tensor([ 1, 58, 46, 47, 57,  1, 50, 47]) -

## 2. Model

### 2.1 BiGramLanguage Model

In [343]:
torch.manual_seed(SEED)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, vocab_size)

    @property
    def device(self):
        return self.embedding.weight.device

    def forward(self, idx, targets=None):
        logits = self.embedding(idx) # B, T, C, (4, 8) --> (4, 8, 65)

        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
            return logits, loss
        
        return logits, None
    
    def generate(self, idx, max_new_tokens):
        
        for _ in range(max_new_tokens):
            logits, _ = self(idx) # B, T, C
            last_logits = logits[:, -1, :]  # B, -1, C --> B, C
           
            probs = F.softmax(last_logits, dim=-1)  # B, C
            next_idx = torch.multinomial(probs, num_samples=1) # B, 1
            idx = torch.cat([idx, next_idx], dim=1)
        return idx
    

m = BigramLanguageModel(vocab_size=vocab_size)
logits, loss = m(x, y)  
print(logits.shape) 
print(loss)

start_token = torch.zeros((1, 1), dtype=torch.long)
generated_tokens = m.generate(idx=start_token, max_new_tokens=500)
print(generated_tokens.shape)

print(decode(generated_tokens[0].tolist()))

torch.Size([256, 65])
tensor(4.5696, grad_fn=<NllLossBackward0>)
torch.Size([1, 501])

Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3!dcbf?pGXepydZJSrF$Jrqt!:wwWSzPNxbjPiD&Q!a;yNt$Kr$o-gC$WSjJqfBKBySKtSKpwNNfyl&w:q-jluBatD$Lj;?yzyUca!UQ!vrpxZQgC-hlkq,ptKqHoiX-jjeLJ &slERj KUsBOL!mpJO!zLg'wNfqHAMgq'hZCWhu.W.IBcP 
RFJ&DEs,nw?pxE?xjNHHVxJ&D&vWWToiERJFuszPyZaNw$
EQJMgzaveDDIoiMl&sMHkzdRptRCPVjwW.RSVMjs-bgRkzrBTEa!!oP fRSxq.PLboTMkX'DUYepxIBFAYuxKXe.jeh
sa!3MGFrSjuM:wX!?BTMl!.?,M:bQzPHpYfN!Cbo'MmtDxBkDD3SBjyFdmY'DOqkWeRjlxyJB-bVbfd&


In [179]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [287]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [346]:
m = BigramLanguageModel(vocab_size=vocab_size)
m = m.to(device)
print(m.device)

optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

cuda:0


In [347]:
batch_size = 32

for steps in range(10_000):

    x, y = random_chunks(data=train_data, block_size=block_size, batch_size=batch_size)
    x, y = x.to(device=device), y.to(device=device)
    logits, loss = m(x, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    # print(loss.item())
print(loss.item())

2.456115484237671


In [352]:
gen_text = decode(m.generate(idx=torch.zeros((1, 1), dtype=torch.long, device=device), max_new_tokens=500)[0].tolist())
print(gen_text)


Thins; s ookesthouk bl,-mer, s, es s;
RICld t olk verescee Clll!
Y theee isheresed we h, st ar
ENCELONcedrveiryo derk ht kng
NG thin.
To'd!RWANERUSTRI she. ar?
EEYOR:
IA:
Whthathe couf mhir, byoass ooung, hathese hawaye hernd end,
My wiflouth se ma ar h ccr ces cilt, ofamongl,
Youlateald weat r theat bef fowalo'dooorrt bur IULENRerower lifl he punint gur outo theayoourer:
Anow.
Whur woury po my,
The,
IALEseng,
Jag. d bV:
EDofld
Lor s PR pat ingr brifo s men'ed.
Whathme s his ote h henike o an we


## Maths behind self-attention

### Toy example

In [394]:
torch.manual_seed(SEED)

B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [405]:
# Version 1: Get average of context vectors
# Brute force method

xbow = torch.zeros(B, T, C)     # Dummy output
for b in range(B):              # For each batch    
    for t in range(T):          # For each token
        ctx = x[b, :t+1]        # Get context
        avg = ctx.mean(dim=0)   # Average context tokens
        xbow[b, t] = avg        # Store in output dummy tensor
        
xbow.shape, xbow[0]

(torch.Size([4, 8, 2]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

In [413]:
# Version 2: using triangular matrix for context

wei = torch.tril(torch.ones((T, T)))    # Lower Triangular matrix
wei /= wei.sum(-1, keepdim=True)        # Normalize weights along each row to sum as 1.
xbow2 = wei @ x                         # multiply wei (weights) with x, to get weights average. B is broadcasted


print(f"{wei=}")
xbow2.shape, xbow2[0]

wei=tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


(torch.Size([4, 8, 2]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

In [414]:
# Version 3 : Softmax
# Same as version 2, but using softmax instead of manual normalization,
# This helps in version 4's intution to calculate wei matrix as a dot product of two vectors. (similarity of key, query)

tril = torch.tril(torch.ones((T, T)))               # Lower Triangular matrix
wei = torch.zeros((T, T))                           # Dummy weights matrix
wei = wei.masked_fill(tril == 0, float('-inf'))     # Mask out upper triangular matrix to -inf for softmax to return 0
wei = F.softmax(wei, dim=-1)                        # Softmax along rows, gives out same wei as version 2.
xbow3 = wei @ x

print(f"{wei=}")
xbow3.shape, xbow3[0]

wei=tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


(torch.Size([4, 8, 2]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

In [419]:
# Version 4 : Dot Product or Self Attention 
# Using dot product of query and key to calculate weights (wei)

n_head_dim = 16

key = nn.Linear(C, n_head_dim, bias=False)      # Query
query = nn.Linear(C, n_head_dim, bias=False)    # Key
value = nn.Linear(C, n_head_dim, bias=False)    # Value

k = key(x)                                      # B, T, C --> B, T, n_head_dim
q = query(x)                                    # B, T, C --> B, T, n_head_dim
v = value(x)                                    # B, T, C --> B, T, n_head_dim

wei = q @ k.transpose(-2, -1)                   # B, T, n_head_dim @ B, n_head_dim, T --> B, T, T

tril = torch.tril(torch.ones((T, T)))               # Lower Triangular matrix
# wei = torch.zeros((T, T))                           # Dummy weights matrix
wei = wei.masked_fill(tril == 0, float('-inf'))     # Mask out upper triangular matrix to -inf for softmax to return 0
wei = F.softmax(wei, dim=-1)                        # Softmax along rows, gives out same wei as version 2.
xbow4 = wei @ v

print(f"{wei[0]=}")  # Note here wei is not braodcasted, as we have B here. 
xbow4.shape, xbow4[0]

wei[0]=tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6148, 0.3852, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3517, 0.2045, 0.4438, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2126, 0.0950, 0.3036, 0.3887, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1686, 0.3346, 0.2068, 0.2505, 0.0395, 0.0000, 0.0000, 0.0000],
        [0.0900, 0.1956, 0.0530, 0.0361, 0.0583, 0.5671, 0.0000, 0.0000],
        [0.1370, 0.0900, 0.1473, 0.1532, 0.2437, 0.1015, 0.1273, 0.0000],
        [0.0628, 0.5481, 0.0491, 0.0444, 0.0026, 0.1951, 0.0970, 0.0010]],
       grad_fn=<SelectBackward0>)


(torch.Size([4, 8, 16]),
 tensor([[-0.1027,  0.1132, -0.0518,  0.1425, -0.1165,  0.0059, -0.1514,  0.0051,
           0.1137, -0.1668,  0.0495,  0.0394,  0.0073, -0.0748,  0.0068, -0.1065],
         [-0.1870,  0.0150, -0.2778,  0.0643, -0.1843, -0.0262, -0.2306,  0.0869,
          -0.0337, -0.2108, -0.2747, -0.2018,  0.2746,  0.0031,  0.2642,  0.0464],
         [-0.2004,  0.1753, -0.1448,  0.2315, -0.2206,  0.0027, -0.2846,  0.0285,
           0.1646, -0.3032,  0.0097,  0.0118,  0.0765, -0.1127,  0.0733, -0.1506],
         [-0.2465,  0.3421, -0.0568,  0.4139, -0.2897,  0.0278, -0.3799, -0.0163,
           0.3617, -0.4345,  0.2533,  0.1955, -0.0788, -0.2307, -0.0765, -0.3443],
         [-0.2348,  0.1943, -0.1804,  0.2599, -0.2569,  0.0010, -0.3309,  0.0379,
           0.1788, -0.3499, -0.0100, -0.0022,  0.1049, -0.1239,  0.1005, -0.1624],
         [ 0.0733, -0.4034, -0.2727, -0.4311,  0.1299, -0.0666,  0.1839,  0.1273,
          -0.4877,  0.2757, -0.6514, -0.4901,  0.4360,  0.2884,  0.4