In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

EX2: Train the GPT on your own dataset of choice! What other data could be fun to blabber on about? (A fun advanced suggestion if you like: train a GPT to do addition of two numbers, i.e. a+b=c. You may find it helpful to predict the digits of c in reverse order, as the typical addition algorithm (that you're hoping it learns) would proceed right to left too. You may want to modify the data loader to simply serve random problems and skip the generation of train.bin, val.bin. You may want to mask out the loss at the input positions of a+b that just specify the problem using y=-1 in the targets (see CrossEntropyLoss ignore_index). Does your Transformer learn to add? Once you have this, swole doge project: build a calculator clone in GPT, for all of +-*/. Not an easy problem. You may need Chain of Thought traces.)

Note: I only focused on teaching a Transformer how to a) add (w.o COT) and another Transformer how to b) multiply (w/ and w.o COT traces)

This Paper, Teaching Arithmetic to Small Transformers(https://www.google.com/url?sa=t&source=web&rct=j&opi=89978449&url=https://arxiv.org/abs/2307.03381&ved=2ahUKEwjtsdeK7MKHAxX_4skDHXZZPI8QFnoECAEQAQ&usg=AOvVaw2vTw_O8bIV2L-6v0GvzuZT) was
very important to solving the below excercises


Part a: Addition

One thing which can be noticed below is that I have seperated out the dataset into 
 - addition of single digit numbers
 - addition of double digit and single digit numbers
 - addition of double digit numbers
 - addition of three digit and double digit or single digit numbers
 - addition of three digit numbers


 I did so, such that when randomly sampling training instances (I.e a total of 10000, I can ensure that my training set has a good coverage of the potentially different set of problems)


In [3]:
""" Create ADD Math Dataset """
math_single_dataset = []
nums = torch.arange(0, 10)
for num1 in nums:
    for num2 in nums:
        out = num1 + num2
        math_single_dataset.append(f"{str(num1.item()).zfill(3)}+{str(num2.item()).zfill(3)}={str(out.item()).zfill(4)[::-1]}.")

math_double_dataset_mixed = []
nums = torch.arange(0, 99)
for num1 in nums:
    for num2 in nums[:10]:
        out = num1 + num2
        math_double_dataset_mixed.append(f"{str(num1.item()).zfill(3)}+{str(num2.item()).zfill(3)}={str(out.item()).zfill(4)[::-1]}.")

math_double_dataset_exclusive = []
nums = torch.arange(10, 99)
for num1 in nums:
    for num2 in nums:
        out = num1 + num2
        math_double_dataset_exclusive.append(f"{str(num1.item()).zfill(3)}+{str(num2.item()).zfill(3)}={str(out.item()).zfill(4)[::-1]}.")

math_triple_dataset_mixed = []
nums = torch.arange(0, 999)
for num1 in nums:
    for num2 in nums[:99]:
        out = num1 + num2
        math_triple_dataset_mixed.append(f"{str(num1.item()).zfill(3)}+{str(num2.item()).zfill(3)}={str(out.item()).zfill(4)[::-1]}.")

math_triple_dataset_exclusive = []
nums = torch.arange(100, 999)
for num1 in nums:
    for num2 in nums:
        out = num1 + num2
        math_triple_dataset_exclusive.append(f"{str(num1.item()).zfill(3)}+{str(num2.item()).zfill(3)}={str(out.item()).zfill(4)[::-1]}.")

In [4]:
print(len(math_single_dataset),len(math_double_dataset_mixed),len(math_double_dataset_exclusive),len(math_triple_dataset_mixed),len(math_triple_dataset_exclusive))

100 990 7921 98901 808201


In [5]:
import random
random.seed(42)
random.shuffle(math_single_dataset)
random.shuffle(math_double_dataset_mixed)
random.shuffle(math_triple_dataset_mixed)
random.shuffle(math_double_dataset_exclusive)
random.shuffle(math_triple_dataset_exclusive)

math_dataset = math_single_dataset + math_double_dataset_mixed[:300] + math_double_dataset_mixed[:600] + math_triple_dataset_mixed[:2000] +  math_triple_dataset_exclusive[:7000]
random.shuffle(math_dataset)

In [6]:
math_vocab = [".", "+", "=", "P"] + [str(num) for num in torch.arange(0, 10).tolist()]
math_stoi = {char:idx for idx, char in enumerate(math_vocab)}
math_itos = {idx:char for char, idx in math_stoi.items()}
math_encode = lambda math_eq: [math_stoi[char] for char in math_eq]
math_decode = lambda math_idx: "".join([math_itos[idx] for idx in math_idx])

In [8]:
math_dataset = [math_encode(eq) for eq in math_dataset]

In [11]:
print(math_dataset[0])
print(math_decode(math_dataset[0]))

#Remember output digits are reversed, more inline with how humans tackle addition

[12, 11, 11, 1, 10, 9, 8, 2, 5, 7, 9, 5, 0]
877+654=1351.


In [12]:
math_stoi

{'.': 0,
 '+': 1,
 '=': 2,
 'P': 3,
 '0': 4,
 '1': 5,
 '2': 6,
 '3': 7,
 '4': 8,
 '5': 9,
 '6': 10,
 '7': 11,
 '8': 12,
 '9': 13}

In [13]:
""" Global Variables """
block_size = 32
batch_size = 16
embed_dim = 64
num_heads = 4
num_blocks = 4
lr = 1e-3
epochs = 5000

In [16]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(device)

mps


In [14]:
def get_batch(split):
    if split == "train":
        idxs = torch.randint(0, int(len(math_dataset)*0.9) - block_size, (batch_size,))
    else:
        idxs = torch.randint(int(len(math_dataset)*0.9), len(math_dataset) - block_size, (batch_size,))
    X = []
    Y = []
    for idx in idxs:
        x = math_dataset[idx] + [math_stoi["P"]]*block_size #Padding to allow for batching addition examples
        X.append(x[:block_size])
        ignore1_idx = torch.argwhere(torch.tensor(x[:block_size]) == 2)[0][0].item() 
        ignore2_idx = torch.argwhere(torch.tensor(x[:block_size]) == 3)[0][0].item()
        Y.append([-1]*ignore1_idx + x[ignore1_idx+1:ignore2_idx] + [-1]*(len(x[:block_size]) - ignore2_idx + 1)) #Adding target = -1 for any thing up to or before = and anything after the stop token . (so for the padding "P")
    X = torch.tensor(X).to(torch.long)
    Y = torch.tensor(Y)
    return X, Y

In [98]:
""" Create Model """

class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, num_heads, embed_dim, block_size):
        super().__init__()
        self.num_heads = num_heads
        self.proj_dim = int(embed_dim/num_heads)
        self.w_q = nn.Linear(embed_dim, num_heads*self.proj_dim, bias = False)
        self.w_k = nn.Linear(embed_dim, num_heads*self.proj_dim, bias = False)
        self.w_v = nn.Linear(embed_dim, num_heads*self.proj_dim, bias = False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size).to(device)))
        self.dp = nn.Dropout(0.0)
        self.out_linear = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, input):
        B, T, C = input.shape #C == embed_dim
        query = self.w_q(input).view(B, T, self.num_heads, self.proj_dim) #B, T, H, proj_dim
        key = self.w_k(input).view(B, T, self.num_heads, self.proj_dim) #B, T, H proj_dim
        wei = (query.permute(0,2,1,3) @ key.permute(0,2,3,1))*(C**-0.5) #B, H, T,T
        wei = wei.masked_fill(self.tril[:T,:T] == 0, float("-inf")) #Note: subset the mask self.tril[:T,:T] in case input sequence is less then block_size; the mask needs to be broadcastable with wei
        wei = F.softmax(wei, dim = -1)

        value = self.w_v(input).view(B, T, self.num_heads, self.proj_dim) #B, T, H, proj_dim
        
        out = wei @ value.permute(0,2,1,3) #B, H, T, proj_dim
        out = out.permute(0,2,1,3).contiguous().view(B,T,C) #B, T, C
        out = self.out_linear(out)  #B, T, C
        return self.dp(out)
    
class FeedForward(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.ffn = nn.Sequential(nn.Linear(embed_dim, embed_dim*4), nn.ReLU(), nn.Linear(embed_dim*4, embed_dim))
        self.dp = nn.Dropout(0.0)

    def forward(self, input):
        out = self.ffn(input)
        return self.dp(out)
    

class Block(nn.Module):
    def __init__(self, num_heads, embed_dim, block_size):
        super().__init__()
        self.attention = MaskedMultiHeadAttention(num_heads, embed_dim, block_size)
        self.ffn = FeedForward(embed_dim)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

    def forward(self, input):
        x = self.attention(self.ln1(input)) + input
        out = self.ffn(self.ln2(x)) + x
        return out
    

class Transformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, block_size, num_heads, num_blocks):
        super().__init__()
        self.content_embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(block_size, embed_dim)
        self.blocks = nn.ModuleList([Block(num_heads, embed_dim, block_size) for _ in range(num_blocks)])
        self.output = nn.Linear(embed_dim, vocab_size)
        self.dp = nn.Dropout(0.0)

    def forward(self, input):
        B, T = input.shape
        con_embed = self.content_embedding(input) #B,T,embed_dim
        pos_embed = self.position_embedding(torch.arange(T, device = device)) #1,T,embed_dim
        x = con_embed + pos_embed #B,T,embed_dim
        x = self.dp(x)

        for block in self.blocks:
            x = block(x) #B,T,embed_dim
        out = self.output(x)
        return out

In [18]:
model = Transformer(len(math_vocab), embed_dim, block_size, num_heads, num_blocks)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr = lr)

In [19]:
@torch.no_grad()
def estimate_loss(eval_iters):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            X,Y = X.to(device), Y.to(device)
            logits  = model(X)
            loss = F.cross_entropy(logits.permute(0, 2, 1), Y, ignore_index = -1)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [20]:
for i in range(epochs):
    x,y = get_batch('train')
    x,y = x.to(device), y.to(device)
    logits = model(x)
    loss = F.cross_entropy(logits.permute(0, 2, 1), y, ignore_index = -1)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i % 100) == 0:
        estimated_losses = estimate_loss(200)
        print(f"Iteration {i}, Train Loss {estimated_losses['train']}, Val Loss {estimated_losses['val']}")

Iteration 0, Train Loss 2.662686824798584, Val Loss 2.6516246795654297
Iteration 100, Train Loss 1.387786865234375, Val Loss 1.3856680393218994
Iteration 200, Train Loss 1.361119270324707, Val Loss 1.3627911806106567
Iteration 300, Train Loss 1.3566230535507202, Val Loss 1.3616020679473877
Iteration 400, Train Loss 1.3193213939666748, Val Loss 1.3247090578079224
Iteration 500, Train Loss 1.0440106391906738, Val Loss 1.010298252105713
Iteration 600, Train Loss 0.5419133901596069, Val Loss 0.5412693619728088
Iteration 700, Train Loss 0.2846592664718628, Val Loss 0.28586316108703613
Iteration 800, Train Loss 0.1886684000492096, Val Loss 0.1880323737859726
Iteration 900, Train Loss 0.09563055634498596, Val Loss 0.0948578342795372
Iteration 1000, Train Loss 0.012534172274172306, Val Loss 0.012815193273127079
Iteration 1100, Train Loss 0.06405463069677353, Val Loss 0.05420443415641785
Iteration 1200, Train Loss 0.07592733204364777, Val Loss 0.07644709944725037
Iteration 1300, Train Loss 0.02

In [30]:
"Function to use Transformer as an addition calculator"

@torch.no_grad()
def addition_calculator(num1,num2):
    pred_idx = 2
    encoded_input = math_encode(f"{str(num1).zfill(3)}+{str(num2).zfill(3)}=")
    decoded_output = []
    while pred_idx != 0: #terminate at stop token
        transformer_input = torch.tensor(encoded_input).view(1,-1)
        model.eval()
        output = model(transformer_input.to(device))
        model.train()
        pred_idx = torch.argmax(output[0,-1,:]).item()
        encoded_input += [pred_idx]
        decoded_output += [pred_idx]
    return int(math_decode(decoded_output[::-1][1:]))

In [34]:
addition_calculator(245, 63) == 245 + 63

True

In [26]:
from tqdm import tqdm

In [27]:
"Testing Performance on 1000 trained instances"
seen_examples = math_double_dataset_mixed[:600] + math_triple_dataset_mixed[:2000] +  math_triple_dataset_exclusive[:7000]
random.seed(42)
random.shuffle(seen_examples)

amount_correct = 0
for ex in tqdm(seen_examples[:1000]):
    num1, num2 = ex.split("=")[0].split("+")
    transformer_output = int(addition_calculator(num1, num2))
    actual_output = int(num1) + int(num2)
    if actual_output == transformer_output:
        amount_correct += 1

print(amount_correct/1000)


100%|██████████| 1000/1000 [01:43<00:00,  9.68it/s]

1.0





In [28]:
"Testing Performance on 1000 Test instances"
unseen_examples = math_double_dataset_mixed[600:] + math_triple_dataset_mixed[2000:6000] +  math_triple_dataset_exclusive[7000:11000]
random.seed(42)
random.shuffle(unseen_examples)


amount_correct = 0
for ex in tqdm(unseen_examples[:1000]):
    num1, num2 = ex.split("=")[0].split("+")
    transformer_output = int(addition_calculator(num1, num2))
    actual_output = int(num1) + int(num2)
    if actual_output == transformer_output:
        amount_correct += 1

print(amount_correct/1000)

100%|██████████| 1000/1000 [01:27<00:00, 11.40it/s]

1.0





An interesting bug is that in order to get correct output need number with more digits first in the expression; due to how it was trained


Aside from that we can see that the Transformer has 100% accuracy in Addition

In [35]:
245 + 63

308

In [36]:
addition_calculator(245,63)

308

In [37]:
addition_calculator(63,245) #BUG which was higlighted above, due to hwo it was trained

208

Part b: Multiply

First lets redo exactly as above but with multiplication

In [99]:
""" Create MUL Math Dataset """
math_single_dataset = []
nums = torch.arange(0, 10)
for num1 in nums:
    for num2 in nums:
        out = num1 * num2
        math_single_dataset.append(f"{str(num1.item()).zfill(3)}*{str(num2.item()).zfill(3)}={str(out.item()).zfill(6)[::-1]}.")

math_double_dataset_mixed = []
nums = torch.arange(0, 99)
for num1 in nums:
    for num2 in nums[:10]:
        out = num1 * num2
        math_double_dataset_mixed.append(f"{str(num1.item()).zfill(3)}*{str(num2.item()).zfill(3)}={str(out.item()).zfill(6)[::-1]}.")

math_double_dataset_exclusive = []
nums = torch.arange(10, 99)
for num1 in nums:
    for num2 in nums:
        out = num1 * num2
        math_double_dataset_exclusive.append(f"{str(num1.item()).zfill(3)}*{str(num2.item()).zfill(3)}={str(out.item()).zfill(6)[::-1]}.")

math_triple_dataset_mixed = []
nums = torch.arange(0, 999)
for num1 in nums:
    for num2 in nums[:99]:
        out = num1 * num2
        math_triple_dataset_mixed.append(f"{str(num1.item()).zfill(3)}*{str(num2.item()).zfill(3)}={str(out.item()).zfill(6)[::-1]}.")

math_triple_dataset_exclusive = []
nums = torch.arange(100, 999)
for num1 in nums:
    for num2 in nums:
        out = num1 * num2
        math_triple_dataset_exclusive.append(f"{str(num1.item()).zfill(3)}*{str(num2.item()).zfill(3)}={str(out.item()).zfill(6)[::-1]}.")


import random
random.seed(42)
random.shuffle(math_single_dataset)
random.shuffle(math_double_dataset_mixed)
random.shuffle(math_triple_dataset_mixed)
random.shuffle(math_double_dataset_exclusive)
random.shuffle(math_triple_dataset_exclusive)

math_dataset = math_single_dataset + math_double_dataset_mixed[:300] + math_double_dataset_mixed[:600] + math_triple_dataset_mixed[:2000] +  math_triple_dataset_exclusive[:7000]
random.shuffle(math_dataset)

math_vocab = [".", "*", "=", "P"] + [str(num) for num in torch.arange(0, 10).tolist()]
math_stoi = {char:idx for idx, char in enumerate(math_vocab)}
math_itos = {idx:char for char, idx in math_stoi.items()}
math_encode = lambda math_eq: [math_stoi[char] for char in math_eq]
math_decode = lambda math_idx: "".join([math_itos[idx] for idx in math_idx])

math_dataset = [math_encode(eq) for eq in math_dataset]

def get_batch(split):
    if split == "train":
        idxs = torch.randint(0, int(len(math_dataset)*0.9) - block_size, (batch_size,))
    else:
        idxs = torch.randint(int(len(math_dataset)*0.9), len(math_dataset) - block_size, (batch_size,))
    X = []
    Y = []
    for idx in idxs:
        x = math_dataset[idx] + [math_stoi["P"]]*block_size 
        X.append(x[:block_size])
        ignore1_idx = torch.argwhere(torch.tensor(x[:block_size]) == 2)[0][0].item()
        ignore2_idx = torch.argwhere(torch.tensor(x[:block_size]) == 3)[0][0].item()
        Y.append([-1]*ignore1_idx + x[ignore1_idx+1:ignore2_idx] + [-1]*(len(x[:block_size]) - ignore2_idx + 1))
    X = torch.tensor(X).to(torch.long)
    Y = torch.tensor(Y)
    return X, Y

In [100]:
print(math_dataset[0])
print(math_decode(math_dataset[0]))
#digits reversed for output just like how it was done for addition

[12, 11, 11, 1, 10, 9, 8, 2, 12, 9, 9, 7, 11, 9, 0]
877*654=855375.


In [101]:
block_size = 253 #Bug created when training prior, but since model was trained for hours probs fine to keep
model = Transformer(len(math_vocab), embed_dim, block_size, num_heads, num_blocks)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr = lr)

In [102]:
@torch.no_grad()
def estimate_loss(eval_iters):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            X,Y = X.to(device), Y.to(device)
            logits  = model(X)
            loss = F.cross_entropy(logits.permute(0, 2, 1), Y, ignore_index = -1)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

So I trained the model in another file for good number of hours, so I'll import those weights and show it can train here for a few epochs to higlight the code work

In [103]:
model.load_state_dict(torch.load("/Users/bhavverma/Documents/karpathy_notes_and_solutions/video_7_dependencies/rev_model_weights.pt"))

<All keys matched successfully>

In [104]:
for i in range(500):
    x,y = get_batch('train')
    x,y = x.to(device), y.to(device)
    logits = model(x)
    loss = F.cross_entropy(logits.permute(0, 2, 1), y, ignore_index = -1)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i % 100) == 0:
        estimated_losses = estimate_loss(200)
        print(f"Iteration {i}, Train Loss {estimated_losses['train']}, Val Loss {estimated_losses['val']}")

Iteration 0, Train Loss 0.6741361021995544, Val Loss 1.2719918489456177
Iteration 100, Train Loss 0.6755880117416382, Val Loss 1.2324095964431763
Iteration 200, Train Loss 0.6660304069519043, Val Loss 1.2309340238571167
Iteration 300, Train Loss 0.6705098152160645, Val Loss 1.2488014698028564
Iteration 400, Train Loss 0.681050181388855, Val Loss 1.281855821609497


In [105]:
@torch.no_grad()
def mul_calculator(num1,num2):
    pred_idx = 2
    encoded_input = math_encode(f"{str(num1).zfill(3)}*{str(num2).zfill(3)}=")
    decoded_output = []
    while pred_idx != 0:
        transformer_input = torch.tensor(encoded_input).view(1,-1)
        model.eval()
        output = model(transformer_input.to(device))
        model.train()
        pred_idx = torch.argmax(output[0,-1,:]).item()
        encoded_input += [pred_idx]
        decoded_output += [pred_idx]
    return int(math_decode(decoded_output[::-1][1:]))

In [106]:
mul_calculator(9,2)

18

In [107]:
"Testing Performance on 1000 trained instances"
seen_examples = math_double_dataset_mixed[:600] + math_triple_dataset_mixed[:2000] +  math_triple_dataset_exclusive[:7000]
random.seed(42)
random.shuffle(seen_examples)

print(seen_examples[:10])

amount_correct = 0
for ex in tqdm(seen_examples[:1000]):
    num1, num2 = ex.split("=")[0].split("*")
    transformer_output = int(mul_calculator(num1, num2))
    actual_output = int(num1) * int(num2)
    if actual_output == transformer_output:
        amount_correct += 1

print(amount_correct/1000)

['325*925=526003.', '938*943=435488.', '910*201=019281.', '906*477=261234.', '995*033=538230.', '356*342=257121.', '462*384=804771.', '722*699=876405.', '935*419=567193.', '074*004=692000.']


100%|██████████| 1000/1000 [01:15<00:00, 13.19it/s]

0.131





In [108]:
"Testing Performance on 1000 Train instances"
unseen_examples = math_double_dataset_mixed[600:] + math_triple_dataset_mixed[2000:6000] +  math_triple_dataset_exclusive[7000:11000]
random.seed(42)
random.shuffle(unseen_examples)

print(unseen_examples[:10])

amount_correct = 0
for ex in tqdm(unseen_examples[:1000]):
    num1, num2 = ex.split("=")[0].split("*")
    transformer_output = int(mul_calculator(num1, num2))
    actual_output = int(num1) * int(num2)
    if actual_output == transformer_output:
        amount_correct += 1

print(amount_correct/1000)

['097*057=925500.', '094*005=074000.', '053*077=180400.', '925*874=054808.', '472*336=295851.', '826*016=612310.', '590*593=078943.', '655*928=048706.', '826*489=419304.', '879*105=592290.']


  0%|          | 2/1000 [00:00<03:05,  5.37it/s]

100%|██████████| 1000/1000 [01:13<00:00, 13.63it/s]

0.04





We can clearly see that using the same method for addition does not work multiplication, it gives a horrible result. Instead lets aim to use chain of thought traces, where each training instance, is more like a step by step solution to a multiplication problem

In [55]:
""" Create Chain of Thought Template """

def mul_cot_format(num1, num2):
    str_num1, str_num2 = str(num1), str(num2)

    a_1 = num1*int(str_num2.zfill(3)[-1])
    b_1 = a_1
    c_1 = b_1

    a_2 = num1*int(str_num2.zfill(3)[-2])
    b_2 = a_2*10
    c_2 = c_1 + b_2

    a_3 = num1*int(str_num2.zfill(3)[-3])
    b_3 = a_3*100
    c_3 = c_2 + b_3

    a = [a_1, a_2, a_3]
    b = [b_1, b_2, b_3]
    c = [c_1, c_2, c_3]

    k = [1,10,100]

    sol = str(num1*num2)

    input_str = f"In:{str_num1}*{str_num2}\n"
    total_target = [f"Target:[{','.join(str_num1)}]has{len(str_num1)}digits.[{','.join(str_num2)}]has{len(str_num2)}digits."]
    for i in range(len(str_num2)):
        past_c = 0 if i == 0 else c[i-1]
        total_target.append(f"[{','.join(str_num1)}]*{str_num2[::-1][i]},A=[{','.join(str(a[i]).zfill(len(str_num1)))}],k={k[i]},B=[{','.join(str(b[i]).zfill(len(str_num1)+i))}],C={str(past_c)}+{str(b[i])}={str(c[i])}" + '\n')
    total_target.append(f",END{''.join(sol)};")
    target = "".join(total_target)

    return input_str + target

In [56]:
mul_cot_format(12, 360)

'In:12*360\nTarget:[1,2]has2digits.[3,6,0]has3digits.[1,2]*0,A=[0,0],k=1,B=[0,0],C=0+0=0\n[1,2]*6,A=[7,2],k=10,B=[7,2,0],C=0+720=720\n[1,2]*3,A=[3,6],k=100,B=[3,6,0,0],C=720+3600=4320\n,END4320;'

In [57]:
""" Need to recreate MUL Math Dataset so that formatting works with COT template"""
math_single_dataset = []
nums = torch.arange(0, 10)
for num1 in nums:
    for num2 in nums:
        math_single_dataset.append((num1,num2))

math_double_dataset_mixed = []
nums = torch.arange(0, 99)
for num1 in nums:
    for num2 in nums[:10]:
        math_double_dataset_mixed.append((num1,num2))

math_double_dataset_exclusive = []
nums = torch.arange(10, 99)
for num1 in nums:
    for num2 in nums:
        math_double_dataset_exclusive.append((num1,num2))
        

math_triple_dataset_mixed = []
nums = torch.arange(0, 999)
for num1 in nums:
    for num2 in nums[:99]:
        out = num1 * num2
        math_triple_dataset_mixed.append((num1,num2))

math_triple_dataset_exclusive = []
nums = torch.arange(100, 999)
for num1 in nums:
    for num2 in nums:
        out = num1 * num2
        math_triple_dataset_exclusive.append((num1,num2))


import random
random.seed(42)
random.shuffle(math_single_dataset)
random.shuffle(math_double_dataset_mixed)
random.shuffle(math_triple_dataset_mixed)
random.shuffle(math_double_dataset_exclusive)
random.shuffle(math_triple_dataset_exclusive)

math_dataset = math_single_dataset + math_double_dataset_mixed[:300] + math_double_dataset_mixed[:600] + math_triple_dataset_mixed[:2000] +  math_triple_dataset_exclusive[:7000]
random.shuffle(math_dataset)

In [58]:
math_vocab = list(sorted(set([".", "P"] + [str(num) for num in torch.arange(0, 10).tolist()] + list(set(mul_cot_format(12,36))))))
math_stoi = {char:idx for idx, char in enumerate(math_vocab)}
math_itos = {idx:char for char, idx in math_stoi.items()}
math_encode = lambda cot_input: [math_stoi[char] for char in cot_input]
math_decode = lambda m_output: "".join([math_itos[idx] for idx in m_output])

In [59]:
print(math_stoi)

{'\n': 0, '*': 1, '+': 2, ',': 3, '.': 4, '0': 5, '1': 6, '2': 7, '3': 8, '4': 9, '5': 10, '6': 11, '7': 12, '8': 13, '9': 14, ':': 15, ';': 16, '=': 17, 'A': 18, 'B': 19, 'C': 20, 'D': 21, 'E': 22, 'I': 23, 'N': 24, 'P': 25, 'T': 26, '[': 27, ']': 28, 'a': 29, 'd': 30, 'e': 31, 'g': 32, 'h': 33, 'i': 34, 'k': 35, 'n': 36, 'r': 37, 's': 38, 't': 39}


Important indices: 
15 -> ":", this is delimiter we will use during evaluation, in which only the chars before will be given
16 -> ";", this is stop token
25 -> "P", this is padding token

In [61]:
math_decode(math_encode((mul_cot_format(12,36))))

'In:12*36\nTarget:[1,2]has2digits.[3,6]has2digits.[1,2]*6,A=[7,2],k=1,B=[7,2],C=0+72=72\n[1,2]*3,A=[3,6],k=10,B=[3,6,0],C=72+360=432\n,END432;'

In [62]:
math_dataset = [math_encode(mul_cot_format(num_pairs[0].item(),num_pairs[1].item())) for num_pairs in math_dataset]

In [63]:
#Setting block size to longest COT template in training data + 10 (give breathing room)
longest_length = 0 
for i in math_dataset:
    if len(i) >= longest_length:
        longest_length = len(i)
print(longest_length)
block_size = longest_length + 10

243


In [64]:
def get_batch(split):
    if split == "train":
        idxs = torch.randint(0, int(len(math_dataset)*0.9), (batch_size,))
    else:
        idxs = torch.randint(int(len(math_dataset)*0.9), len(math_dataset), (batch_size,))
    X = []
    Y = []
    for idx in idxs:
        x = math_dataset[idx] + [math_stoi["P"]]*block_size
        ignore_idx1 = torch.argwhere(torch.tensor(x[:block_size]) == 15)[1][0].item()
        ignore_idx2 = torch.argwhere(torch.tensor(x[:block_size]) == 25)[0][0].item()
        X.append(x[:block_size])
        Y.append([-1]*(len(x[:ignore_idx1])) + x[ignore_idx1 + 1 : ignore_idx2] + [-1]*(len(x[:block_size]) - ignore_idx2 + 1)) #-1 targets for anything up to and including target: and -1 for all pad tokens
    X = torch.tensor(X).to(torch.long)
    Y = torch.tensor(Y)
    return X, Y

In [87]:
model = Transformer(len(math_vocab), embed_dim, block_size, num_heads, num_blocks)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr = lr)

In [88]:
@torch.no_grad()
def estimate_loss(eval_iters):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            X,Y = X.to(device), Y.to(device)
            logits  = model(X)
            loss = F.cross_entropy(logits.permute(0, 2, 1), Y, ignore_index = -1)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

Once again, I trained the model in another file for good number of hours, so I'll import those weights and show it can train here for a few epochs to higlight the code work

In [89]:
model.load_state_dict(torch.load("/Users/bhavverma/Documents/karpathy_notes_and_solutions/video_7_dependencies/cot_model_weights.pt"))

<All keys matched successfully>

In [90]:
for i in range(500):
    x,y = get_batch('train')
    x,y = x.to(device), y.to(device)
    logits = model(x)
    loss = F.cross_entropy(logits.permute(0, 2, 1), y, ignore_index = -1)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i % 100) == 0:
        estimated_losses = estimate_loss(200)
        print(f"Iteration {i}, Train Loss {estimated_losses['train']}, Val Loss {estimated_losses['val']}")

Iteration 0, Train Loss 0.005612213164567947, Val Loss 0.010330012068152428
Iteration 100, Train Loss 0.004963007289916277, Val Loss 0.010579774156212807
Iteration 200, Train Loss 0.005102494265884161, Val Loss 0.009807861410081387
Iteration 300, Train Loss 0.004419269040226936, Val Loss 0.008675486780703068
Iteration 400, Train Loss 0.0068982369266450405, Val Loss 0.010805675759911537


In [91]:
@torch.no_grad()
def mul_calculator(num1,num2):
    encoded_input = math_encode(mul_cot_format(num1, num2))
    ignore_idx1 = torch.argwhere(torch.tensor(encoded_input) == 15)[1][0].item()
    encoded_input = encoded_input[:ignore_idx1+1]
    print(math_decode(encoded_input))
    pred_idx = None
    decoded_output = []
    while pred_idx != 16:
        transformer_input = torch.tensor(encoded_input).view(1,-1)
        model.eval()
        output = model(transformer_input.to(device))
        model.train()
        pred_idx = torch.argmax(output[0,-1,:]).item()
        encoded_input += [pred_idx]
        decoded_output += [pred_idx]
    return math_decode(decoded_output)

In [94]:
#Showing with print statement the initial prompt, will redefine the fuc without print statement
mul_calculator(460, 205)

In:460*205
Target:


'[4,6,0]has3digits.[2,0,5]has3digits.[4,6,0]*5,A=[2,3,0,0],k=1,B=[2,3,0,0],C=0+2300=2300\n[4,6,0]*0,A=[0,0,0],k=10,B=[0,0,0,0],C=2300+0=2300\n[4,6,0]*2,A=[9,2,0],k=100,B=[9,2,0,0,0],C=2300+92000=94300\n,END94300;'

In [93]:
460*205

94300

In [95]:
@torch.no_grad()
def mul_calculator(num1,num2):
    encoded_input = math_encode(mul_cot_format(num1, num2))
    ignore_idx1 = torch.argwhere(torch.tensor(encoded_input) == 15)[1][0].item()
    encoded_input = encoded_input[:ignore_idx1+1]
    pred_idx = None
    decoded_output = []
    while pred_idx != 16:
        transformer_input = torch.tensor(encoded_input).view(1,-1)
        model.eval()
        output = model(transformer_input.to(device))
        model.train()
        pred_idx = torch.argmax(output[0,-1,:]).item()
        encoded_input += [pred_idx]
        decoded_output += [pred_idx]
    return math_decode(decoded_output)

In [96]:
"Testing Performance on 1000 trained instances"
seen_examples = math_double_dataset_mixed[:600] + math_triple_dataset_mixed[:2000] +  math_triple_dataset_exclusive[:7000]
random.seed(42)
random.shuffle(seen_examples)

print(seen_examples[:10])

amount_correct = 0
for ex in tqdm(seen_examples[:1000]):
    num1, num2 = ex[0].item(), ex[1].item()
    transformer_output = mul_calculator(num1, num2)
    if not("END" in transformer_output):
        continue
    else:
        transformer_output = transformer_output.split("END")[1][:-1]
    actual_output = num1 * num2
    if str(actual_output) == str(transformer_output):
        amount_correct += 1

print(amount_correct/1000)

[(tensor(325), tensor(925)), (tensor(938), tensor(943)), (tensor(910), tensor(201)), (tensor(906), tensor(477)), (tensor(995), tensor(33)), (tensor(356), tensor(342)), (tensor(462), tensor(384)), (tensor(722), tensor(699)), (tensor(935), tensor(419)), (tensor(74), tensor(4))]


100%|██████████| 1000/1000 [56:53<00:00,  3.41s/it] 

0.764





In [97]:
"Testing Performance on 1000 Train instances"
unseen_examples = math_double_dataset_mixed[600:] + math_triple_dataset_mixed[2000:6000] +  math_triple_dataset_exclusive[7000:11000]
random.seed(42)
random.shuffle(unseen_examples)

print(unseen_examples[:10])

amount_correct = 0
for ex in tqdm(unseen_examples[:1000]):
    num1, num2 = ex[0].item(), ex[1].item()
    transformer_output = mul_calculator(num1, num2)
    if not("END" in transformer_output):
        continue
    else:
        transformer_output = transformer_output.split("END")[1][:-1]
    actual_output = num1 * num2
    if str(actual_output) == str(transformer_output):
        amount_correct += 1

print(amount_correct/1000)

[(tensor(97), tensor(57)), (tensor(94), tensor(5)), (tensor(53), tensor(77)), (tensor(925), tensor(874)), (tensor(472), tensor(336)), (tensor(826), tensor(16)), (tensor(590), tensor(593)), (tensor(655), tensor(928)), (tensor(826), tensor(489)), (tensor(879), tensor(105))]


100%|██████████| 1000/1000 [30:26<00:00,  1.83s/it]

0.635





As we can clearly see COT tracing provided substantial benefits to allowing the transformer to learn multiplication, with it increasing from an aprox 4% accuracy to 64% accuracy