In [23]:
import torch
import torch.nn.functional as F

In [4]:
torch.manual_seed(1337)

<torch._C.Generator at 0x7f1204726c90>

In [6]:
# B := minibatch
# T := time steps
# C := channel (embedded space dimensions)
B, T, C = 4, 8, 2 
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [8]:
# Calculate mean value to allow communication among tokens in different time steps
xbow = torch.zeros(B, T, C)
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # The shape is (t, C)
        xbow[b, t] = torch.mean(xprev, dim=0)

In [12]:
print(f'Origin: {x[0]}')
print(f'Average: {xbow[0]}')

Origin: tensor([[-0.8345,  0.5978],
        [-0.0514, -0.0646],
        [-0.4970,  0.4658],
        [-0.2573, -1.0673],
        [ 2.0089, -0.5370],
        [ 0.2228,  0.6971],
        [-1.4267,  0.9059],
        [ 0.1446,  0.2280]])
Average: tensor([[-0.8345,  0.5978],
        [-0.4429,  0.2666],
        [-0.4610,  0.3330],
        [-0.4100, -0.0171],
        [ 0.0738, -0.1210],
        [ 0.0986,  0.0153],
        [-0.1193,  0.1425],
        [-0.0863,  0.1532]])


In [14]:
# Use matrix mutiplication to make the averaging process more efficient
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print('a=')
print(a)
print('--------')
print('b=')
print(b)
print('--------')
print('c=')
print(c)

a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
--------
b=
tensor([[3., 1.],
        [7., 9.],
        [6., 1.]])
--------
c=
tensor([[16., 11.],
        [16., 11.],
        [16., 11.]])


In [15]:
# Use triangular matrix to control the number of token for averaging
a = torch.tril(torch.ones(3, 3))
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print('a=')
print(a)
print('--------')
print('b=')
print(b)
print('--------')
print('c=')
print(c)

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
--------
b=
tensor([[5., 4.],
        [6., 1.],
        [1., 5.]])
--------
c=
tensor([[ 5.,  4.],
        [11.,  5.],
        [12., 10.]])


In [16]:
# Average the item in the triangular matrix to average the tokens
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, dim=1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print('a=')
print(a)
print('--------')
print('b=')
print(b)
print('--------')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--------
b=
tensor([[1., 4.],
        [9., 2.],
        [5., 0.]])
--------
c=
tensor([[1., 4.],
        [5., 3.],
        [5., 2.]])


In [17]:
weight = torch.tril(torch.ones(T, T))
weight = weight / torch.sum(weight, dim=1, keepdim=True)
weight

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [21]:
xbow2 = weight @ x

In [37]:
print(xbow[0])
print(xbow2[0])
torch.allclose(xbow, xbow2)

tensor([[-0.8345,  0.5978],
        [-0.4429,  0.2666],
        [-0.4610,  0.3330],
        [-0.4100, -0.0171],
        [ 0.0738, -0.1210],
        [ 0.0986,  0.0153],
        [-0.1193,  0.1425],
        [-0.0863,  0.1532]])
tensor([[-0.8345,  0.5978],
        [-0.4429,  0.2666],
        [-0.4610,  0.3330],
        [-0.4100, -0.0171],
        [ 0.0738, -0.1210],
        [ 0.0986,  0.0153],
        [-0.1193,  0.1425],
        [-0.0863,  0.1532]])


True

In [34]:
# An update for assigning physical meaning for the matrix
masked_tril = torch.tril(torch.ones(T, T))
print(f'masked triangular matrix: ')
print(masked_tril)

w_t = torch.zeros((masked_tril.shape))

# Maske the weight matrix with -inf according to the masked triangular matrix
# '-inf' means the future tokens cannot communicate with previous tokens
w_t = w_t.masked_fill(masked_tril==0, float('-inf'))
print(f'w_t with -inf:')
print(w_t)

# Use Softmax to make all items in a same row sum up to 1
w_t = F.softmax(w_t, dim=1)
print(f'w_t with softmax:')
print(w_t)

masked triangular matrix: 
tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
w_t with -inf:
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
w_t with softmax:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.000

In [36]:
xbow3 = w_t @ x
print(xbow[0])
print(xbow3[0])
torch.allclose(xbow, xbow3)

tensor([[-0.8345,  0.5978],
        [-0.4429,  0.2666],
        [-0.4610,  0.3330],
        [-0.4100, -0.0171],
        [ 0.0738, -0.1210],
        [ 0.0986,  0.0153],
        [-0.1193,  0.1425],
        [-0.0863,  0.1532]])
tensor([[-0.8345,  0.5978],
        [-0.4429,  0.2666],
        [-0.4610,  0.3330],
        [-0.4100, -0.0171],
        [ 0.0738, -0.1210],
        [ 0.0986,  0.0153],
        [-0.1193,  0.1425],
        [-0.0863,  0.1532]])


True