Self-attention is a sequence-to-sequence operation: a sequence of vectors goes in, and a sequence of vectors comes out. Let’s call the input vectors 𝐱1,𝐱2,…,𝐱t and the corresponding output vectors 𝐲1,𝐲2,…,𝐲t. The vectors all have dimension k.

To produce output vector 𝐲i, the self attention operation simply takes a weighted average over all the input vectors

# 𝐲i=∑jwij𝐱j.
Where j indexes over the whole sequence and the weights sum to one over all j. The weight wij is not a parameter, as in a normal neural net, but it is derived from a function over 𝐱i and 𝐱j. The simplest option for this function is the dot product:

# w′ij=𝐱iT𝐱j.
Note that 𝐱i is the input vector at the same position as the current output vector 𝐲i. For the next output vector, we get an entirely new series of dot products, and a different weighted sum.
The dot product gives us a value anywhere between negative and positive infinity, so we apply a softmax to map the values to [0,1] and to ensure that they sum to 1 over the whole sequence:

# wij=exp w′ij∑jexp w′ij.
And that’s the basic operation of self attention.

<img src="http://peterbloem.nl/files/transformers/self-attention.svg"/>
ref: <a href="http://peterbloem.nl/blog/transformers">http://peterbloem.nl/blog/transformers</a>

In [5]:
import torch
import torch.nn.functional as f
import numpy as np

In [35]:
embedding = torch.nn.Embedding(1000,128)

In [54]:
x1 = embedding(torch.LongTensor([1]))
x2 = embedding(torch.LongTensor([2]))
x3 = embedding(torch.LongTensor([3]))
x = torch.stack([x1, x2, x3])
x.size()

torch.Size([3, 1, 128])

In [55]:
wij = torch.bmm(x, x.transpose(1,2))
wij.size()

torch.Size([3, 1, 1])

In [56]:
wij = f.softmax(wij, dim=2)

In [57]:
y = torch.bmm(wij, x)

In [58]:
y

tensor([[[ 0.0235,  0.0644, -1.0557,  0.3499,  0.8542,  1.0027, -0.7468,
          -0.7461,  0.8663, -0.3890,  0.4735,  2.3888,  0.9989, -0.7964,
           0.7887, -0.2837, -0.0906,  0.3525, -0.2325, -1.0213,  0.7719,
          -0.3725,  0.2081,  0.6937,  0.9120, -0.9780,  0.1345,  0.6599,
           0.2458, -0.8819,  0.1030, -0.9069, -0.6137,  0.5116, -0.1451,
          -0.0123,  0.2573, -0.9507, -0.2157, -0.3429, -0.1541, -0.2230,
          -1.2359,  0.4929, -1.4787, -0.5101,  2.3545, -0.0465,  1.3382,
          -0.3760,  1.2486, -1.9955, -0.3367, -1.4961,  1.0573,  0.4445,
           2.0985,  1.2877, -0.1124,  1.9950, -0.2542,  0.3731, -0.4743,
           2.5063, -0.9619, -0.0963, -0.0860, -0.2414, -1.1402, -1.2531,
           0.7519,  0.3334, -0.1457, -0.7368,  0.7149, -1.8806, -0.7455,
           0.4977, -0.9126, -1.1121,  2.6944,  0.9263,  0.7032, -0.6567,
           2.0906, -0.1071, -1.4508,  0.0037,  0.6656,  0.0280,  1.0099,
          -1.5742, -0.1638,  0.3440,  1.4420,  0.18