(Based on http://www.peterbloem.nl/blog/transformers)

# Basic attention

The attention mechanism is just a sequence to sequence function. Given a sequence of vector inputs $x_1, x_2,\ldots, x_t \in \mathbb{R}^d$, the attention mechanism produces an output sequence of vector inputs $y_1, y_2, \ldots, y_t \in \mathbb{R}^d$ by doing the following:
$$y_i = \sum_{j} w_{ij} x_j.$$
where
$$w_{ij} = f(x_i, x_j).$$
Often, $f(x_i, x_j)$ is taken to be $x_i^Tx_j$. In this case, $y_i$ is the weighted sum of all $x_1,\ldots, x_t$ weighted by the similarity of $x_i$ and all $x_1,\ldots x_t$. Also, it is conventional to normalize the weights by a softmax: $w_{ij} \leftarrow \frac{\exp(w_{ij})}{\sum_{j} \exp(w_{ij})}$.

# An example of basic attention

In [8]:
import torch
import torch.nn.functional as F


def basic_attention(x):
    w = x.matmul(x.transpose(-1, -2))
    normalized_w = F.softmax(w, dim = -1)
    y = torch.matmul(normalized_w, x)
    return y

In [10]:
t, d = 4, 2
x = torch.randn(t, d)
y = basic_attention(x)
print(y)

tensor([[-1.2173,  2.4251],
        [-0.6028, -0.6962],
        [ 2.3391,  1.3497],
        [-0.9528, -1.0761]])


Notice that attention is permuation equivariant. If you permute the input, then the output is permuted in the same way.

In [12]:
perm = torch.LongTensor([0, 1, 3, 2])
x_new = x[perm]
y_new = basic_attention(x_new)
print(y_new)

tensor([[-1.2173,  2.4251],
        [-0.6028, -0.6962],
        [-0.9528, -1.0761],
        [ 2.3391,  1.3497]])


Here the last two vectors are permuted in the input and the output of attention.

 # Query, key, value
 
 Notice that the basic attention mechanism has no tunable parameters. We can give it some tunable parameters by noticing that each input vector $x_i$ appears three times in the computation:
 1. Computing the weights $x_i^Tx_j$ for all $j$
 2. Computing the weights $x_j^Tx_i$ for all $j$
 3. Acting as a basis vector $w_{i'i} x_i$
Thus we can introduce three linear maps $W_q, W_k, W_v: \mathbb{R}^d \to \mathbb{R}^d$ to compute the query, key, and value such that 
$$q_i = W_q x_i$$
$$k_i = W_k x_i$$
$$v_i = W_v x_i$$
$$w'_{ij} = q_i^T k_j$$
$$w_{ij} \leftarrow softmax(w'_{i1}, \ldots, w'_{it})$$
$$y_i = \sum_{j} w_{ij} v_j$$

In [13]:
wq = torch.randn(d, d)
wk = torch.randn(d, d)
wv = torch.randn(d, d)
def qkv_attention(x):
    d = x.size(-1)
    q = x.matmul(wq)
    k = x.matmul(wk)
    v = x.matmul(wv)
    w = q.matmul(k.transpose(-1, -2))
    y = w.matmul(v)
    return y

Recall that attention was permutation equivariant. In fact, even with query, keys, and values, it is still equivariant. 

In [14]:
t, d = 4, 2
x = torch.randn(t, d)
y = qkv_attention(x)
print(y)

tensor([[  2.9639,  22.1648],
        [ -2.6682, -11.1776],
        [  1.1420,   1.5704],
        [  2.7934,  17.0057]])


In [15]:
perm = torch.LongTensor([0, 1, 3, 2])
x_new = x[perm]
y_new = qkv_attention(x_new)
print(y_new)

tensor([[  2.9639,  22.1649],
        [ -2.6682, -11.1776],
        [  2.7934,  17.0057],
        [  1.1420,   1.5704]])


# Scaling the inner product

The softmax can be very sensitive to extreme values due to the exponentiation. This can cause vanishing gradients for the non-outliers.
Since the expected length of random normal vectors in $d$ dimensions is roughly $\sqrt d$, we rescale the inner products before the softmax:
$$w'_{ij} = \frac{q_i^T k_j}{\sqrt{d}}.$$

# Multi-head attention

We can introduce multi-head attention which creates $h$ query, key, and value linear maps and concatenate all the $y_i^r$ before reducing the dimension back down to $d$.
$$q_i^r = W_q^r x_i$$
$$k_i^r = W_k^r x_i$$
$$v_i^r = W_v^r x_i$$
$$w_{ij}^{'r} = q_i^{rT} k_j^r$$
$$w_{ij}^r \leftarrow softmax(w_{i1}^{'r}, \ldots, w_{it}^{'r})$$
$$y_i^r = \sum_{j} w_{ij}^r v_j^r$$
$$y_i = W stack(y_i^1, \ldots, y_i^h)$$
where $W\in \mathbb{R}^{d\times hd}$.