# Multi-head Attention
In practical, we hope model could base on the attention mechanism to learn different behaviour as knowledge and combine them together, catching dependency relations among sequences when given same query, key and value. So it might be beneficial if we allow attention to use different representation subspaces of query, key and value and combine them together.
So, we could use different linear projections learnt by us independently to transform query, key and value. Then they will be sent into average pooling parallely and concatenated together, and projected by another linear projection, to generate the final output. For each attention pooling, we call it a *head*.
![multihead](../statics/imgs/section10.5_fig1.jpg)
Each attention head $\mathbf{h}_i$ could be calculated as
$$
\mathbf{h}_i = f(\mathbf{W}_i^{(q)}\mathbf{q}, \mathbf{W}_i^{(k)}\mathbf{k}, \mathbf{W}_i^{(v)}\mathbf{v})
$$
where $f$ represents average pooling function. It could be additive or dot-product scalar attention.

In [1]:
import math
import torch
from torch import nn
from d2l import torch as d2l

In [5]:
def transpose_qkv(X, num_heads):
    """transpose the shape of input vectors for parallel computation of the multi-head attention"""
    """
    (batch_size, len(queries or kv pairs), num_hiddens) -> (batch_size, len(queries or kv pairs), num_heads, num_hiddens/num_heads)
    """
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    X = X.permute(0, 2, 1, 3)
    # output: (batch_size * num_heads, len(queries or kv pairs), num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    """reverse the operation of transpose_qkv"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries, keys, values: (batch_size, len(queries) or len(kv pairs), num_hiddens)
        # valid_lens: (batch_size, len(queries)) or (batch_size, )
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

        # output: (batch_size*num_heads, len(queries), num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

In [7]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5)
attention.eval()

MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)

In [8]:
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape

torch.Size([2, 4, 100])