In [1]:
import torch
import torch.nn.functional as F

from torch import nn

In [2]:
torch.manual_seed(1337)
B,T,C = 4, 8, 2
x = torch.randn((B, T, C))
x

tensor([[[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]],

        [[ 1.3488, -0.1396],
         [ 0.2858,  0.9651],
         [-2.0371,  0.4931],
         [ 1.4870,  0.5910],
         [ 0.1260, -1.5627],
         [-1.1601, -0.3348],
         [ 0.4478, -0.8016],
         [ 1.5236,  2.5086]],

        [[-0.6631, -0.2513],
         [ 1.0101,  0.1215],
         [ 0.1584,  1.1340],
         [-1.1539, -0.2984],
         [-0.5075, -0.9239],
         [ 0.5467, -1.4948],
         [-1.2057,  0.5718],
         [-0.5974, -0.6937]],

        [[ 1.6455, -0.8030],
         [ 1.3514, -0.2759],
         [-1.5108,  2.1048],
         [ 2.7630, -1.7465],
         [ 1.4516, -1.5103],
         [ 0.8212, -0.2115],
         [ 0.7789,  1.5333],
         [ 1.6097, -0.4032]]])

- In the avg tensor, for each timestep I need to find the avg of `C` upto _that_ timestep.

In [3]:
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        tmp = x[b,:t+1] # B,T
        xbow[b,t] = torch.mean(tmp, 0) # t, C
xbow

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

- Andrej says that there's a more efficient way of doing this using mmul. Let's try to derive it ourselves.

In [4]:
m1 = x[0]
m1, m1.shape

(tensor([[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]]),
 torch.Size([8, 2]))

- The target matrix is also of shape `8, 2`. Going by that, the matrix to multiply `m1` with is, therefore, of shape `2, 2` giving `8, 2`.

In [5]:
m2 = torch.ones((C, C)) / 2
m2, m2.shape

(tensor([[0.5000, 0.5000],
         [0.5000, 0.5000]]),
 torch.Size([2, 2]))

- That's not correct. What we can do instead is have a left lower triangular matrix of ones (initially) that is multiplied with `x`.
- Each row can then be divided by i+2.
- What this does is that at each row, it masks out the future timesteps, and only averages curr and upto that row.
- So, this matrix will be: `8, 8`. (result -> `8, 2`)

In [6]:
m1 = torch.tril(torch.ones((T,T)))
m1.shape, m1

(torch.Size([8, 8]),
 tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 1., 1., 1., 1.]]))

In [7]:
m2 = x[0]
m2.shape, m2

(torch.Size([8, 2]),
 tensor([[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]]))

In [8]:
m1 @ m2 # not the avg yet

tensor([[ 0.1808, -0.0700],
        [-0.1789, -0.9852],
        [ 0.4469, -0.9597],
        [ 1.4014, -0.8953],
        [ 1.7626,  0.2725],
        [ 0.4127, -0.2376],
        [ 0.6486, -0.4774],
        [-0.2725,  1.0659]])

In [9]:
torch.arange(1, T+1).unsqueeze(1).shape

torch.Size([8, 1])

In [10]:
avg_m = m1 / torch.arange(1, T+1).unsqueeze(1)
avg_m

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [11]:
res = avg_m @ m2
res

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [12]:
torch.allclose(res, xbow[0])

True

Yayy! It's correct!

Another (not-so-hard-coded) way to do it:

In [13]:
avg_m = m1 / torch.sum(m1, 1, keepdim=True)
avg_m

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [14]:
res = avg_m @ m2
torch.allclose(res, xbow[0])

True

In [15]:
wei = torch.tril(torch.ones((T,T)))
wei = wei / torch.sum(wei, 1, keepdim=True)
xbow2 = wei @ x # [B], T, T @ B, T, C --> B, T, C
xbow2

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

In [16]:
torch.allclose(xbow2[0], res)

True

There's another way of computing the same thing. This version uses softmax. Softmax basically creates a probability distribution from a set of logits.

So, if you have 0s in places where you want 1s and `-inf`s in places where you want a 0 (bear with me), then when you exponentiate these, you'll get the lower tril matrix. The next step in softmax is dividing each element with the sum of all the exponentiated values, i.e., all the values in each row of the obtained tril giving us the required matrix.

In [17]:
tril = torch.tril(torch.ones((T, T)))
tril==0

tensor([[False,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False]])

In [18]:
tril = torch.tril(torch.ones((T, T)))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [19]:
xbow3 = wei @ x
torch.allclose(res, xbow3[0])

True

Of course, all we have done so far is passing around information from the prev tokens to the curr token by averaging the embeddings up until curr token embedding. And this isn't the actual self-attention.

Here's what Andrej's telling about self-attention (for starters):
- Just averaging the tok embs until a timestep is a lossy way of passing around information. Instead, here's what happens in self attention...
    - The current token has a query vector. It asks a question.
    - The tokens in the previous timesteps have a key vector. It says what it has to offer.
    - The dot product between the curr tok's `q` vector and prev toks' `k` vector gives `wei`.

In [54]:
B, T, C = 4, 8, 32
x = torch.randn((B, T, C))
x.shape

torch.Size([4, 8, 32])

In [55]:
head_sz = 16
key = nn.Linear(C, head_sz, bias=False)
query = nn.Linear(C, head_sz, bias=False)
value = nn.Linear(C, head_sz, bias=False)
k = key(x)
q = query(x)

k.shape

torch.Size([4, 8, 16])

In [56]:
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)
print(wei[0])
wei.shape

tensor([[-2.6408, -0.0168, -0.1657, -0.3016,  4.6135,  3.8161,  0.5657, -1.8342],
        [ 0.5343, -0.4730, -0.7929, -0.9196, -1.4602,  0.4356, -0.0270, -0.3583],
        [ 0.1870, -0.7272, -1.3527, -0.2296,  2.3319,  1.1802,  1.4095,  0.3595],
        [-0.3962,  0.3359,  0.6781,  0.1396,  1.4178,  1.1088,  0.2748, -0.9579],
        [ 0.8676, -0.9641, -2.2328, -0.5911, -2.0789,  0.2846, -0.4443,  1.2669],
        [ 0.0285, -0.1294,  1.5084,  1.3670,  1.7330,  1.9421, -1.6419, -0.4578],
        [-0.9873,  1.8264,  3.8634, -0.1285,  4.2294, -0.9790, -1.7341, -0.4707],
        [ 0.4985, -0.0867,  1.0004,  0.6012,  0.9321, -0.7500, -0.3899, -0.5351]],
       grad_fn=<SelectBackward0>)


torch.Size([4, 8, 8])

We must now mask out the previous timesteps.

In [57]:
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril==0, float('-inf')) # tril broadcasts from (T, T) to (B, T, T)
wei = F.softmax(wei, dim=-1) # still B, T, T
v = value(x)
out = wei @ v # (B, T, T) @ (B, T, head_sz) ---> (B, T, head_sz)
# out = wei @ x # (B, T, T) @ (B, T, C) ---> (B, T, C)

out.shape

torch.Size([4, 8, 16])