In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformer import GPT, Config

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
text = open("./data.txt", "r").read()
chars = sorted(list(set(text)))
print(chars)

['\n', ' ', '+', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '=']


In [4]:
config = Config(
    block_size=128,
    vocab_size=len(chars),
    n_head=4,
    n_embd=256,
    dropout=0.1,
    n_layer=4,
)

In [5]:
stoi = {s:i for i, s in enumerate(chars)}
itos = {i:s for s, i in stoi.items()}

In [6]:
def encode(text):
    return [stoi[c] for c in text]

def decode(toks):
    return ''.join([itos[t] for t in toks])

In [7]:
print(decode(encode("123 + 3412 = 334\n")))

123 + 3412 = 334



In [8]:
tokens = encode(text)
dataset = torch.tensor(tokens).long().to(device)
train_dataset = dataset[:int(len(tokens) * 0.95)]
test_dataset = dataset[int(len(tokens) * 0.95):]

In [9]:
new_line_locs = []
for i, t in enumerate(tokens[:int(len(tokens) * 0.95)]):
    if t == 0 and i < len(tokens) - config.block_size:
        new_line_locs.append(i)

new_line_locs = torch.tensor(new_line_locs, dtype=torch.long, device=device)
print(new_line_locs)

tensor([        0,        19,        38,  ..., 107132647, 107132664,
        107132681], device='cuda:0')


In [10]:
def get_batch(split="train", batch_size=32):
    data = train_dataset if split == "train" else test_dataset
    starts = torch.randint(0, len(new_line_locs), (batch_size,))
    X = torch.stack([data[i:i+config.block_size] for i in starts])
    Y = torch.stack([data[i+1:i+config.block_size+1] for i in starts])
    return X, Y


In [11]:
m = GPT(config).to(device)
optimizer = torch.optim.Adam(m.parameters(), lr=3e-4)

params: 3199488


In [20]:
for i in range(100000):
    xb, yb = get_batch(split="train", batch_size=64)
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if i % 100 == 0: print(i, loss.item())

0 1.028536081314087
100 1.03047776222229
200 1.023223876953125
300 1.0292056798934937
400 1.026403784751892
500 1.0227800607681274
600 1.0231775045394897
700 1.0267149209976196
800 1.0288524627685547
900 1.0263512134552002
1000 1.0250824689865112
1100 1.025925874710083
1200 1.0294182300567627
1300 1.0229589939117432
1400 1.0197416543960571
1500 1.021450161933899
1600 1.0280606746673584
1700 1.0262962579727173
1800 1.029707908630371
1900 1.026080846786499
2000 1.0347639322280884
2100 1.02982759475708
2200 1.023695468902588
2300 1.0219638347625732
2400 1.02284574508667
2500 1.025407314300537
2600 1.0264673233032227
2700 1.0230599641799927
2800 1.0344970226287842
2900 1.0232372283935547
3000 1.0251270532608032
3100 1.0237553119659424
3200 1.0205727815628052
3300 1.0254640579223633
3400 1.023802638053894
3500 1.025191068649292
3600 1.0280824899673462
3700 1.0205121040344238
3800 1.0233972072601318
3900 1.0288875102996826
4000 1.0256034135818481
4100 1.0280306339263916
4200 1.02626180648803

In [13]:
def solve(equation):
    m.eval()
    idx = torch.tensor(encode(equation), dtype=torch.long, device=device)
    for i in range(10):
        logits, loss = m(idx.unsqueeze(0))
        idx_next = logits[:, -1, :].argmax(dim=-1).view(1)
        if idx_next.item() == 0: break
        idx = torch.cat((idx,idx_next), dim=-1)
    m.train()

    return decode(idx.tolist())

solve("3 + 23 =")

'3 + 23 = 6978'

In [14]:
import random

In [30]:
total = 0
correct = 0

def test(lo, hi):
    a = random.randint(lo, hi)
    b = random.randint(lo, hi)

    c = a + b
    # print(f"{a} + {b} = {c}")
    out = solve(f"\n{a} + {b} =")

    predicted = int(out.split("=")[1][::-1])
    print(out, predicted == c)
    return predicted == c

for i in range(0, 1000):
    total += 1
    is_correct = test(0, 10000)
    if is_correct:
        correct += 1

for i in range(0, 1000):
    total += 1
    is_correct = test(0, 1000)
    if is_correct:
        correct += 1

for i in range(0, 1000):
    total += 1
    is_correct = test(0, 100)
    if is_correct:
        correct += 1

for i in range(0, 100):
    total += 1
    is_correct = test(0, 10)
    if is_correct:
        correct += 1

print(f"{correct} / {total}, {correct / total}")



6198 + 4755 = 35901 True

1013 + 1391 = 4042 True

8093 + 4352 = 54421 True

7478 + 9485 = 36961 True

8562 + 6259 = 12841 True

5931 + 5535 = 66411 True

5161 + 4954 = 51101 True

4176 + 726 = 2094 True

8321 + 8644 = 56961 True

8250 + 446 = 6968 True

5002 + 6555 = 75511 True

4067 + 2417 = 4846 True

9338 + 4684 = 22041 True

41 + 8338 = 9738 True

3688 + 5747 = 5349 True

6495 + 8966 = 16451 True

5899 + 5439 = 83311 True

9909 + 7240 = 94171 True

2252 + 2161 = 3144 True

3646 + 3433 = 9707 True

4755 + 4447 = 2029 True

6085 + 3162 = 7429 True

1313 + 2492 = 5083 True

7532 + 9700 = 23271 True

503 + 5227 = 0375 True

2365 + 7113 = 8749 True

6000 + 4930 = 03901 True

9657 + 6158 = 51851 True

753 + 3389 = 2414 True

6272 + 5715 = 78911 True

922 + 2097 = 9103 True

7329 + 1751 = 0809 True

1651 + 3094 = 5474 True

9692 + 6683 = 57361 True

7615 + 2599 = 41201 True

1980 + 3681 = 1665 True

7577 + 7841 = 81451 True

196 + 1077 = 3721 True

3535 + 580 = 5114 True

6257 + 619 = 6

In [16]:
# text.find("324 + 324")
"999 + 3 =" in text

True

In [29]:
solve("\n883 + 832 =")

'\n883 + 832 = 5171'

In [46]:
def add(a, b):
    out = solve(f"\n{a} + {b} =")
    return int(out.split("=")[1][::-1])


add(1520, 1001)

2521