This notebook goes over how we can represent tokens and their communication through matrix multiplication, as well as covering how we can use softmax for masking.

In [8]:
import torch
from torch.nn import functional as F


In [2]:
torch.manual_seed(1337)
B,T,C = 4, 8, 2  #batch, time, channel
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [3]:
xbag_of_words = torch.zeros((B,T, C))

for b in range(B):
    for t in range(T):
        x_prev = x[b, :t + 1]
        xbag_of_words[b, t] = torch.mean(x_prev, 0)

In [4]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [5]:
xbag_of_words[0] #each row is an avg of each previous row, so the firt element is the same


tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [6]:
#quicker version using the below math trick

weight = torch.tril(torch.ones(T, T))
weight = weight/weight.sum(1, keepdim=True)
xbag_of_words2 = weight @ x # (T,T) @ B,T,C, add a B because of the @ batch multiplication symbol  --> B,T,C
torch.allclose(xbag_of_words, xbag_of_words2)

True

In [9]:
#even better softmax

triangle = torch.tril(torch.ones(T, T))
weight = torch.zeros((T,T))
weight = weight.masked_fill(triangle ==0, float('-inf'))
weight = F.softmax(weight, dim=-1)
xbag_of_words3 = weight @ x
torch.allclose(xbag_of_words, xbag_of_words3)

True