In [2]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm

In [47]:
# Create Vocabulary
chars = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "=", "."]

# Tokenization
stoi = {s:i for i, s in enumerate(chars)}
itos = {i:s for s, i in stoi.items()}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [74]:
# Hyperparameters
VOCAB_SIZE = len(chars)
MAX_NUMBER = 9 # the equations will have numbers from [0 - MAX_NUMBER]
EMBEDDING_SIZE = 32
CONTEXT_SIZE = len(str(MAX_NUMBER))*2 + len(str(MAX_NUMBER+MAX_NUMBER)) + 2 # context window size just big enough to always see the whole equation
BATCH_SIZE = 64
MAX_STEPS = 5000
LEARNING_RATE = 3E-4
BLOCK_COUNT = 2
NUM_HEADS = 4
DROPOUT = 0.2
HEAD_SIZE = EMBEDDING_SIZE // NUM_HEADS # How big Query, Key and Value matrices are
device = 'cuda' if torch.cuda.is_available() else "cpu"
EVAL_INTERVAL = 500
EVAL_LOSS_BATCHES = 200

this_model_name = "model_EX1.pth"

In [80]:
# Loader that returns a batch of equations: label="a+b=", target="c"
# Start with numbers 0-9 @TODO: Increase to larger numbers
def get_batch():
    a = torch.randint(0, MAX_NUMBER+1, (BATCH_SIZE, ))
    b = torch.randint(0, MAX_NUMBER+1, (BATCH_SIZE, ))

    equations = [f"{it1}+{it2}={it1+it2}" for it1, it2 in zip(a.tolist(), b.tolist())]
    equations = [eq + "."*(CONTEXT_SIZE - len(eq)) for eq in equations] # pad with "." at the end of equation to fill CONTEXT_SIZE

    x = torch.tensor([encode(eq) for eq in equations])
    y = torch.tensor([encode(eq[1:] + ".") for eq in equations])

    x, y = x.to(device), y.to(device)
    return x, y

In [82]:
# SHOW FIRST EXAMPLE OF BATCH
xb, yb = get_batch()
print(xb.shape, yb.shape)
print(xb[0], yb[0])
for i in range(CONTEXT_SIZE):
    labels = decode(xb[0][:i+1].tolist())
    target = itos[yb[0][i].item()]
    print(f"{labels} => {target}")

torch.Size([64, 6]) torch.Size([64, 6])
tensor([ 3, 10,  6, 11,  9, 12]) tensor([10,  6, 11,  9, 12, 12])
3 => +
3+ => 6
3+6 => =
3+6= => 9
3+6=9 => .
3+6=9. => .


In [25]:
""" Multiple Heads of Self-Attention that are processed in parallel """
class CausalSelfAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()

        # Single Heads in parallel
        self.query = torch.randn([num_heads, EMBEDDING_SIZE, head_size]) * 0.02
        self.key = torch.randn([num_heads, EMBEDDING_SIZE, head_size]) * 0.02
        self.value = torch.randn([num_heads, EMBEDDING_SIZE, head_size]) * 0.02

        self.dropout1 = nn.Dropout(DROPOUT)
        self.register_buffer('tril', torch.tril(torch.ones(CONTEXT_SIZE, CONTEXT_SIZE)))
        
        # Only For Multi Head
        self.proj = nn.Linear(num_heads*head_size, EMBEDDING_SIZE) # back to original size (see 3b1b Value↑ matrix)
        self.dropout2 = nn.Dropout(DROPOUT)
    
    def forward(self, x):
        n_batch, n_context, n_emb = x.shape
        num_heads, head_size = self.query.shape[0], self.query.shape[-1]

        # (num_heads, n_batch, n_context, head_size)
        q = torch.einsum('bxy,iyk->bxik', (x, self.query)).view(num_heads, n_batch, n_context, head_size)
        k = torch.einsum('bxy,iyk->bxik', (x, self.key)).view(num_heads, n_batch, n_context, head_size)
        v = torch.einsum('bxy,iyk->bxik', (x, self.value)).view(num_heads, n_batch, n_context, head_size)
        
        wei = q @ k.transpose(-2, -1) * q.shape[-1]**-0.5 # (num_heads, n_batch, n_context, n_context)
        wei = wei.masked_fill(self.tril[:n_context, :n_context] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1) # (num_heads, n_batch, n_context, n_context)
        wei = self.dropout1(wei)

        self.out = wei @ v # (num_heads, n_batch, n_context, head_size)
        self.out = self.out.view(n_batch, n_context, num_heads*head_size)
        self.out = self.dropout2(self.proj(self.out)) # (n_batch, n_context, EMBEDDING_SIZE)
        return self.out

In [26]:
class FeedForward(nn.Module):
    def __init__(self, in_feat):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_feat, in_feat * 4),
            nn.ReLU(),
            nn.Linear(4 * in_feat, in_feat),
            nn.Dropout(DROPOUT)
        )
    
    def forward(self, x):
        return self.net(x)

In [28]:
# Transformer Block: Communication (MultiHead Attention) followed by computation (MLP - FeedForward)
class Block(nn.Module):
    def __init__(self, n_heads, head_size):
        super().__init__()
        self.sa_heads = CausalSelfAttention(n_heads, head_size)
        self.ffwd = FeedForward(EMBEDDING_SIZE)

        self.ln1 = nn.LayerNorm(EMBEDDING_SIZE)
        self.ln2 = nn.LayerNorm(EMBEDDING_SIZE)
    
    def forward(self, x):
        # x + because their are residual connections around Masked Multi-Head Attention and Feed Forward (see Transformer Architecture)
        x = x + self.sa_heads(self.ln1(x)) # (BATCH_SIZE, CONTEXT_SIZE, num_heads*head_size)
        x = x + self.ffwd(self.ln2(x)) # (BATCH_SIZE, CONTEXT_SIZE, num_heads*head_size)
        return x

In [29]:
class GPT(nn.Module):
    def __init__(self):
        super().__init__()

        # add an Embedding Table for Character Embedding
        self.token_embedding_table = nn.Embedding(VOCAB_SIZE, EMBEDDING_SIZE)
        self.position_embedding_table = nn.Embedding(CONTEXT_SIZE, EMBEDDING_SIZE)
        self.blocks = nn.Sequential(*[Block(NUM_HEADS, HEAD_SIZE) for _ in range(BLOCK_COUNT)])
        self.ln_f = nn.LayerNorm(EMBEDDING_SIZE) # final layer norm
        self.lm_head = nn.Linear(EMBEDDING_SIZE, VOCAB_SIZE)

        # better initialization
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, x, y=None):
        n_batch, n_context = x.shape

        tok_emb = self.token_embedding_table(x) # (BATCH_SIZE, CONTEXT_SIZE, EMBEDDING_SIZE)
        pos_emb = self.position_embedding_table(torch.arange(0, n_context, device=device)) # position embedding for each char in CONTEXT (CONTEXT_SIZE, EMBEDDING_SIZE)
        x = tok_emb + pos_emb # (BATCH_SIZE, CONTEXT_SIZE, EMBEDDING_SIZE)
        x = self.blocks(x)
        x = self.ln_f(x) # (BATCH_SIZE, CONTEXT_SIZE, EMBEDDING_SIZE)
        logits = self.lm_head(x) # (BATCH_SIZE, CONTEXT_SIZE, VOCAB_SIZE)
        
        if y is None:
            loss = None
        else:
            logits = logits.view(n_batch*n_context, VOCAB_SIZE)
            y = y.view(n_batch*CONTEXT_SIZE)
            loss = F.cross_entropy(logits, y)

        return logits, loss
    
    def generate(self, previous_text, max_new_tokens):
        output = previous_text
        for _ in tqdm(range(max_new_tokens)):
            last_tokens = torch.tensor(encode(output[-CONTEXT_SIZE:]), device=device)
            
            # add batch dimension and feed to model
            logits, _ = self(last_tokens.view(1, -1))
            probs = F.softmax(logits, dim=-1)
            probs_next_char = probs[0, -1]
            new_char = itos[torch.multinomial(probs_next_char, num_samples=1).item()]

            output += new_char

        return output

In [30]:
# calculate mean loss for {EVAL_LOSS_BATCHES}x batches
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(EVAL_LOSS_BATCHES, device=device)
        for i in tqdm(range(EVAL_LOSS_BATCHES)):
            X, Y = get_batch(split)
            _, loss = model(X, Y)
            losses[i] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [31]:
model = GPT()

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

for step in tqdm(range(MAX_STEPS)):
    # calculate loss every once in a while
    if step % EVAL_INTERVAL == 0:
        losses = estimate_loss()
        print(f"Step {step}/{MAX_STEPS}) train: {losses['train']:.4f}, val: {losses['val']:.4f}")

    xb, yb = get_batch("train")
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

100%|██████████| 200/200 [00:04<00:00, 49.09it/s]
100%|██████████| 200/200 [00:03<00:00, 63.54it/s]
  0%|          | 4/5000 [00:07<1:58:07,  1.42s/it] 

Step 0/5000) train: 4.4307, val: 4.4299


100%|██████████| 200/200 [00:02<00:00, 66.95it/s]]
100%|██████████| 200/200 [00:03<00:00, 60.14it/s]
 10%|█         | 506/5000 [00:27<26:08,  2.86it/s]

Step 500/5000) train: 2.6391, val: 2.6246


100%|██████████| 200/200 [00:03<00:00, 55.95it/s]]
100%|██████████| 200/200 [00:03<00:00, 59.50it/s]
 20%|██        | 1005/5000 [00:49<28:28,  2.34it/s]

Step 1000/5000) train: 2.4151, val: 2.4116


100%|██████████| 200/200 [00:03<00:00, 59.41it/s]s]
100%|██████████| 200/200 [00:03<00:00, 61.39it/s]
 30%|███       | 1504/5000 [01:10<31:29,  1.85it/s]

Step 1500/5000) train: 2.3192, val: 2.2993


100%|██████████| 200/200 [00:06<00:00, 32.96it/s]s]
100%|██████████| 200/200 [00:04<00:00, 43.40it/s]
 40%|████      | 2004/5000 [01:48<58:50,  1.18s/it]  

Step 2000/5000) train: 2.2506, val: 2.2500


100%|██████████| 200/200 [00:07<00:00, 26.94it/s]s]
100%|██████████| 200/200 [00:07<00:00, 26.79it/s]s]
 50%|█████     | 2501/5000 [02:37<1:36:19,  2.31s/it]

Step 2500/5000) train: 2.2076, val: 2.2083


100%|██████████| 200/200 [00:03<00:00, 62.50it/s]s]  
100%|██████████| 200/200 [00:03<00:00, 60.29it/s]
 60%|██████    | 3006/5000 [03:01<11:54,  2.79it/s]

Step 3000/5000) train: 2.1923, val: 2.1983


100%|██████████| 200/200 [00:03<00:00, 63.92it/s]s]
100%|██████████| 200/200 [00:02<00:00, 70.99it/s]
 70%|███████   | 3504/5000 [03:24<12:42,  1.96it/s]

Step 3500/5000) train: 2.1726, val: 2.1736


100%|██████████| 200/200 [00:03<00:00, 57.80it/s]s]
100%|██████████| 200/200 [00:03<00:00, 54.01it/s]
 80%|████████  | 4004/5000 [03:48<07:11,  2.31it/s]

Step 4000/5000) train: 2.1571, val: 2.1518


100%|██████████| 200/200 [00:02<00:00, 67.03it/s]s]
100%|██████████| 200/200 [00:03<00:00, 59.63it/s]
 90%|█████████ | 4504/5000 [04:11<03:15,  2.54it/s]

Step 4500/5000) train: 2.1410, val: 2.1442


100%|██████████| 5000/5000 [04:28<00:00, 18.60it/s]


In [33]:
# Inference (Generate Harry Potter'ish text)
model.eval()
output = model.generate("\n", 1000)
print(output)


100%|██████████| 1000/1000 [00:04<00:00, 216.16it/s]


ton as oves at chat lidnize nitts be this it his. PenumnaeMe andilm cione ceeart's
Dery agearmben't mup the -fop a havet hand bardourp’t dic Verneare Andss he men the moss trely
of the whe abe a noir mugh das at havery; he spalein was bad. A blemp of Crimsed. There’s'n Harry, it she noartedortpimedfonsed, suingly one, wald Ron.
Harr; Dongefivell whis and but sils?”
"Sid nus ove
wor
bibe Prublet in. S, saigho bloplayblainbreathes; so?A pepthe boorey cumbrice wroundlou coven, jabcenuc of as closing arry. They edorich romts,beampolly and
the, leired geee Maen he mi,
sMe thand he thered's. Wealld yowad lik awed you're said notathe a
got mapt, his beds, wexpblatathand the stalfoake coking the owakerstiss,
lition at at omian. The gaming of and wharme.. Harry bund I to chiy offfotter.
This te whe; yon to Me waslos uing k.” . ., Bec, harry pead the theing it,” said Sno?”
"Dunat and Tung couldtorned Mlyowly. I
pide said
HoAI geced for wich seaimsbelted exere a He fesiled andy”
“Nast wald parr


