In [2]:
import torch
from torch import nn

$$f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v,$$

$$a(\mathbf q, \mathbf k) = \mathbf w_v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R},$$

In [3]:
batch_size = 2
query_size = 10
num_keys = 5
key_size = 2
value_size = 4

queries = torch.normal(0,1,(batch_size,1,query_size)) # (batch-size * number of queries * each query has 10 features)
keys = torch.ones((batch_size,num_keys,key_size))         # (batch-size * number key value pairs * each key has 2 features)
val_instance = torch.arange(num_keys*value_size, dtype=torch.float32).reshape(1,num_keys,value_size)
values = val_instance.repeat(batch_size,1,1)    # (batch-size * number key value pairs * each value has 4 features)
print("q size: ", queries.shape)
print("k size: ", keys.shape)
print("v size: ", values.shape)

q size:  torch.Size([2, 1, 10])
k size:  torch.Size([2, 5, 2])
v size:  torch.Size([2, 5, 4])


In [4]:
num_hiddens = 20
W_k = nn.Linear(key_size, num_hiddens, bias=False)
W_q = nn.Linear(query_size, num_hiddens, bias=False)
w_v = nn.Linear(num_hiddens, 1, bias=False)
Wq, Wk = W_q(queries), W_k(keys)
print("W_q*q size:", Wq.shape)
print("* Note: (batch_size, num_queries, num_hidden)")
print("W_k*k size:", Wk.shape)
print("* Note: (batch_size, number of key value pairs, num_hidden)")

W_q*q size: torch.Size([2, 1, 20])
* Note: (batch_size, num_queries, num_hidden)
W_k*k size: torch.Size([2, 5, 20])
* Note: (batch_size, number of key value pairs, num_hidden)


In [5]:
features = Wq.unsqueeze(2) + Wk.unsqueeze(1)
features = torch.tanh(features)
print("W_q*q+W_k*k size: ",features.shape)
scores = w_v(features).squeeze(-1)
print("attention score size: ",scores.shape)

W_q*q+W_k*k size:  torch.Size([2, 1, 5, 20])
attention score size:  torch.Size([2, 1, 5])


Note value size = 4.

$f(\mathbf{q}) = \alpha^{2\times1\times5} \mathbf{v}_i^{5\times4} \in \mathbb{R}^v,$

In [6]:
attention_weights = torch.softmax(scores, dim=1)
f = torch.bmm(attention_weights, values)
print("f(q) size: ",f.shape)

f(q) size:  torch.Size([2, 1, 4])
