In this notebook we will iterate on the bigram model and build up to attention mechanism. 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Mathematical trick for self attention:

In [52]:
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch_size, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

__How do we introduce interaction such that say 5th time character only sees 1:4 chars as context?__

One (__poor__) way of capturing this is taking an average along C dimension for chars 1:4 and using that as input to predict char 5 as output. <br>
We are losing a lot of information about spatial arrangement of chars 1:4 but for a start its ok!

In [None]:
# we want: x[b,t] = mean {i <=t} x[b,i]

xbow = torch.zeros((B,T,C))

# v1 - manual
for b in range(B):
    for t in range(T):
        x_prev = x[b, :t+1] # (t,C)
        # print(x_prev, x_prev.shape)
        # print(x_prev.mean(dim = 0))
        xbow[b,t] = x_prev.mean(dim = 0) # dim 0 is along 't'

In [33]:
x[0], xbow[0]
# first row matches, every kth subsequent row in xbow is a mean of :k+1 rows of x

(tensor([[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

(Super) Clutch trick to parallelize accumulation using matrix multiplication with lower triangular matrix:

Think of matrix multiplication from first principles!

In [47]:
torch.manual_seed(1667)
L = torch.tril(torch.ones(3,3))
# L = L.mean(dim = 1, keepdim=True)

a = torch.randint(1,10,(3,2)).float()

print(L)
print('------')
print(a)
print('------')
print(L @ a)


tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
------
tensor([[9., 5.],
        [4., 5.],
        [7., 4.]])
------
tensor([[ 9.,  5.],
        [13., 10.],
        [20., 14.]])


See how the accumulation of rows of `a` happens across the rows of `L@a`? Further if we normalize `a` along the columns we get the mean accumulation in `L@a`. <br>
Now lets implement this same in out $(B,T,C)$ dimensional matrix:

In [None]:
# v2 - using lower tril matrix

wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(dim = 1, keepdim= True) # normalize 

# weighted aggregation through matrix multi seen above
xbow2 = wei @ x # (T,T) @ (B,T,C) => pytorch adds batch dimension 

# verify
# xbow[0], xbow2[0] --  same

`(T,T) @ (B,T,C)` $\implies$ pytorch adds batch dimension $\implies$ `(B,T,T) @ (B,T,C)` $\implies$ `(B) + (T,T @ T,C)` $\implies$ `(B,T,C)`

In [None]:
# v3 - using softmax