This script is an example of how to make tokens communicate with previous tokens. It is important that they do not communicate with future tokens.

The easiest way for making them communitate is to average the channels of the previous tokens. This created a feature vector which summarizes the previous tokens. This is called self-attention.

Averageing the channels is very lossy, as it doesn't take sequences or time into account.

The point here is that we want to iterate batch-wise over the tokens and average the channels of the previous tokens.


B is the batch size. This means how many subsets of data are we running in parallel

T is the time steps. This determines how many tokens should be in each batch

C is the number of features in each time step. This is the resolution for our encoding. Setting this to 2 would mean that each token is represented as a 2D feature vector.

In [4]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)  # seeded randomness
B, T, C = 4, 8, 2  # batch size, time steps, number of classes
x = torch.randn(B, T, C)  # random input
print(x.shape)

xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t + 1]
        xbow[b, t] = torch.mean(xprev, dim=0)
print(x[0], "\n")
print(xbow[0])


torch.Size([4, 8, 2])
tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]]) 

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


This can be done more effeciently using matricies by making a lower triangular matrix, A, and multiplying it to B. By normalizing the A matrix this will end up as an average.

![alt text](https://algebra1course.wordpress.com/wp-content/uploads/2013/02/slide10.jpg)

In [3]:
# example
a = torch.tril(torch.ones((5,5)))
a = a / a.sum(dim=1, keepdim=True)
b = torch.randn(5, 5)
c = a @ b # dot product

print(a)
print("----------------")
print(b)
print("----------------")
print(c)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]])
----------------
tensor([[-0.8345,  0.5978, -0.0514, -0.0646, -0.4970],
        [ 0.4658, -0.2573, -1.0673,  2.0089, -0.9665],
        [ 0.3583,  0.1073,  1.2463,  1.2460,  0.3534],
        [ 0.9425, -1.6669, -0.7960,  0.1298, -1.9446],
        [ 0.0610, -0.2379,  1.9020, -1.1763, -0.1772]])
----------------
tensor([[-0.8345,  0.5978, -0.0514, -0.0646, -0.4970],
        [-0.1844,  0.1703, -0.5593,  0.9722, -0.7318],
        [-0.0035,  0.1493,  0.0426,  1.0635, -0.3700],
        [ 0.2330, -0.3048, -0.1671,  0.8300, -0.7637],
        [ 0.1986, -0.2914,  0.2467,  0.4288, -0.6464]])


In [5]:
# Average normalization

# previous a is the attention weights
# previous b is the (B,T,C) tensor

weights = torch.tril(torch.ones((T,T))) # size equal to number of tokens in a sequence
weights = weights / weights.sum(dim=1, keepdim=True) # normalize the weights
xbow = weights @ x # (B, T, T) @ (B, T, C) = (B, T, C)

A better way of implementing this, is using softmax as the normalization. Here we set all 0's in the lower triangular matrix to '-inf', which gives us the same result for the weight matrix when we apply softmax.

This method allows tokens to decide which other tokens from the past they want to communicate with. This is explained as affinity in the lecure.

In [8]:
tril = torch.tril(torch.ones((T,T)))
weights = torch.zeros((T,T)) # initialize weights at 0
weights = weights.masked_fill(tril == 0, float("-inf"))
weights = F.softmax(weights, dim=1) # softmax over the time dimension (x-axis)
xbow = weights @ x # (B, T, T) @ (B, T, C) = (B, T, C)
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])