In [2]:
import torch

In [3]:
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [4]:
# we want x[b, t] = mean_{i <= t} x[b, i]
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        # everything up to and including t
        xprev = x[b, :t+1] # (t, C)
        # average out over time
        xbow[b, t] = xprev.mean(dim=0)

In [5]:
x[0]

tensor([[ 0.3855, -1.2077],
        [-0.9504,  0.7236],
        [-0.1962, -0.0840],
        [ 0.5312, -0.0764],
        [ 0.6885,  1.0959],
        [-0.4674, -0.5356],
        [ 0.8577, -2.2124],
        [ 0.2828, -0.4025]])

In [6]:
xbow[0]

tensor([[ 0.3855, -1.2077],
        [-0.2825, -0.2421],
        [-0.2537, -0.1894],
        [-0.0575, -0.1611],
        [ 0.0917,  0.0903],
        [-0.0015, -0.0140],
        [ 0.1213, -0.3281],
        [ 0.1415, -0.3374]])

In [7]:
# we can be very efficient doing this with matrix operations
torch.manual_seed(42)
a = torch.ones(3, 3)
# now with a mask
a = torch.tril(torch.ones(3, 3))
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print("a=")
print(a)
print("b=")
print(b)
print("c=")
print(c)

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


In [8]:
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [9]:
# we can be very efficient doing this with matrix operations
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
# if we now want a masked average, we can do this by normalizing the mask
a = a / a.sum(dim=1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print("a=")
print(a)
print("b=")
print(b)
print("c=")
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [10]:
a / a.sum(dim=1, keepdim=True)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [11]:
# now lets vectorize our toy attention
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(dim=1, keepdim=True)
# each row of wei is a probability distribution over the previous tokens
# the "averaging mask"
xbow2 = wei @ x # (T, T) @ (B, T, C) --> (B, T, C)
# under the hood pytorch will look and see (T, T) requires broadcasting and create a batch dimension 
# s.t. we have (B, T, T) @ (B, T, C) --> (B, T, C)
# the batch multiplication happens in parallel and then for each batch element we have (T, T) @ (T, C) --> (T, C)    

# thereby
torch.allclose(xbow, xbow2)

True

In [12]:
# version 3: using softmax - version we will use in the real code
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf')) # this line is saying the future cannot communicate with the past
wei = torch.softmax(wei, dim=1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True