In [1]:
import torch
from torch.nn import functional as F

In [2]:
torch.manual_seed(1337)

<torch._C.Generator at 0x720cec1d15d0>

## The Problem

We have a batch of examples, where each example is a series of tokens, and each token is an embedding vector. For each token, we want to calculate the average of the previous tokens, which will serve as a form of communication between them. But future tokens cannot be communicated with, since they are in the future, which is what we are trying to predict. Eventually, we will use this to predict the next token in the sequence.

### Version 1 - For Loop

We just use a for loop to iterate over the previous tokens and take the average.

In [3]:
# B - batch
# T - time
# C - channel
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [4]:
# BOW stands for "bag of words"
xbow = torch.zeros((B, T, C))
xbow.shape

torch.Size([4, 8, 2])

In [5]:
for b in range(B):
    for t in range(T):
        # xprev is of shape (t, C)
        xprev = x[b, :t+1]
        xbow[b, t] = torch.mean(xprev, 0)

In [6]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [7]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

Note that we average across the time dimension.

In [8]:
print(f'x[0][0][0] = {x[0][0][0].item():.4f}, mean(x[0][0][0]) = {(x[0][0][0]).item():.4f}, xbow[0][1][0] = {xbow[0][0][0].item():.4f}')
print(f'x[0][0][0] = {x[0][0][0].item():.4f}, x[0][1][0] = {x[0][1][0].item():.4f}, mean(x[0][0][0], x[0][1][0]) = {((x[0][0][0]+x[0][1][0])/2).item():.4f}, xbow[0][1][0] = {xbow[0][1][0].item():.4f}')
print(f'x[0][0][0] = {x[0][0][0].item():.4f}, x[0][1][0] = {x[0][1][0].item():.4f}, x[0][2][0] = {x[0][2][0].item():.4f}, mean(x[0][0][0], x[0][1][0], x[0][2][0]) = {((x[0][0][0]+x[0][1][0]+x[0][2][0])/3).item():.4f}, xbow[0][2][0] = {xbow[0][2][0].item():.4f}')

x[0][0][0] = 0.1808, mean(x[0][0][0]) = 0.1808, xbow[0][1][0] = 0.1808
x[0][0][0] = 0.1808, x[0][1][0] = -0.3596, mean(x[0][0][0], x[0][1][0]) = -0.0894, xbow[0][1][0] = -0.0894
x[0][0][0] = 0.1808, x[0][1][0] = -0.3596, x[0][2][0] = 0.6258, mean(x[0][0][0], x[0][1][0], x[0][2][0]) = 0.1490, xbow[0][2][0] = 0.1490


#### Version 2 - Replacing the For Loop with Matrix Multiplication

Matrix multiplication is kind of like a series of dot products, which can be implemented as a for loop. For example, consider two matrices, $\mathbf{A}$ and $\mathbf{B}$.

$$
\mathbf{A} = \begin{bmatrix}
a_{1 \, 1} & a_{1 \, 2} & \dots & a_{1 \, k}\\
a_{2 \, 1} & a_{2 \, 2} & \dots & a_{2 \, k}\\
\vdots & \vdots & \ddots & \vdots\\
a_{m \, 1} & a_{m \, 2} & \dots & a_{m \, k}\\
\end{bmatrix}
\quad\quad\quad
\mathbf{B} = \begin{bmatrix}
b_{1 \, 1} & b_{1 \, 2} & \dots & b_{1 \, n}\\
b_{2 \, 1} & b_{2 \, 2} & \dots & b_{2 \, n}\\
\vdots & \vdots & \ddots & \vdots\\
b_{k \, 1} & b_{k \, 2} & \dots & b_{k \, n}\\
\end{bmatrix}
$$

The matrix $\mathbf{C} = \mathbf{A} \mathbf{B}$ is defined as

$$
\mathbf{C} = \mathbf{A} \mathbf{B} = \begin{bmatrix}
a_{1 \, 1} b_{1 \, 1} + a_{1 \, 2} b_{2 \, 1} + \dots + a_{1 \, k} b_{k \, 1} & a_{1 \, 1} b_{1 \, 2} + a_{1 \, 2} b_{2 \, 2} + \dots + a_{1 \, k} b_{k \, 2} & \dots & a_{1 \, 1} b_{1 \, n} + a_{1 \, 2} b_{2 \, n} + \dots + a_{1 \, k} b_{k \, n}\\
a_{2 \, 1} b_{1 \, 1} + a_{2 \, 2} b_{2 \, 1} + \dots + a_{2 \, k} b_{k \, 1} & a_{2 \, 1} b_{1 \, 2} + a_{2 \, 2} b_{2 \, 2} + \dots + a_{2 \, k} b_{k \, 2} & \dots & a_{2 \, 1} b_{1 \, n} + a_{2 \, 2} b_{2 \, n} + \dots + a_{2 \, k} b_{k \, n}\\
\vdots & \vdots & \ddots & \vdots\\
a_{m \, 1} b_{1 \, 1} + a_{m \, 2} b_{2 \, 1} + \dots + a_{m \, k} b_{k \, 1} & a_{m \, 1} b_{1 \, 2} + a_{m \, 2} b_{2 \, 2} + \dots + a_{m \, k} b_{k \, 2} & \dots & a_{m \, 1} b_{1 \, n} + a_{m \, 2} b_{2 \, n} + \dots + a_{m \, k} b_{k \, n}\\
\end{bmatrix}
$$

Note that for two vectors $\mathbf{x}$ and $\mathbf{y}$

$$
\mathbf{x} = \begin{bmatrix}
x_1 & x_2 & \dots & x_n
\end{bmatrix}
\quad\quad\quad
\mathbf{y} = \begin{bmatrix}
y_1\\
y_2\\
\dots\\
y_n
\end{bmatrix}
$$

the dot product $\mathbf{x} \cdot \mathbf{y}$ is defined as

$$
\mathbf{x} \cdot \mathbf{y} = \begin{bmatrix}
x_1 & x_2 & \dots & x_n
\end{bmatrix} \begin{bmatrix}
y_1\\
y_2\\
\vdots\\
y_n
\end{bmatrix} = x_1 y_1 + x_2 y_2 + \dots x_n y_n
$$

If we denote row $i$ of matrix $\mathbf{M}$ as $\mathbf{m}_i$ and column $j$ of matrix $\mathbf{M}$ as $\mathbf{m}^j$, then we can express the matrix multiplication $\mathbf{C} = \mathbf{A} \mathbf{B}$ as

$$
\mathbf{C} = \mathbf{A} \mathbf{B} = \begin{bmatrix}
\mathbf{a}_1 \mathbf{b}^1 & \mathbf{a}_1 \mathbf{b}^2 & \dots & \mathbf{a}_1 \mathbf{b}^n\\
\mathbf{a}_2 \mathbf{b}^1 & \mathbf{a}_2 \mathbf{b}^2 & \dots & \mathbf{a}_2 \mathbf{b}^n\\
\vdots & \vdots & \ddots & \vdots\\
\mathbf{a}_m \mathbf{b}^1 & \mathbf{a}_m \mathbf{b}^2 & \dots & \mathbf{a}_m \mathbf{b}^n
\end{bmatrix}
$$

Naturally, we can implement the dot product as a for loop

```
function dot(X, Y):
    assert number of columns of X == number of rows of Y
    dot = 0
    for i = 1..len(X):
        dot += X_i * Y_j
    return dot
```

Notice how similar calculating the dot product is to calculating the average of the elements of a vector

```
function average(X):
    avg = 0
    for i = 1..len(X):
        avg += X_i
    avg /= len(X)
    return avg
```

If we normalize the vector before taking its average, then we don't need to divide at the end.

```
function average(X):
    Xnorm = X / len(X) # Element-wise division
    avg = 0
    for i = 1..len(X):
        avg += Xnorm_i
    return avg
```

Further, a dot product where one of the vectors has $1$ for all its elements is identical to the sum of that vector. This gives us another way to calculate the average of a vector.

```
function average(X):
    Xnorm = X / len(X) # Element-wise division
    ones = vector of length len(X) where all elements are 1
    avg = dot(Xnorm, ones)
    return avg
```

This is great because modern computers can do matrix multiplication more efficiently than for loops. But one last thing remains: getting all the averages of the previous vectors. For this, we can use a triangular matrix.

If we have lower triangular matrix $\mathbf{L} \in \mathbb{R}^{m \times m}$

$$
\mathbf{L} = \begin{bmatrix}
1 & 0 & 0 & \dots & 0 \\
1 & 1 & 0 & \dots & 0 \\
1 & 1 & 1 & \dots & 0 \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
1 & 1 & 1 & \dots & 1 \\
\end{bmatrix}
$$

Then the matrix multiplication $\mathbf{L} \mathbf{A}$ is

$$
\begin{align*}
\mathbf{L} \mathbf{A} &= \begin{bmatrix}
1 & 0 & 0 & \dots & 0 \\
1 & 1 & 0 & \dots & 0 \\
1 & 1 & 1 & \dots & 0 \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
1 & 1 & 1 & \dots & 1 \\
\end{bmatrix} \begin{bmatrix}
a_{1 \, 1} & a_{1 \, 2} & \dots & a_{1 \, k}\\
a_{2 \, 1} & a_{2 \, 2} & \dots & a_{2 \, k}\\
\vdots & \vdots & \ddots & \vdots\\
a_{m \, 1} & a_{m \, 2} & \dots & a_{m \, k}\\
\end{bmatrix} \\[35pt]
&= \begin{bmatrix}
a_{1 \, 1} + 0 + \dots + 0 & a_{1 \, 2} + 0 + \dots + 0 & \dots & a_{1 \,k} + 0 + \dots + 0\\
a_{1 \, 1} + a_{2 \, 1} + \dots + 0 & a_{1 \, 2} + a_{2 \, 2} + \dots + 0 & \dots & a_{1 \,k} + a_{2 \,k} + \dots + 0\\
\vdots & \vdots & \ddots & \vdots\\
a_{1 \, 1} + a_{2 \, 1} + \dots + a_{m \, 1} & a_{1 \, 2} + a_{2 \, 2} + \dots + a_{m \, 2} & \dots & a_{1 \,k} + a_{2 \,k} + \dots + a_{m \, k}\\
\end{bmatrix} \\[35pt]
\mathbf{L} \mathbf{A} &= \begin{bmatrix}
a_{1 \, 1} & a_{1 \, 2} & \dots & a_{1 \,k}\\
a_{1 \, 1} + a_{2 \, 1} & a_{1 \, 2} + a_{2 \, 2} & \dots & a_{1 \,k} + a_{2 \,k}\\
\vdots & \vdots & \ddots & \vdots\\
a_{1 \, 1} + a_{2 \, 1} + \dots + a_{m \, 1} & a_{1 \, 2} + a_{2 \, 2} + \dots + a_{m \, 2} & \dots & a_{1 \,k} + a_{2 \,k} + \dots + a_{m \, k}\\
\end{bmatrix}
\end{align*}
$$

If we normalize matrix $\mathbf{A}$ before performing the matrix multiplication with the lower triangular matrix, then each element in the last row is the average of the previous columns. If we want to have correct averages for every row, however, we can normalize the triangular matrix to ensure that each row in the triangular matrix adds to $1$.

Let's implement this in PyTorch.

In [9]:
a = torch.randint(0, 10, (3, 3)).float()
ones = torch.ones(3, 3)
c = ones @ a
print(f'{ones=}\n\n{a=}\n\n{c=}')

ones=tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

a=tensor([[8., 6., 5.],
        [2., 4., 4.],
        [7., 4., 5.]])

c=tensor([[17., 14., 14.],
        [17., 14., 14.],
        [17., 14., 14.]])


In [10]:
L = torch.tril(torch.ones(3, 3))
c = L @ a
print(f'{L=}\n\n{a=}\n\n{c=}')

L=tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

a=tensor([[8., 6., 5.],
        [2., 4., 4.],
        [7., 4., 5.]])

c=tensor([[ 8.,  6.,  5.],
        [10., 10.,  9.],
        [17., 14., 14.]])


In [11]:
L_norm = L / L.sum(1, keepdim=True)
c = L_norm @ a
print(f'{L=}\n\n{L_norm=}\n\n{a=}\n\n{c=}')

L=tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

L_norm=tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

a=tensor([[8., 6., 5.],
        [2., 4., 4.],
        [7., 4., 5.]])

c=tensor([[8.0000, 6.0000, 5.0000],
        [5.0000, 5.0000, 4.5000],
        [5.6667, 4.6667, 4.6667]])


Implementing this method with our example:

In [12]:
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [13]:
xbow2 = wei @ x
xbow2[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [14]:
torch.allclose(xbow[1], xbow2[1], atol=1e-7)

True

### Version 3 - Softmax

We can also achieve this using softmax. We start with the lower triangular matrix, as usual.

In [15]:
tril = torch.tril(torch.ones(T, T))
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

Then we have `wei` start as all zeros.

In [16]:
wei = torch.zeros((T,  T))
wei

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

Then every element where `tril` is $0$, we set `wei` to $-\infin$.

In [17]:
wei = wei.masked_fill(tril == 0, float('-inf'))
wei

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

Finally, we perform a softmax. This means we exponentiate each element and take the sum (in this case along dimension 1).

In [18]:
wei = F.softmax(wei, dim=-1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

Finally, we perform the matrix multiplication.

In [19]:
xbow3 = wei @ x
xbow3[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [20]:
torch.allclose(xbow, xbow3, atol=1e-7)

True

This method will be more useful later.