# Attention

http://nlp.seas.harvard.edu/2018/04/03/attention.html

In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import math

In [2]:
torch.__version__

'0.4.1'

In [3]:
q1 = torch.ones(2); q1

tensor([1., 1.])

In [4]:
k1 = 2 * torch.ones(2); k1

tensor([2., 2.])

In [5]:
v1 = 3 * torch.ones(2); v1

tensor([3., 3.])

In [6]:
k2 = torch.ones(2); k2

tensor([1., 1.])

In [7]:
v2 = 4 * torch.ones(2); v2

tensor([4., 4.])

We first calculate the compatability between q1 and k1. A simple compatability function is simply the dot product.

In [8]:
C1 = torch.matmul(q1, k1); C1

tensor(4.)

Similarly between q1 and k2.

In [9]:
C2 = torch.matmul(q1, k2); C2

tensor(2.)

Okay, we could just use these raw values to multiply our values by...but all things are relative.
So we do a softmax.

In [10]:
torch.tensor([C1, C2])

tensor([4., 2.])

In [11]:
p_attn = F.softmax(torch.tensor([C1, C2]), dim=0); p_attn

tensor([0.8808, 0.1192])

Note how the softmax transformed our scores of 4 and 2 to a probability vector (i.e. the sum of the elements is 1). So the weight that we're going to apply to our value vectors v1 and v2 are 0.8808 and 0.1192 respectively.

However, note how if I scale our scores down from (4, 2) to (2, 1), the probability vector changes to 0.7311 and 0.2689. This will be important to remember later.

In [12]:
F.softmax(torch.tensor([2.0, 1]), dim=0)

tensor([0.7311, 0.2689])

Great, so let's compute our output!

0.8808 * v1

In [13]:
p_attn[0] * v1

tensor([2.6424, 2.6424])

0.1192 * v2

In [14]:
p_attn[1] * v2

tensor([0.4768, 0.4768])

In [15]:
output = (p_attn[0] * v1) + (p_attn[1] * v2); output

tensor([3.1192, 3.1192])

Now, let's do this in vector/matrix form!

In [16]:
Q = q1.unsqueeze(0); Q

tensor([[1., 1.]])

In [17]:
K = torch.stack([k1, k2], dim=0); K

tensor([[2., 2.],
        [1., 1.]])

In [18]:
V = torch.stack([v1, v2], dim=0); V

tensor([[3., 3.],
        [4., 4.]])

In [19]:
def dotproduct_attention(query, key, value, mask=None, dropout=None):
    "Compute Dot Product Attention"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1))
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return scores, torch.matmul(p_attn, value), p_attn

In [20]:
dotproduct_attention(Q, K, V)

(tensor([[4., 2.]]), tensor([[3.1192, 3.1192]]), tensor([[0.8808, 0.1192]]))

The problem with dot product as the main compatability function is that if we increase the dimensionality of our query and key vectors, the softmax score will change. 

In [21]:
Q = torch.ones(1, 8)

In [22]:
K = torch.stack([2 * torch.ones(8), torch.ones(8)])

In [23]:
V = torch.stack([3 * torch.ones(8), 4 * torch.ones(8)])

In [24]:
dotproduct_attention(Q, K, V)

(tensor([[16.,  8.]]),
 tensor([[3.0003, 3.0003, 3.0003, 3.0003, 3.0003, 3.0003, 3.0003, 3.0003]]),
 tensor([[0.9997, 0.0003]]))

By increasing our dimensionality to 8, the compatability score will yield:

C(q1, k1) = 16
and
C(q1, k2) = 8

This will give much more weight to v1 than when C(q1, k1) = 4 and C(q1, k2) = 2.

We mitigate this by dividing the score by the sqrt of the d_k dimension.

In [25]:
def scaled_dotproduct_attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    print(scores)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return scores, torch.matmul(p_attn, value), p_attn

In [26]:
scaled_dotproduct_attention(Q, K, V)

tensor([[5.6569, 2.8284]])


(tensor([[5.6569, 2.8284]]),
 tensor([[3.0558, 3.0558, 3.0558, 3.0558, 3.0558, 3.0558, 3.0558, 3.0558]]),
 tensor([[0.9442, 0.0558]]))

# Let's experiment with dimensions

In [58]:
Q = torch.ones(3, 8, 4)

K = torch.ones(3, 3, 8, 4)
V = torch.ones(3, 3, 8, 4)

In [59]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value)

In [60]:
o = attention(Q, K, V)

In [61]:
o.size()

torch.Size([3, 3, 8, 4])