In [1]:
import torch

device = 'cuda'

In [64]:
import torch
from torch import nn

def generate_simple_expression(max_num=100_000_000):
    """Generate a random expression"""
    num1 = random.randint(0, max_num)
    num2 = random.randint(0, max_num)
    op = '+'  # random.choice(['+', '-', '*', '/'])
    return f'{num1}{op}{num2}={eval(f"{num1} {op} {num2}")}'


VOCAB = "0123456789+()= "
VOCAB_SIZE = len("0123456789+()= ")
CONTEXT_SIZE = 32
EMBEDDING_SIZE = 16


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, n_embd, head_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size

        self.query = nn.Linear(n_embd, head_size * num_heads)
        self.key = nn.Linear(n_embd, head_size * num_heads)
        self.value = nn.Linear(n_embd, head_size * num_heads)
        self.register_buffer('mask', torch.tril(torch.ones(CONTEXT_SIZE, CONTEXT_SIZE)))

    def forward(self, x):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        q = q.view(-1, self.num_heads, x.shape[1], self.head_size)
        k = k.view(-1, self.num_heads, x.shape[1], self.head_size)
        v = v.view(-1, self.num_heads, x.shape[1], self.head_size)

        ws = q @ k.transpose(-2, -1) * self.head_size**-0.5  # relation between nodes
        ws.masked_fill(self.mask[:x.shape[1], :x.shape[1]] == 0, float('-inf'))  # mask out the future
        ws = torch.softmax(ws, dim=-1)  # (bs, num_heads, x.shape[1], x.shape[1])
        ret = ws @ v  # (bs, num_heads, x.shape[1], head_size)

        return ret.transpose(1, 2).contiguous().view(-1, x.shape[1], self.num_heads * self.head_size)


class DecoderBlock(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.attn = MultiHeadSelfAttention(n_embd, 8, 2)
        self.mlp = nn.Sequential(
            nn.Linear(16, 4 * 16),
            nn.GELU(),
            nn.Linear(4 * 16, n_embd),
        )
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class CalculatorTransformer(nn.Module):
    def __init__(self, num_layers):
        super().__init__()
        self.token_embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_SIZE)
        self.position_embedding = nn.Embedding(CONTEXT_SIZE, EMBEDDING_SIZE)
        self.blocks = nn.Sequential(*[DecoderBlock(EMBEDDING_SIZE) for _ in range(num_layers)])
        self.ln = nn.LayerNorm(EMBEDDING_SIZE)
        self.head = nn.Linear(EMBEDDING_SIZE, VOCAB_SIZE)

    def forward(self, x):
        x = self.token_embedding(x) + self.position_embedding(torch.arange(x.shape[1]))
        x = self.blocks(x)
        x = self.ln(x)
        x = self.head(x)
        return x

    def generate(self, idx):
        self.eval()
        for i in range(32):
            logits = self(idx[0, -CONTEXT_SIZE:].view(1, -1))
            #print(logits[:, -1, :].shape)
            idx_next = logits[:, -1, :].argmax(-1)
            idx = torch.cat((idx, idx_next.view(1, 1)), dim=1)
        self.train()
        return idx


model = CalculatorTransformer(2)


def get_batch(size=32):
    x = []
    for _ in range(size):
        t = [VOCAB.find(a) for a in generate_simple_expression()]
        while len(t) < CONTEXT_SIZE + 1:
            t.append(VOCAB.find(' '))
        x.append(t)
    x = torch.tensor(x)
    return x[:, :-1], x[:, 1:]


optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for i in range(20000):
    x, y = get_batch()
    out = model(x)
    loss = nn.functional.cross_entropy(out.reshape(-1, VOCAB_SIZE), y.reshape(-1))
    model.zero_grad()
    loss.backward()
    optimizer.step()
    if i % 1000 == 0:
        print(f'loss: {loss.item()}')


loss: 2.8373327255249023
loss: 0.028803199529647827
loss: 0.00251967366784811
loss: 0.0010418533347547054
loss: 0.0003799532714765519
loss: 0.0004831521073356271
loss: 0.00037596895708702505
loss: 7.950369035825133e-05
loss: 0.00017593210213817656
loss: 4.390700632939115e-05
loss: 3.0849099857732654e-05
loss: 2.5491368432994932e-05
loss: 1.1139887647004798e-05
loss: 1.7092868802137673e-05
loss: 8.40282427816419e-06
loss: 5.48824027646333e-06
loss: 1.8239163182443008e-05
loss: 4.922851076116785e-06
loss: 2.455703724990599e-05
loss: 6.804723852837924e-06


In [65]:
encode = lambda t: torch.tensor([VOCAB.find(a) for a in t]).view(1, -1)
decode = lambda t: ''.join([VOCAB[a] for a in t])

decode(model.generate(encode('2001231+312131=')).view(-1).tolist())

'2001231+312131=366636636663                    '