In [35]:
# Making the tokens speak to each other through attention

import torch, numpy
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)

In [36]:
x.shape

torch.Size([4, 8, 2])

In [37]:
# For attention in decoder like models we have to take into account that when we perform attention we do masked self-attention.

# Why is that?
# When generating new tokens from a given input sequence the model can only use the past tokens (the input) as reference for deciding what the next token should be
# So we have to mimic this behavior during training; concretely this means that a for predicting a token at timestep t the model is only allowed to "attend" to the tokens from position 0 up to t - 1

# As demonstrated in this file when we train a decoder we can turn one sample of k tokens into k-1 samples by using all the subsequences, so 
# (input: t_0, label: t_1), (input: t_0, t_1), label: t_2), ... 

# => Information only flows from previous tokens to the current timestamp, and we cannot get any information from the future, as we try to predict it

In [38]:
# How can you make tokens "attend" to each other?

# The easiest way to do so is performing an average over a set of tokens, which means we get the averaged information from all the tokens
x_avg = torch.zeros_like(x)

# batch dimension
for b in range(B):
    # time dimension of one sample
    for t in range(T):
        x_prev = x[b, :t+1]
        x_avg[b,t] = torch.mean(x_prev,0)
               
# This code snippet essentially did that for every time dimension in every sample in every batch 

In [39]:
# There is a mathematical trick how to make the above computation very efficient through matrix multiplication

torch.manual_seed(42)
a = torch.ones(3,3)
b = torch.randint(0,10,(3,2)).float()

In [40]:
print(f"A:\n{a}\nB:\n{b}")
c = a @ b # @ = matrix multiplication (https://en.wikipedia.org/wiki/Matrix_multiplication)
print(f"C:\n{c}")

A:
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
B:
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
C:
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


In [41]:
# Instead of standard matrix multiplication what one can do is perform row-wise summation of the second matrix B. For this we use torch.tril which gives us a lower triangular matrix.
a_triangular = torch.tril(a)

c = a_triangular @ b
print(f"C:\n{c}")

# Imagine B being our sample of tokens (2 in this case) and what happend below is that in matrix C we performed row wise addition of matrix B.

print(f"B:\n{b}")

C:
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])
B:
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])


In [42]:
# We can spin this further and instead of peforming just row wise addition we perform row wise averaging!
# For that we manipulate the lower triangular matrix:

a_triangular_average = a_triangular / torch.sum(a_triangular, 1, keepdim=True)
a_triangular_average

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [44]:
# And now we obtain a row-wise average
c = a_triangular_average @ b
c

tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])

In [45]:
# We use this principle to simplify our double for-loop attention mechanism from the beginning:

In [55]:
x_avg = torch.zeros_like(x)

# Naive implementation of attention

# batch dimension
for b in range(B):
    # time dimension of one sample
    for t in range(T):
        x_prev = x[b, :t+1]
        x_avg[b,t] = torch.mean(x_prev,0)

print(f"Naive impl: {x_avg[0]}")

# Optimized implementation with matrix multiplication

W = torch.tril(torch.ones(size=(T,T)))
W = W / torch.sum(W, 1, keepdim=True)

out = W @ x  # shape: (B,T,T) @ (B,T,C) | The shape of W is automatically broadcasted to batch dimension
print(f"Optimized impl: {out[0]}")

Naive impl: tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])
Optimized impl: tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])
