In [1]:
import torch

In [20]:
vocabulary = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, '+', '=']
itos = {i:s for i,s in enumerate(vocabulary)}
stoi = {s:i for i,s in enumerate(vocabulary)}

def get_batch(batch_size):
    sample = torch.randint(0, 5, (batch_size,2))
    y = sample.sum(1)
    plus_part = torch.ones((batch_size, 1))*stoi['+']
    equal_part = torch.ones((batch_size, 1))*stoi['=']
    x = torch.cat((sample[:,:1], plus_part, sample[:,1:], equal_part), dim=1)
    return x.int(),y

In [21]:
x,y = get_batch(10)

In [22]:
x, y

(tensor([[ 4, 10,  2, 11],
         [ 0, 10,  4, 11],
         [ 3, 10,  2, 11],
         [ 0, 10,  0, 11],
         [ 4, 10,  3, 11],
         [ 3, 10,  0, 11],
         [ 1, 10,  0, 11],
         [ 3, 10,  1, 11],
         [ 3, 10,  2, 11],
         [ 2, 10,  3, 11]], dtype=torch.int32),
 tensor([6, 4, 5, 0, 7, 3, 1, 4, 5, 5]))

In [23]:
class Head:
    def __init__(self, embed_dim, head_dim, block_size):
        self.head_dim = head_dim
        self.E = torch.randn((len(vocabulary), embed_dim))
        self.Wq = torch.randn((embed_dim, head_dim)) * embed_dim**-0.5
        self.Wk = torch.randn((embed_dim, head_dim)) * embed_dim**-0.5
        self.Wv = torch.randn((embed_dim, head_dim)) * embed_dim**-0.5
        self.mask = torch.tril(torch.ones(block_size, block_size))
        self.dense = torch.randn((head_dim, head_dim)) * head_dim**-0.5
        self.dense_bias = torch.zeros((head_dim,))

        self.final = torch.randn((head_dim, len(vocabulary))) * head_dim**-0.5

    def parameters(self):
        return [self.E, self.Wq, self.Wk, self.Wv, self.dense, self.dense_bias, self.final]

    def __call__(self, x): # x.shape = [batch_size, block_size]
        embedded = self.E[x] # [batch_size, block_size, embed_dim]
        Q = embedded @ self.Wq # [batch_size, block_size, embed_dim] @ [embed_dim, head_dim] = [batch_size, block_size, head_dim]
        K = embedded @ self.Wk # [batch_size, block_size, embed_dim] @ [embed_dim, head_dim] = [batch_size, block_size, head_dim]
        V = embedded @ self.Wv # [batch_size, block_size, embed_dim] @ [embed_dim, head_dim] = [batch_size, block_size, head_dim]

        att_weights = Q @ K.transpose(-1,-2) # [batch_size, block_size, head_dim] @ [batch_size, head_dim, block_size] = [batch_size, block_size, block_size]
        att_weights = att_weights * self.head_dim**-0.5 # scale down
        masked_att = att_weights.masked_fill(self.mask == 0, -torch.inf)
        att_weights = torch.nn.functional.softmax(masked_att, dim=2)
        weighted_output = att_weights @ V # [batch_size, block_size, block_size] @ [batch_size, block_size, head_dim] = [batch_size, block_size, head_dim]

        output = torch.nn.functional.relu(weighted_output @ self.dense + self.dense_bias) # [batch_size, block_size, head_dim] @ [head_dim, head_dim] = [batch_size, block_size, head_dim]

        return output @ self.final

In [31]:
# Training loop
model = Head(8, 16, 4)
for p in model.parameters():
    p.requires_grad = True

for i in range(1000):
    x, y = get_batch(32)
    logits = model(x)[:,-1,:] # Take only the prediction for the last token
    loss = torch.nn.functional.cross_entropy(logits, y)

    print(i+1, loss.item())
    for p in model.parameters():
        p.grad = None
    loss.backward()
    for p in model.parameters():
        p.data -= 0.1 * p.grad

1 2.630934000015259
2 2.589439868927002
3 2.4839253425598145
4 2.4814367294311523
5 2.5044431686401367
6 2.408311128616333
7 2.4608681201934814
8 2.4794373512268066
9 2.445145845413208
10 2.4535348415374756
11 2.4528188705444336
12 2.410306453704834
13 2.430544137954712
14 2.3514509201049805
15 2.4499292373657227
16 2.459643602371216
17 2.380988121032715
18 2.431831121444702
19 2.365316390991211
20 2.396195650100708
21 2.386254072189331
22 2.3812763690948486
23 2.3393983840942383
24 2.319797992706299
25 2.285654306411743
26 2.3088815212249756
27 2.33225679397583
28 2.3125219345092773
29 2.3055038452148438
30 2.257030963897705
31 2.368298292160034
32 2.361767292022705
33 2.298804998397827
34 2.3047051429748535
35 2.2803378105163574
36 2.340421438217163
37 2.2432315349578857
38 2.1725854873657227
39 2.2233641147613525
40 2.2260704040527344
41 2.228787422180176
42 2.117684841156006
43 2.195376396179199
44 2.1228535175323486
45 2.124797821044922
46 2.148360252380371
47 2.153533458709717
48

In [32]:
def eval_model(model):
    matches = 0
    count = 0
    for i in range(32):
        x, y = get_batch(32)
        y_pred = model(x)[:,-1,:]
        y_pred = torch.argmax(y_pred, dim=1)
        matches += (y == y_pred).sum().item()
        count += len(x)
    return matches / count

In [33]:
eval_model(model)

1.0

In [38]:
d1 = 0
d2 = 2
x = torch.tensor([[stoi[d1], stoi['+'], stoi[d2], stoi['=']]])
y_pred = model(x)[:,-1,:]
y_pred = torch.argmax(y_pred, dim=1)[0]
print(itos[y_pred.item()])

2
