In [1]:
import torch

In [2]:
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [3]:
# we want x[b, t] = mean_{i <= t} x[b, i]
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        # everything up to and including t
        xprev = x[b, :t+1] # (t, C)
        # average out over time
        xbow[b, t] = xprev.mean(dim=0)

In [4]:
x[0]

tensor([[ 1.8778,  0.6996],
        [ 0.9471,  0.6889],
        [-0.2117,  0.4191],
        [ 1.1670,  0.5994],
        [-0.7523,  0.1631],
        [ 0.9327, -0.5975],
        [-0.4691,  0.4282],
        [ 0.9282,  0.6231]])

In [5]:
xbow[0]

tensor([[1.8778, 0.6996],
        [1.4124, 0.6943],
        [0.8711, 0.6025],
        [0.9451, 0.6018],
        [0.6056, 0.5140],
        [0.6601, 0.3288],
        [0.4988, 0.3430],
        [0.5525, 0.3780]])

In [10]:
# we can be very efficient doing this with matrix operations
torch.manual_seed(42)
a = torch.ones(3, 3)
# now with a mask
a = torch.tril(torch.ones(3, 3))
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print("a=")
print(a)
print("b=")
print(b)
print("c=")
print(c)

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


In [9]:
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [15]:
# we can be very efficient doing this with matrix operations
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
# if we now want a masked average, we can do this by normalizing the mask
a = a / a.sum(dim=1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print("a=")
print(a)
print("b=")
print(b)
print("c=")
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [14]:
a / a.sum(dim=1, keepdim=True)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [22]:
# now lets vectorize our toy attention
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(dim=1, keepdim=True)
# each row of wei is a probability distribution over the previous tokens
# the "averaging mask"
xbow2 = wei @ x # (T, T) @ (B, T, C) --> (B, T, C)
# under the hood pytorch will look and see (T, T) requires broadcasting and create a batch dimension 
# s.t. we have (B, T, T) @ (B, T, C) --> (B, T, C)
# the batch multiplication happens in parallel and then for each batch element we have (T, T) @ (T, C) --> (T, C)    

In [18]:
x.shape

torch.Size([4, 8, 2])

In [19]:
xbow.shape

torch.Size([4, 8, 2])

tensor([[[ 1.8778,  0.6996],
         [ 1.4124,  0.6943],
         [ 0.8711,  0.6025],
         [ 0.9451,  0.6018],
         [ 0.6056,  0.5140],
         [ 0.6601,  0.3288],
         [ 0.4988,  0.3430],
         [ 0.5525,  0.3780]],

        [[-0.5954, -0.0072],
         [ 0.4261, -0.3368],
         [ 0.4964, -0.2207],
         [ 0.2745, -0.2707],
         [-0.0337, -0.0335],
         [-0.2426,  0.1255],
         [-0.2242,  0.0303],
         [-0.0497,  0.0400]],

        [[-1.1438,  0.0843],
         [-0.6994,  1.2355],
         [-0.4707,  1.1656],
         [-0.2883,  0.6383],
         [-0.2519,  0.3507],
         [-0.4115,  0.4310],
         [-0.3521,  0.6447],
         [-0.5144,  0.4690]],

        [[-2.0635, -0.2788],
         [-1.0403, -0.9372],
         [-0.4238, -0.4352],
         [-0.2936, -0.0962],
         [-0.0933, -0.1059],
         [ 0.1296,  0.1284],
         [ 0.3538,  0.2544],
         [ 0.1949,  0.3895]]])