In [42]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm
import re

In [43]:
# Create Vocabulary
chars = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "=", "."]

# Tokenization
stoi = {s:i for i, s in enumerate(chars)}
itos = {i:s for s, i in stoi.items()}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [44]:
# Hyperparameters
VOCAB_SIZE = len(chars)
MAX_NUMBER = 9 # the equations will have numbers from [0 - MAX_NUMBER]
EMBEDDING_SIZE = 16
CONTEXT_SIZE = len(str(MAX_NUMBER))*2 + len(str(MAX_NUMBER+MAX_NUMBER)) + 2 # context window size just big enough to always see the whole equation
BATCH_SIZE = 64
MAX_STEPS = 5000
LEARNING_RATE = 3E-4
BLOCK_COUNT = 5
NUM_HEADS = 6
DROPOUT = 0.2
HEAD_SIZE = EMBEDDING_SIZE // NUM_HEADS # How big Query, Key and Value matrices are
device = 'cuda' if torch.cuda.is_available() else "cpu"
EVAL_INTERVAL = 500
EVAL_LOSS_BATCHES = 200

this_model_name = "model_EX1.pth"

In [45]:
# Loader that returns a batch of equations: label="a+b=", target="c"
# Start with numbers 0-9 @TODO: Increase to larger numbers
def get_batch():
    a = torch.randint(0, MAX_NUMBER+1, (BATCH_SIZE, ))
    b = torch.randint(0, MAX_NUMBER+1, (BATCH_SIZE, ))

    equations = [f"{it1}+{it2}={it1+it2}" for it1, it2 in zip(a.tolist(), b.tolist())]
    equations = [eq + "."*(CONTEXT_SIZE - len(eq)) for eq in equations] # pad with "." at the end of equation to fill CONTEXT_SIZE

    x = torch.tensor([encode(eq) for eq in equations])
    y = torch.tensor([encode(eq[1:] + ".") for eq in equations])

    x, y = x.to(device), y.to(device)
    return x, y

In [46]:
# SHOW FIRST EXAMPLE OF BATCH
xb, yb = get_batch()
print(xb.shape, yb.shape)
print(xb[0], yb[0])
for i in range(CONTEXT_SIZE):
    labels = decode(xb[0][:i+1].tolist())
    target = itos[yb[0][i].item()]
    print(f'"{labels}" => "{target}"')

torch.Size([64, 6]) torch.Size([64, 6])
tensor([ 0, 10,  7, 11,  7, 12]) tensor([10,  7, 11,  7, 12, 12])
"0" => "+"
"0+" => "7"
"0+7" => "="
"0+7=" => "7"
"0+7=7" => "."
"0+7=7." => "."


In [47]:
""" Multiple Heads of Self-Attention that are processed in parallel """
class CausalSelfAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()

        # Single Heads in parallel
        self.query = torch.randn([num_heads, EMBEDDING_SIZE, head_size]) * 0.02
        self.key = torch.randn([num_heads, EMBEDDING_SIZE, head_size]) * 0.02
        self.value = torch.randn([num_heads, EMBEDDING_SIZE, head_size]) * 0.02

        self.dropout1 = nn.Dropout(DROPOUT)
        self.register_buffer('tril', torch.tril(torch.ones(CONTEXT_SIZE, CONTEXT_SIZE)))
        
        # Only For Multi Head
        self.proj = nn.Linear(num_heads*head_size, EMBEDDING_SIZE) # back to original size (see 3b1b Value↑ matrix)
        self.dropout2 = nn.Dropout(DROPOUT)
    
    def forward(self, x):
        n_batch, n_context, n_emb = x.shape
        num_heads, head_size = self.query.shape[0], self.query.shape[-1]

        # (num_heads, n_batch, n_context, head_size)
        q = torch.einsum('bxy,iyk->bxik', (x, self.query)).view(num_heads, n_batch, n_context, head_size)
        k = torch.einsum('bxy,iyk->bxik', (x, self.key)).view(num_heads, n_batch, n_context, head_size)
        v = torch.einsum('bxy,iyk->bxik', (x, self.value)).view(num_heads, n_batch, n_context, head_size)
        
        wei = q @ k.transpose(-2, -1) * q.shape[-1]**-0.5 # (num_heads, n_batch, n_context, n_context)
        wei = wei.masked_fill(self.tril[:n_context, :n_context] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1) # (num_heads, n_batch, n_context, n_context)
        wei = self.dropout1(wei)

        self.out = wei @ v # (num_heads, n_batch, n_context, head_size)
        self.out = self.out.view(n_batch, n_context, num_heads*head_size)
        self.out = self.dropout2(self.proj(self.out)) # (n_batch, n_context, EMBEDDING_SIZE)
        return self.out

In [48]:
class FeedForward(nn.Module):
    def __init__(self, in_feat):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_feat, in_feat * 4),
            nn.ReLU(),
            nn.Linear(4 * in_feat, in_feat),
            nn.Dropout(DROPOUT)
        )
    
    def forward(self, x):
        return self.net(x)

In [49]:
# Transformer Block: Communication (MultiHead Attention) followed by computation (MLP - FeedForward)
class Block(nn.Module):
    def __init__(self, n_heads, head_size):
        super().__init__()
        self.sa_heads = CausalSelfAttention(n_heads, head_size)
        self.ffwd = FeedForward(EMBEDDING_SIZE)

        self.ln1 = nn.LayerNorm(EMBEDDING_SIZE)
        self.ln2 = nn.LayerNorm(EMBEDDING_SIZE)
    
    def forward(self, x):
        # x + because their are residual connections around Masked Multi-Head Attention and Feed Forward (see Transformer Architecture)
        x = x + self.sa_heads(self.ln1(x)) # (BATCH_SIZE, CONTEXT_SIZE, num_heads*head_size)
        x = x + self.ffwd(self.ln2(x)) # (BATCH_SIZE, CONTEXT_SIZE, num_heads*head_size)
        return x

In [50]:
"""
Masking out tokens from input tokens a+b= for Loss Calculation.
The model should not receive a loss for mispredicting the next token before the "="(stoi["="]=11) sign, since it can only guess here.
Achieve by setting logits to the targets for input tokens.

"8" => "+"      | IGNORE
"8+" => "2"     | IGNORE
"8+2" => "="    | IGNORE
"8+2=" => "1"   | USE
"8+2=1" => "0"  | USE
"8+2=10" => "." | USE
"""

xb, yb = get_batch()
logits = torch.randn((BATCH_SIZE, CONTEXT_SIZE, VOCAB_SIZE))

# Step 1) Create correct logits from targets
correct_logits = torch.zeros_like(logits).scatter_(2, yb.unsqueeze(2), 1)
correct_logits[correct_logits == 0] = float('-inf')

# Step 2) For each item in batch, find out at which index in Context the "=" is
equal_idx = (yb == stoi["="]).nonzero()[:, 1]

# Step 3) Replace logits up until "=" with correct_logits without using a loop
mask = torch.arange(CONTEXT_SIZE).unsqueeze(0) <= equal_idx.unsqueeze(1)
logits = torch.where(mask.unsqueeze(2), correct_logits, logits)

logits[0], yb[0]

(tensor([[   -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
             -inf,    -inf,  1.0000,    -inf,    -inf],
         [ 1.0000,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
             -inf,    -inf,    -inf,    -inf,    -inf],
         [   -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
             -inf,    -inf,    -inf,  1.0000,    -inf],
         [-2.1674,  1.9481,  0.2714,  0.0994, -0.5013, -2.3173, -0.6925,  0.1254,
           1.2330,  0.4711,  0.9592,  0.3126, -0.2099],
         [ 0.4638,  0.3835,  0.3073,  0.4802, -0.1251,  1.0787,  0.3154, -0.9656,
          -1.3152,  0.9500,  0.2395,  0.8869,  0.2441],
         [-0.3549, -1.0478, -3.2440,  1.1116,  0.1357, -1.0144, -1.3654,  0.9175,
           1.2289,  0.3988,  0.8242,  1.7663,  0.0610]]),
 tensor([10,  0, 11,  3, 12, 12]))

In [51]:
class GPT(nn.Module):
    def __init__(self):
        super().__init__()

        # add an Embedding Table for Character Embedding
        self.token_embedding_table = nn.Embedding(VOCAB_SIZE, EMBEDDING_SIZE)
        self.position_embedding_table = nn.Embedding(CONTEXT_SIZE, EMBEDDING_SIZE)
        self.blocks = nn.Sequential(*[Block(NUM_HEADS, HEAD_SIZE) for _ in range(BLOCK_COUNT)])
        self.ln_f = nn.LayerNorm(EMBEDDING_SIZE) # final layer norm
        self.lm_head = nn.Linear(EMBEDDING_SIZE, VOCAB_SIZE)

        # better initialization
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, x, y=None):
        n_batch, n_context = x.shape

        tok_emb = self.token_embedding_table(x) # (BATCH_SIZE, CONTEXT_SIZE, EMBEDDING_SIZE)
        pos_emb = self.position_embedding_table(torch.arange(0, n_context, device=device)) # position embedding for each char in CONTEXT (CONTEXT_SIZE, EMBEDDING_SIZE)
        x = tok_emb + pos_emb # (BATCH_SIZE, CONTEXT_SIZE, EMBEDDING_SIZE)
        x = self.blocks(x)
        x = self.ln_f(x) # (BATCH_SIZE, CONTEXT_SIZE, EMBEDDING_SIZE)
        logits = self.lm_head(x) # (BATCH_SIZE, CONTEXT_SIZE, VOCAB_SIZE)
        
        if y is None:
            loss = None
        else:
            # all tokens before '=' should not have impact on loss since model can only guess here
            logits = self._mask_out_input_tokens(logits, y)

            logits = logits.view(n_batch*n_context, VOCAB_SIZE)
            y = y.view(n_batch*CONTEXT_SIZE)

            loss = F.cross_entropy(logits, y)

        return logits, loss
    
    # See Code Cell above for more detailed explanation
    def _mask_out_input_tokens(self, logits, y):
        correct_logits = torch.zeros_like(logits).scatter_(2, y.unsqueeze(2), 1)
        correct_logits[correct_logits == 0] = float('-inf')
        equal_idx = (y == stoi["="]).nonzero()[:, 1]
        mask = torch.arange(CONTEXT_SIZE).unsqueeze(0) <= equal_idx.unsqueeze(1)
        logits = torch.where(mask.unsqueeze(2), correct_logits, logits)
        return logits

    def calculate(self, equation):
        assert isinstance(equation, str), "The variable 'equation' must be a string"
        assert re.match(r'^\d+\+\d+=', equation), "Equation must be of shape 'a+b='"
        a, b = equation[:-1].split("+") 
        assert 0 <= int(a) <= MAX_NUMBER and 0 <= int(b) <= MAX_NUMBER, f"The variables must be in [0, {MAX_NUMBER}]"
        
        output = torch.tensor(encode(equation), device=device)
        while output[-1] != stoi["."] and len(output) <= CONTEXT_SIZE:
            try:
                logits, _ = self(output.view(1, -1))
            except IndexError as err:
                print(output.tolist())
                
            probs = F.softmax(logits, dim=-1)
            next_digit = torch.multinomial(probs[0, -1], num_samples=1)
            output = torch.cat((output, next_digit))

        if output[-1] != stoi["."]:
            return equation, -1

        equation_answer = decode(output[:-1].tolist())
        answer = equation_answer.replace(equation, "")
        answer = int(answer) if answer.isdigit() else -1
        
        return equation, answer

In [52]:
# calculate mean loss for {EVAL_LOSS_BATCHES}x batches
@torch.no_grad()
def estimate_loss(model):
    model.eval()
    losses = torch.zeros(EVAL_LOSS_BATCHES, device=device)
    for i in range(EVAL_LOSS_BATCHES):
        X, Y = get_batch()
        _, loss = model(X, Y)
        losses[i] = loss.item()
    out = losses.mean()
    model.train()
    return out

In [53]:
model = GPT()

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

for step in tqdm(range(MAX_STEPS)):
    # calculate loss every once in a while
    if step % EVAL_INTERVAL == 0:
        loss = estimate_loss(model)
        print(f"Step {step}/{MAX_STEPS}) Loss: {loss:.4f}")

    xb, yb = get_batch()
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

  0%|          | 8/5000 [00:01<10:34,  7.87it/s]  

Step 0/5000) Loss: 1.2810


 10%|█         | 510/5000 [00:11<03:50, 19.45it/s]

Step 500/5000) Loss: 0.5256


 20%|██        | 1006/5000 [00:23<05:28, 12.17it/s]

Step 1000/5000) Loss: 0.4830


 30%|███       | 1506/5000 [00:36<04:34, 12.74it/s]

Step 1500/5000) Loss: 0.4741


 40%|████      | 2006/5000 [00:48<03:51, 12.92it/s]

Step 2000/5000) Loss: 0.4700


 50%|█████     | 2508/5000 [01:00<03:19, 12.52it/s]

Step 2500/5000) Loss: 0.4678


 60%|██████    | 3009/5000 [01:13<02:36, 12.76it/s]

Step 3000/5000) Loss: 0.4682


 70%|███████   | 3505/5000 [01:25<02:30,  9.92it/s]

Step 3500/5000) Loss: 0.4675


 80%|████████  | 4007/5000 [01:38<01:13, 13.43it/s]

Step 4000/5000) Loss: 0.4662


 90%|█████████ | 4506/5000 [01:50<00:40, 12.31it/s]

Step 4500/5000) Loss: 0.4669


100%|██████████| 5000/5000 [02:02<00:00, 40.78it/s]


In [54]:
# Inference (solve simple Addition Equations)
model.eval()
output = model.calculate("5+0=")
output


('5+0=', 8)

In [60]:
# Test Accuaracy
correct = 0
max = 1000
for _ in tqdm(range(max)):
    a, b = torch.randint(0, MAX_NUMBER, (2, )).tolist()
    eq, c = model.calculate(f"{a}+{b}=")
        
    if c == a+b:
        correct += 1

f"Accuracy: {round(correct/max*100, 2)}%"

100%|██████████| 1000/1000 [00:08<00:00, 117.62it/s]


'Accuracy: 7.1%'

## Why is Loss so good and Accuracy so low??
The Loss only seems really good, since all the input tokens ("a+b=") are always copied from the targets. This means the model seems to have guessed most values right, although they are only copied and overwritten. In Conclusion: This model still sucks at arithmetic!