In [105]:
#Create example tensor
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337) #for reproducability
B,T,C = 4, 8, 2 #batch = 4, time = 8, channels = 2
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

Now our task is to let the 8 elements of the time dimension talk to each other, but only back in time. So 3 only talks to 2, 1, and 0. do to this we will find the mean of x[b,i], where i<=t, in the tensor x[b,t] this will be a basic attention mechanism

In [90]:
#basic bag of words implementation
xbow = torch.zeros(B,T,C) #x, bag of words
for b in range(B):
    for t in range(T):
        xprev =x[b,:t+1] # will be of dimensions (t,C)
        xbow[b,t] = torch.mean(xprev, 0) #0 indicates that we calculate the mean along the batch dimension
     
print(x[0], '\n')
print(xbow[0])

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]]) 

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


Notice above how the first timestep of each tensor is the same. After the first time step, the second tensor's timesteps begin to be the average of all previous of the first tensor's timesteps, so diverge.

Showing off the concept of using matrix multiplication to get weights corresponding with row in tensor a:

In [92]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
#normalize values of a by creating a tensor of the same shape as a, with the sum of values along batch axis.
#keepdim=True
a = a / torch.sum(a, 1, keepdim=True) #create probability distribution over a. Each row sums to 1
b = torch.randint(10,(3,2), dtype=torch.float32)
c = a @ b
print('a \n', a)
print('b \n', b)
print('c \n', c)

a 
 tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b 
 tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c 
 tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


Now creating the actual implementation of the matrix multiplication:
wei is of dimension (T,T), while x is of dimension (B,T,C). Pytorch will create an extradimension for wei to multiply the two, b -> wei=(B, T, T). Multiplying this with X is the same procedure as shown in the last step, imagine B=1, T=3, C=2 in the last step.
Thus we will get xbow2 as a 

In [93]:
#Create matrix of T*T, because we are going to max average over T steps, and use one set of weights for each T value
wei = torch.tril(torch.ones(T, T)) 
wei = wei / torch.sum(wei, 1, keepdim=True)
xbow2 = wei @ x

In [107]:
# using softmax and masked to do the same process
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros(T, T)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True