# Multi-Head Attention



In [1]:
import math
import torch
from torch import nn

Tools

In [2]:
def masked_softmax(X, valid_lens):
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.size(1)

        # 创建掩码：对于每个位置，如果索引小于有效长度则为True，否则为False
        mask = torch.arange((maxlen), dtype=torch.float32,
                            device=X.device)[None, :] < valid_len[:, None]

        # 将掩码为False的位置设置为指定的值
        X[~mask] = value
        return X

    # 没有提供有效长度，直接返回标准sofemax
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape

        # 处理不同维度的有效长度
        if valid_lens.dim() == 1:
            # 如果是一维,复制到每个位置
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            # 如果是二维,展平
            valid_lens = valid_lens.reshape(-1)

        # 应用序列掩码,将无效位置设置为很小的值
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)

        # 应用softmax并恢复原始形状
        return nn.functional.softmax(X.reshape(shape), dim=-1)

class DotProductAttention(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        # 获取特征维度（用于缩放）
        d = queries.shape[-1]

        # 计算注意力分数：Q * K^T / √d
        # bmm: 批量矩阵乘法
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)

        # 应用带掩码的softmax得到注意力权重
        self.attention_weights = masked_softmax(scores, valid_lens)

        # 对注意力权重应用dropout，然后与值相乘得到最终输出
        return torch.bmm(self.dropout(self.attention_weights), values)

def check_shape(tensor, expected_shape):
    assert tensor.shape == expected_shape, \
        f'Expected shape: {expected_shape}, got: {tensor.shape}'
    return tensor

Choose the scaled dot product attention
for each head

In [3]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention."""
    def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_o = nn.LazyLinear(num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        output = self.attention(queries, keys, values, valid_lens)
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)

Parallel computation of multiple heads

In [4]:
class MultiHeadAttention(nn.Module):
  def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
    super().__init__()
    self.num_heads=num_heads
    self.num_hiddens=num_hiddens
    self.W_q=nn.Linear(num_hiddens,num_hiddens)
    self.W_k=nn.Linear(num_hiddens,num_hiddens)
    self.W_v=nn.Linear(num_hiddens,num_hiddens)
    self.W_o=nn.Linear(num_hiddens,num_hiddens)
    self.attention=DotProductAttention(dropout)

  def transpose_qkv(self, X):
    """Transposition for parallel computation of multiple attention heads."""
    X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
    X = X.permute(0, 2, 1, 3)
    return X.reshape(-1, X.shape[2], X.shape[3])

  def transpose_output(self, X):
    """Reverse the operation of transpose_qkv."""
    X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

  def forward(self, queries, keys, values, valid_lens):
    Q=self.transpose_qkv(self.W_q(queries))
    K=self.transpose_qkv(self.W_k(keys))
    V=self.transpose_qkv(self.W_v(values))

    if valid_lens is not None:
      valid_lens=torch.repeat_interleave(valid_lens,repeats=self.num_heads,dim=0)

    output=self.attention(Q,K,V,valid_lens)
    output=self.transpose_output(output)
    return self.W_o(output)

Test our implemented

In [5]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
check_shape(attention(X, Y, Y, valid_lens),
                (batch_size, num_queries, num_hiddens))

tensor([[[ 3.1588e-01, -5.3620e-02, -5.3872e-01,  1.3297e-02,  8.4461e-01,
           3.7934e-01,  3.3786e-01,  3.6676e-01, -2.4644e-01,  6.9524e-01,
          -5.3751e-01, -1.2681e+00, -2.4270e-01,  5.9949e-01,  2.0265e-01,
           1.6614e-01,  3.2251e-01,  5.8388e-01,  7.1025e-01,  1.0319e-01,
           8.9562e-01, -3.3539e-01, -1.5246e-01,  5.6413e-01,  5.2562e-02,
           5.7807e-01,  1.2123e-01,  3.6487e-01,  4.5249e-01, -5.4267e-01,
          -7.8027e-02,  3.3589e-01, -4.2924e-01,  4.9593e-01,  2.6356e-01,
          -4.8317e-01,  1.1016e+00, -4.1329e-02,  2.3063e-01,  1.2471e-01,
          -5.8109e-01, -2.8962e-01, -6.7634e-01, -7.4485e-01, -5.7372e-01,
          -4.7555e-01, -4.3039e-01,  8.9539e-01, -1.3001e-01, -4.7291e-01,
          -7.2822e-01,  6.3764e-01,  3.0498e-01, -3.5024e-01, -9.9975e-02,
          -6.7348e-01, -4.3543e-01,  5.3985e-01, -5.5539e-01, -4.9863e-01,
          -1.6256e-01,  2.2108e-01,  1.0541e+00, -8.7572e-01, -3.2463e-01,
           3.5782e-02,  2