In [1]:
from tinygrad.tensor import Tensor
import numpy as np
from core.model import GPTConfig
from core.model_tiny import GPT
from tinygrad.state import load_state_dict, get_parameters
from tinygrad.helpers import dtypes
from tinygrad.nn.optim import AdamW

In [2]:
from transformers import GPTNeoXTokenizerFast
tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/pythia-70m")

config = GPTConfig.from_pretrained("EleutherAI/pythia-70m")

tiny_model = GPT(config)
tiny_state_dict = GPT.state_dict_from_huggingface("EleutherAI/pythia-70m")
load_state_dict(tiny_model, tiny_state_dict, strict=False)

ram used:  0.31 GB, cos                                               : 100%|██████████| 78/78 [00:00<00:00, 1016.01it/s]

loaded weights in 99.03 ms, 0.31 GB loaded at 3.10 GB/s





In [3]:
import random
from operator import add, sub, mul, floordiv

def generate_equation():
    ops = {
        "+": add,
        "-": sub,
        "×": mul,
        "÷": floordiv,
    }

    op, operation = random.choice(list(ops.items()))

    if op == '/':
        a = random.randint(1, 10)
        b = random.randint(1, 10)
        a = a * b # ensure a/b is an integer
    else:
        a = random.randint(1, 25)
        b = random.randint(1, 25)

    equation = f"{a} {op} {b} = {operation(a,b)}"
    return equation

print(generate_equation())


9 × 2 = 18


In [4]:
def get_batch(batch_size):
    tokenizer.pad_token = tokenizer.eos_token
    batch = []
    for _ in range(batch_size):
        batch.append(generate_equation() + tokenizer.eos_token)
        
    inputs = tokenizer(batch, padding=True, return_tensors="np")['input_ids']
    targets = np.roll(inputs.copy(), -1, axis=1)
    targets[:, -1] = -100
    return inputs, targets

In [5]:
def attempt_answer(question):
    ids = Tensor([tokenizer.encode(question)], dtype=dtypes.int32)
    for i in range(2):
        output = tiny_model(ids)
        output_token = Tensor([[output.softmax(axis=-1)[0,-1].numpy().argmax()]], dtype=dtypes.int32)
        
        ids = ids.cat(output_token, dim=1)
    
    return tokenizer.decode(ids.numpy()[0], skip_special_tokens=True)


def test_model(num_tests=100):
    correct = 0
    for _ in range(num_tests):
        question = generate_equation()
        incomplete_question = question.split("=")[0] + "="
        answer = attempt_answer(incomplete_question)

        if answer.strip() == question.strip():
            print(f"\033[92m {answer} \033[0m")  # output in green
            correct += 1
        else:
            print(f"\033[91m {answer} \033[0m")  # output in red
            
    print(f"Accuracy: {correct} out of {num_tests} = {correct/num_tests:.2f}")


In [6]:
Tensor.training = True
BS = 128
optim = AdamW(get_parameters(tiny_model), lr=1e-4)

In [7]:
def loss_fn(logits, targets):
    num_classes = logits.shape[-1]
    targets_onehot = targets.reshape(list(targets.shape)+[1]).repeat([1]*len(targets.shape)+[num_classes]).eq(Tensor.arange(num_classes, dtype=targets.dtype))
    return -1 * logits.log_softmax().mul(targets_onehot).sum() / targets_onehot.sum()


In [8]:
# single pass

rolling_loss = 5

for i in range(1000):
    inputs, targets = get_batch(BS)
    logits = tiny_model(Tensor(inputs))
    loss = loss_fn(logits, Tensor(targets))

    loss.backward()
    optim.step()
    optim.zero_grad()
    
    rolling_loss = rolling_loss * 0.95 + loss.numpy() * 0.05
    print(f"step {i} rolling loss {rolling_loss:.3f} loss {loss.numpy():.3f}", end="\r")
    
    if i % 100 == 0:
        print()
        test_model()
        print()

step 0 rolling loss 5.059 loss 6.189
[91m 16 - 9 = 9 [0m
[91m 3 - 9 = 9 [0m
[91m 6 + 2 = 5 [0m
[91m 18 - 24 = - [0m
[91m 20 × 11 = 5 [0m
[91m 17 × 1 = 0 [0m
[91m 16 + 7 = 5 [0m
[91m 24 - 1 = 0 [0m
[91m 13 + 6 = 5 [0m
[92m 16 ÷ 25 = 0 [0m
[91m 7 - 3 = 5 [0m
[91m 8 - 9 = 9 [0m
[91m 18 + 24 = 5 [0m
[91m 24 - 17 = [0m
[91m 21 + 22 = - [0m
[92m 1 × 25 = 25 [0m
[91m 9 × 25 = 0 [0m
[91m 14 + 24 = 5 [0m
[91m 4 + 24 = 5 [0m
[91m 12 ÷ 7 = 0 [0m
[91m 21 × 20 = 20 [0m
[91m 14 ÷ 4 = 5 [0m
[92m 1 × 16 = 16 [0m
[91m 23 + 23 = - [0m
[91m 12 - 25 = [0m
[91m 5 - 4 = 4 [0m
[91m 24 - 13 = 13 [0m
[91m 7 × 16 = 16 [0m
[92m 6 ÷ 22 = 0 [0m
[91m 4 × 15 = 15 [0m
[91m 11 × 1 = 1 [0m
[91m 21 ÷ 12 = 12 [0m
[91m 16 - 16 = 16 [0m
[91m 19 ÷ 15 = 0 [0m
[91m 17 + 8 = 8 [0m
[91m 4 × 1 = 0 [0m
[92m 5 ÷ 15 = 0 [0m
[91m 17 × 16 = 16 [0m
[91m 10 + 22 = 0 [0m
[92m 15 ÷ 19 = 0 [0m
[91m 5 × 23 = 0 [0m
[91m 8 × 24 = 0 [0m
[92m 4 ÷ 21 = 0 [0m


KeyboardInterrupt: 