This notebook demonstrates attention mechanism. Code and images are from:

http://www.peterbloem.nl/blog/transformers

In [26]:
import torch
from torch import nn
import torch.nn.functional as F

In [22]:
torch.set_printoptions(precision=4, sci_mode=False)

# Linear algebra and useful functions in PyTorch

`torch.bmm(X, Y)` performs a batch matrix-matrix product. Inputs must be 3D-tensors, first dimention is a batch.

If X is a $(b * n * m)$ tensor and Y is a $(b * m * p)$ tensor, output is a $(b * n * p)$ tensor.

https://pytorch.org/docs/stable/torch.html#torch.bmm

In [2]:
x = torch.randn(10, 3, 4)
y = torch.randn(10, 4, 5)
res = torch.bmm(x, y)
res.size()

torch.Size([10, 3, 5])

`torch.transpose(input, dim0, dim1)` (or `X.transpose(dim0, dim1)`) returns a tensor which is a trasposed version of `input`. Dimentions `dim0` and `dim1` are swapped.

https://pytorch.org/docs/stable/torch.html#torch.transpose

In [3]:
x = torch.randn(2, 3)
x

tensor([[-1.6354, -0.3138,  1.5790],
        [-0.9388,  1.3939,  0.5190]])

In [4]:
torch.transpose(x, 0, 1)

tensor([[-1.6354, -0.9388],
        [-0.3138,  1.3939],
        [ 1.5790,  0.5190]])

`torch.nn.functional.softmax(input, dim)` applies softmax along dimension `dim`.

https://pytorch.org/docs/stable/nn.functional.html#softmax

In [5]:
x = torch.randn(2, 3, 4)
x

tensor([[[-0.5387, -1.3863,  0.1120,  0.2355],
         [-2.0100,  0.8054,  0.2112,  0.4322],
         [-1.9613,  0.5597, -0.1515, -0.5779]],

        [[ 1.2681, -0.1804, -0.3403, -1.1587],
         [-2.9798,  1.0406, -0.5751, -0.0330],
         [-0.0475,  1.7537,  3.5292,  1.2491]]])

In [6]:
F.softmax(x, dim=0)

tensor([[[0.1410, 0.2304, 0.6112, 0.8013],
         [0.7251, 0.4415, 0.6870, 0.6143],
         [0.1286, 0.2326, 0.0246, 0.1386]],

        [[0.8590, 0.7696, 0.3888, 0.1987],
         [0.2749, 0.5585, 0.3130, 0.3857],
         [0.8714, 0.7674, 0.9754, 0.8614]]])

In [7]:
F.softmax(x, dim=1)

tensor([[[0.6800, 0.0590, 0.3481, 0.3758],
         [0.1561, 0.5280, 0.3844, 0.4575],
         [0.1639, 0.4130, 0.2675, 0.1666]],

        [[0.7797, 0.0884, 0.0201, 0.0658],
         [0.0111, 0.2998, 0.0159, 0.2029],
         [0.2092, 0.6117, 0.9640, 0.7313]]])

In [8]:
F.softmax(x, dim=2)

tensor([[[0.1814, 0.0777, 0.3476, 0.3933],
         [0.0260, 0.4347, 0.2400, 0.2993],
         [0.0425, 0.5285, 0.2595, 0.1694]],

        [[0.6564, 0.1542, 0.1314, 0.0580],
         [0.0115, 0.6417, 0.1275, 0.2193],
         [0.0215, 0.1303, 0.7695, 0.0787]]])

# Basic self-attention

Input vectors: $x_1, x_2,.. x_t$. Output vectors: $y_1, y_2,.. y_t$.

Self attention operation is simply a weighted average over all input vectors:

$$
y_i = \sum_j w_{ij} x_j
$$

The simplies way for weight $w_{ij}$ is the dot product to which we apply softmax:

$$
w'_{ij} = x_i^T x_j \\
w_{ij} = \frac{\exp{w'_{ij}}}{\sum_j \exp{w'_{ij}}}
$$

<img src="self-attention.svg" width="500"/>

Input is a sequence of $t$ vectors of dimension $k$, minimatch dimension $b$: a $(b * t * k)$ tensor.

In [9]:
x = torch.randn(2, 3, 4)
x

tensor([[[ 0.2688,  0.3804, -1.7762,  0.8495],
         [-0.1935, -0.3447, -0.3844,  0.7467],
         [ 1.3795, -0.3551,  0.0151, -1.9090]],

        [[-0.3196,  1.8688, -0.8605,  0.5735],
         [-0.2754, -0.9110, -0.9624, -1.8642],
         [ 1.0176, -2.2407, -0.6599,  1.0171]]])

In [10]:
raw_weights = torch.bmm(x, x.transpose(1, 2))
raw_weights

tensor([[[ 4.0935,  1.1339, -1.4127],
         [ 1.1339,  0.8616, -1.5759],
         [-1.4127, -1.5759,  5.6737]],

        [[ 4.6638, -1.8554, -3.3616],
         [-1.8554,  5.3073,  0.5000],
         [-3.3616,  0.5000,  7.5263]]])

In [12]:
weights = F.softmax(raw_weights, dim=2)
weights

tensor([[[    0.9471,     0.0491,     0.0038],
         [    0.5470,     0.4166,     0.0364],
         [    0.0008,     0.0007,     0.9985]],

        [[    0.9982,     0.0015,     0.0003],
         [    0.0008,     0.9911,     0.0081],
         [    0.0000,     0.0009,     0.9991]]])

In [13]:
y = torch.bmm(weights, x)
y

tensor([[[ 0.2504,  0.3420, -1.7010,  0.8338],
         [ 0.1166,  0.0516, -1.1312,  0.7063],
         [ 1.3775, -0.3544,  0.0133, -1.9048]],

        [[-0.3191,  1.8633, -0.8606,  0.5700],
         [-0.2650, -0.9196, -0.9599, -1.8390],
         [ 1.0164, -2.2395, -0.6602,  1.0146]]])

In [14]:
def basic_self_attention(x):
    weights = F.softmax(torch.bmm(x, x.transpose(1, 2)), dim=2)
    y = y = torch.bmm(weights, x)
    return y

basic_self_attention(x)

tensor([[[ 0.2504,  0.3420, -1.7010,  0.8338],
         [ 0.1166,  0.0516, -1.1312,  0.7063],
         [ 1.3775, -0.3544,  0.0133, -1.9048]],

        [[-0.3191,  1.8633, -0.8606,  0.5700],
         [-0.2650, -0.9196, -0.9599, -1.8390],
         [ 1.0164, -2.2395, -0.6602,  1.0146]]])

# Complete self-attention

To give self-attention controllable parameters, we compute:
* query $q_i$
* key $k_j$
* value $v_j$

This terminology is derived from key-value stores: we use query to find a key and use its value.

To compute them, we use $k * k$ weight matrices $W_q, W_k, W_v$.

$$
q_i = W_q x_i \\
k_j = W_k x_j \\
v_j = W_v x_j \\
w'_{ij} = q_i^T k_j \\
w_{ij} = softmax(w'_{ij}) \\
y_i = \sum_j w_{ij} v_j
$$

<img src="key-query-value.svg" width="400"/>

In [15]:
w_q = torch.randn(4, 4)
w_k = torch.randn(4, 4)
w_v = torch.randn(4, 4)

In [19]:
q = torch.matmul(x, w_q)
q

tensor([[[ 0.9703,  1.9240,  3.1712,  1.9942],
         [-0.2578,  1.1100,  0.6023,  0.8200],
         [ 1.8106, -2.0853,  2.7100,  0.7022]],

        [[ 0.1233,  0.7111, -0.6614, -0.8625],
         [ 2.4325,  0.7309,  1.6502,  1.1887],
         [ 0.1435,  1.2382,  4.8879,  3.7766]]])

In [21]:
k = torch.matmul(x, w_k)
v = torch.matmul(x, w_v)

In [23]:
raw_weights = torch.bmm(q, k.transpose(1, 2))
raw_weights

tensor([[[ 13.8004,   7.0465,  -7.8591],
         [  1.6687,   0.9556,   0.5353],
         [ 13.7186,   6.7265, -12.4250]],

        [[  0.0856,   2.1770,  -3.7426],
         [  1.4583,   2.7894,  10.4175],
         [  1.9852,  -7.4579,  25.9419]]])

In [24]:
weights = F.softmax(raw_weights, dim=2)
weights

tensor([[[    0.9988,     0.0012,     0.0000],
         [    0.5519,     0.2705,     0.1777],
         [    0.9991,     0.0009,     0.0000]],

        [[    0.1097,     0.8879,     0.0024],
         [    0.0001,     0.0005,     0.9994],
         [    0.0000,     0.0000,     1.0000]]])

In [25]:
y = torch.bmm(weights, x)
y

tensor([[[ 0.2683,  0.3796, -1.7746,  0.8493],
         [ 0.3411,  0.0536, -1.0815,  0.3316],
         [ 0.2684,  0.3798, -1.7749,  0.8494]],

        [[-0.2772, -0.6093, -0.9505, -1.5900],
         [ 1.0168, -2.2396, -0.6601,  1.0157],
         [ 1.0176, -2.2407, -0.6599,  1.0171]]])