In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformer import GPT, Config

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
text = open("./xs_data.txt", "r").read()
chars = sorted(list(set(text)))
print(chars)

['\n', ' ', '+', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '=']


In [4]:
config = Config(
    block_size=128,
    vocab_size=len(chars),
    n_head=4,
    n_embd=64,
    dropout=0.1,
    n_layer=2,
)

In [5]:
stoi = {s:i for i, s in enumerate(chars)}
itos = {i:s for s, i in stoi.items()}

In [6]:
def encode(text):
    return [stoi[c] for c in text]

def decode(toks):
    return ''.join([itos[t] for t in toks])

In [7]:
print(decode(encode("123 + 3412 = 334\n")))

123 + 3412 = 334



In [8]:
tokens = encode(text)
dataset = torch.tensor(tokens).long().to(device)
train_dataset = dataset[:int(len(tokens) * 0.95)]
test_dataset = dataset[int(len(tokens) * 0.95):]

In [9]:
new_line_locs = []
for i, t in enumerate(tokens[:int(len(tokens) * 0.95)]):
    if t == 0 and i < len(tokens) - config.block_size:
        new_line_locs.append(i)

new_line_locs = torch.tensor(new_line_locs, dtype=torch.long, device=device)
print(new_line_locs)

tensor([      0,      14,      26,  ..., 1262861, 1262873, 1262887],
       device='cuda:0')


In [10]:
def get_batch(split="train", batch_size=32):
    data = train_dataset if split == "train" else test_dataset
    starts = torch.randint(0, len(new_line_locs), (batch_size,))
    X = torch.stack([data[i:i+config.block_size] for i in starts])
    Y = torch.stack([data[i+1:i+config.block_size+1] for i in starts])
    return X, Y


In [11]:
m = GPT(config).to(device)
optimizer = torch.optim.Adam(m.parameters(), lr=3e-4)

params: 110080


In [12]:
for i in range(50000):
    xb, yb = get_batch(split="train", batch_size=64)
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if i % 100 == 0: print(i, loss.item())

0 2.7693939208984375
100 1.8852219581604004
200 1.8275845050811768
300 1.7630831003189087
400 1.6648794412612915
500 1.5907518863677979
600 1.5162897109985352
700 1.4565467834472656
800 1.4033008813858032
900 1.3433430194854736
1000 1.3127546310424805
1100 1.2774759531021118
1200 1.241422414779663
1300 1.2158740758895874
1400 1.1964032649993896
1500 1.1794254779815674
1600 1.171066164970398
1700 1.157525897026062
1800 1.1412755250930786
1900 1.1359535455703735
2000 1.1315711736679077
2100 1.1327189207077026
2200 1.120566487312317
2300 1.1229487657546997
2400 1.1196757555007935
2500 1.121739387512207
2600 1.1150109767913818
2700 1.114862084388733
2800 1.1099539995193481
2900 1.1016037464141846
3000 1.0985208749771118
3100 1.1046230792999268
3200 1.0987879037857056
3300 1.1029605865478516
3400 1.0962467193603516
3500 1.0961031913757324
3600 1.0878716707229614
3700 1.0819143056869507
3800 1.073484182357788
3900 1.077932596206665
4000 1.0715070962905884
4100 1.075303554534912
4200 1.069627

In [13]:
def solve(equation):
    m.eval()
    idx = torch.tensor(encode(equation), dtype=torch.long, device=device)
    for i in range(10):
        logits, loss = m(idx.unsqueeze(0))
        idx_next = logits[:, -1, :].argmax(dim=-1).view(1)
        if idx_next.item() == 0: break
        idx = torch.cat((idx,idx_next), dim=-1)
    m.train()

    return decode(idx.tolist())

solve("\n3 + 23 =")

'\n3 + 23 = 62'

In [14]:
import random

In [15]:
total = 0
correct = 0

def test(lo, hi):
    a = random.randint(lo, hi)
    b = random.randint(lo, hi)

    c = a + b
    # print(f"{a} + {b} = {c}")
    out = solve(f"\n{a} + {b} =")

    predicted = int(out.split("=")[1][::-1])
    print(out, predicted == c)
    return predicted == c

# for i in range(0, 1000):
#     total += 1
#     is_correct = test(0, 10000)
#     if is_correct:
#         correct += 1

# for i in range(0, 1000):
#     total += 1
#     is_correct = test(0, 1000)
#     if is_correct:
#         correct += 1

for i in range(0, 1000):
    total += 1
    is_correct = test(0, 100)
    if is_correct:
        correct += 1

for i in range(0, 100):
    total += 1
    is_correct = test(0, 10)
    if is_correct:
        correct += 1

print(f"{correct} / {total}, {correct / total}")



10 + 20 = 03 True

66 + 34 = 001 True

92 + 69 = 161 True

78 + 68 = 641 True

7 + 61 = 86 True

31 + 98 = 921 True

100 + 33 = 34 False

15 + 91 = 601 True

36 + 4 = 04 True

78 + 69 = 741 True

58 + 77 = 531 True

78 + 45 = 321 True

20 + 49 = 96 True

23 + 53 = 67 True

62 + 12 = 47 True

39 + 0 = 93 True

4 + 80 = 48 True

0 + 4 = 4 True

47 + 35 = 28 True

91 + 43 = 431 True

14 + 98 = 211 True

55 + 70 = 521 True

63 + 94 = 751 True

87 + 20 = 701 True

51 + 3 = 45 True

75 + 58 = 331 True

57 + 36 = 39 True

97 + 40 = 731 True

9 + 55 = 46 True

15 + 9 = 42 True

41 + 73 = 411 True

89 + 23 = 211 True

35 + 77 = 211 True

4 + 56 = 06 True

46 + 3 = 94 True

83 + 28 = 111 True

33 + 0 = 33 True

42 + 11 = 35 True

14 + 17 = 13 True

2 + 0 = 2 True

42 + 18 = 06 True

41 + 20 = 16 True

59 + 73 = 231 True

62 + 37 = 99 True

2 + 59 = 16 True

16 + 76 = 29 True

65 + 96 = 161 True

34 + 8 = 24 True

42 + 66 = 801 True

36 + 90 = 621 True

77 + 14 = 19 True

29 + 19 = 84 True

88 +

In [16]:
# text.find("324 + 324")
"999 + 3 =" in text

False

In [17]:
solve("\n883 + 832 =")

'\n883 + 832 = 511'

In [18]:
def add(a, b):
    out = solve(f"\n{a} + {b} =")
    return int(out.split("=")[1][::-1])


add(1520, 1001)

101

In [19]:
torch.save(m, './xs_model.pt')