In [10]:
import math
import torch
from torch import nn
from d2l import torch as d2l

In [11]:
#@save
class MultiHeadAttention(nn.Module):
    """Multi-head attention."""
    def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads,
                 dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        
    def forward(self, queries, keys, values, valid_lens):
        # shape of input queries, keys, values:
        # (batch_size, num_queries or num_keys, query_size or key_size or value_size)
        
        # shape of queries, keys, values after fully connected:
        # (batch_size, num_queries or num_keys, num_hiddens)
        # shape of valid_lens:
        # (batch_size,) or (batch_size, num_queries)
        
        # after transposing, shape of output queries, keys, values:
        # (batch_size*num_heads, num_queries or num_keys, num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        
        if valid_lens is not None:
            # at dim 0, copy the valid_lens by num_heads times
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
        
        # shape of output: (batch_size*num_heads, num_queries, num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)
        
        # shape of output_concat: (batch_size, num_queries, num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

In [13]:
#@save
def transpose_qkv(X, num_heads):
    """transform shape for the parallel computation of multiple attention heads"""
    # shape of input X: (batch_size, num_queries or num_keys, num_hiddens)
    # shape of output X: (batch_size, num_queries or num_keys, num_heads, num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    
    # shape of output X: (batch_size, num_heads, num_queries or num_keys, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    
    # shape of output X: (batch_size*num_heads, num_queries or num_keys, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

#@save
def transpose_output(X, num_heads):
    """reverse the operation of transpose_qkv"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

In [14]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(key_size=num_hiddens, query_size=num_hiddens,
                               value_size=num_hiddens, num_hiddens=num_hiddens,
                               num_heads=num_heads, dropout=0.5)
attention.eval()

MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)

In [16]:
batch_size, num_queries = 2, 4
num_kvpairs = 6
valid_lens = torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape # (batch_size, num_queries, num_hiddens)

torch.Size([2, 4, 100])