# 11.5. Multi-Head Attention

In [21]:
import math
import torch
from torch import nn
from d2l import torch as d2l

In [22]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention
    """
    def __init__(self, d_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.LazyLinear(d_hiddens, bias=bias)
        self.W_k = nn.LazyLinear(d_hiddens, bias=bias)
        self.W_v = nn.LazyLinear(d_hiddens, bias=bias)
        self.W_out = nn.LazyLinear(d_hiddens, bias=bias)
        
    def forward(self, queries, keys, values, attn_mask):
        # Shape of queries, keys, or values: (batch_size, no. of queries or key-value pairs, d_hiddens)
        # Shape of attn_mask: (batch_size,) or (batch_size, no. of queries)
        # After transposing, shape of output queries, keys, or values: (batch_size * num_heads, no. of queries or key-value pairs, d_hiddens / num_heads)
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))
        
        if attn_mask is not None:
            attn_mask = torch.repeat_interleave(attn_mask, repeats=self.num_heads, dim=0)
            
        output = self.attention(queries, keys, values, attn_mask)
        output_concat = self.transpose_output(output)
        return self.W_out(output_concat)

In [23]:
@d2l.add_to_class(MultiHeadAttention)  #@save
def transpose_qkv(self, X):
    """
    Transposition for parallel computation of multiple attention heads.
    """
    # Shape of input X: (batch_size, no. of queries or key-value pairs, d_hiddens). 
    # Shape of output X: (batch_size, no. of queries or key-value pairs, num_heads, d_hiddens / num_heads)
    X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
    # Shape of output X: (batch_size, num_heads, no. of queries or key-value pairs, d_hiddens / num_heads)
    X = X.permute(0, 2, 1, 3)
    # Shape of output: (batch_size * num_heads, no. of queries or key-value pairs, d_hiddens / num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

In [24]:
@d2l.add_to_class(MultiHeadAttention)  #@save
def transpose_output(self, X):
    """
    Reverse the operation of transpose_qkv.
    """
    X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

In [25]:
d_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(d_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, d_hiddens))
Y = torch.ones((batch_size, num_kvpairs, d_hiddens))
d2l.check_shape(attention(X, Y, Y, valid_lens), (batch_size, num_queries, d_hiddens))