In [1]:
# Notebook with a very simple attention mechanism
# Author: Szymon Manduk
# Date: Feb 09, 2023

In [2]:
import torch

In [3]:
# We have input vector X of 10 tokens (rows), each token embedded in 32 dimensions (columns)
X = torch.randn(10, 32)

# Query parameters matrix Q is 32x32
pQ = torch.randn(32, 32)

# Key parameters matrix K is 32x32
pK = torch.randn(32, 32)

# Value parameters matrix V is 32x32
pV = torch.randn(32, 32)


In [4]:
# step 1 of attention: multiply parametric query matrix pQ with X
# pQ is 32x32, X is 10x32, so the result is 10x32
Q = torch.matmul(X, pQ)

# step 2 of attention: multiply parametric key matrix pK with X
# pK is 32x32, X is 10x32, so the result is 10x32
K = torch.matmul(X, pK)

# step 3 of attention: multiply parametric value matrix pV with X
# pV is 32x32, X is 10x32, so the result is 10x32
V = torch.matmul(X, pV)

In [5]:
# step 4 of attention: calculate the attention weights
# Q is 10x32, K is 10x32, so the result is 10x10
# we need to transpose K to make it 32x10
# we need to divide by sqrt(32) to normalize the weights
weights = torch.matmul(Q, torch.transpose(K, 0, 1)) / torch.sqrt(torch.tensor(32))

In [6]:
# step 5 of attention: apply softmax to the weights
# weights is 10x10, so the result is 10x10
weights = torch.softmax(weights, dim=1)

In [7]:
# step 6 of attention: multiply the weights with the values
# weights is 10x10, V is 10x32, so the result is 10x32
output = torch.matmul(weights, V)

In [8]:
# all those operation can be done in one line
output2 = torch.matmul(torch.softmax(torch.matmul(torch.matmul(X, pQ), torch.transpose(torch.matmul(X, pK), 0, 1)) / torch.sqrt(torch.tensor(32)), dim=1), torch.matmul(X, pV))

In [9]:
print(torch.allclose(output, output2))

True


In [10]:
t = torch.randn(4,4)
t

tensor([[-0.7028, -0.6439,  1.5270,  0.5268],
        [-0.9692, -0.2577, -0.2966,  0.1289],
        [-0.8208, -1.0258,  0.3171,  0.9205],
        [-0.6280,  0.3318, -0.2130,  0.0966]])

In [11]:
# stack tensors one next to another along x axis, keeping dimension 0

torch.stack([t, t, t, t, t], dim=0).shape

torch.Size([5, 4, 4])

In [12]:
out = torch.empty(t.shape[0], 0)
for _ in range(5):
    out = torch.cat((out, t), dim=1)
out.shape

torch.Size([4, 20])

In [13]:
out

tensor([[-0.7028, -0.6439,  1.5270,  0.5268, -0.7028, -0.6439,  1.5270,  0.5268,
         -0.7028, -0.6439,  1.5270,  0.5268, -0.7028, -0.6439,  1.5270,  0.5268,
         -0.7028, -0.6439,  1.5270,  0.5268],
        [-0.9692, -0.2577, -0.2966,  0.1289, -0.9692, -0.2577, -0.2966,  0.1289,
         -0.9692, -0.2577, -0.2966,  0.1289, -0.9692, -0.2577, -0.2966,  0.1289,
         -0.9692, -0.2577, -0.2966,  0.1289],
        [-0.8208, -1.0258,  0.3171,  0.9205, -0.8208, -1.0258,  0.3171,  0.9205,
         -0.8208, -1.0258,  0.3171,  0.9205, -0.8208, -1.0258,  0.3171,  0.9205,
         -0.8208, -1.0258,  0.3171,  0.9205],
        [-0.6280,  0.3318, -0.2130,  0.0966, -0.6280,  0.3318, -0.2130,  0.0966,
         -0.6280,  0.3318, -0.2130,  0.0966, -0.6280,  0.3318, -0.2130,  0.0966,
         -0.6280,  0.3318, -0.2130,  0.0966]])

In [14]:
m = torch.randn(8,8)
m

tensor([[ 0.8136, -0.0446,  0.4911,  0.8046,  1.3285,  0.8012, -0.8975, -0.2174],
        [ 0.3294, -1.0827, -0.6132, -1.1408, -0.8309,  0.4316,  1.2974, -2.2530],
        [-0.0656, -0.1330,  0.3573,  0.7055, -0.5388, -0.6816, -0.1902,  0.2447],
        [ 0.4911, -0.2504,  0.3871,  0.4993, -0.2478, -1.0736,  1.4788,  0.0781],
        [ 0.7119, -0.7082, -1.5157,  0.0637, -1.2900,  0.1505, -2.1048, -1.0875],
        [ 0.4207,  0.6608, -0.9840, -1.2339,  0.1558,  0.1938, -1.0271, -1.6201],
        [-0.2279, -1.0045, -0.4449,  0.7974,  1.1239, -1.1883, -0.6843,  1.9002],
        [-0.4206,  0.3109,  1.0727,  0.6116, -3.4153,  0.3660, -1.5124, -1.8122]])

In [15]:
mask = torch.tril(torch.ones(8,8), diagonal=0)
mask

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [16]:
m = m.masked_fill(mask == 0, float('-inf'))
m

tensor([[ 0.8136,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.3294, -1.0827,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0656, -0.1330,  0.3573,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.4911, -0.2504,  0.3871,  0.4993,    -inf,    -inf,    -inf,    -inf],
        [ 0.7119, -0.7082, -1.5157,  0.0637, -1.2900,    -inf,    -inf,    -inf],
        [ 0.4207,  0.6608, -0.9840, -1.2339,  0.1558,  0.1938,    -inf,    -inf],
        [-0.2279, -1.0045, -0.4449,  0.7974,  1.1239, -1.1883, -0.6843,    -inf],
        [-0.4206,  0.3109,  1.0727,  0.6116, -3.4153,  0.3660, -1.5124, -1.8122]])

In [17]:
torch.softmax(m, dim=1)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8041, 0.1959, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2889, 0.2701, 0.4410, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2953, 0.1407, 0.2662, 0.2978, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4981, 0.1204, 0.0537, 0.2605, 0.0673, 0.0000, 0.0000, 0.0000],
        [0.2341, 0.2976, 0.0575, 0.0447, 0.1796, 0.1865, 0.0000, 0.0000],
        [0.1007, 0.0463, 0.0810, 0.2807, 0.3890, 0.0385, 0.0638, 0.0000],
        [0.0759, 0.1578, 0.3381, 0.2132, 0.0038, 0.1668, 0.0255, 0.0189]])