<a href="https://colab.research.google.com/github/SunshineGreeny/Dive-into-deep-learning-Pytorch/blob/main/chapter_attention-mechanisms-and-transformers/multihead-attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-Head Attention



In [4]:
import math
import torch
from torch import nn

Tools

In [5]:
def masked_softmax(X, valid_lens):
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.size(1)

        # 创建掩码：对于每个位置，如果索引小于有效长度则为True，否则为False
        mask = torch.arange((maxlen), dtype=torch.float32,
                            device=X.device)[None, :] < valid_len[:, None]

        # 将掩码为False的位置设置为指定的值
        X[~mask] = value
        return X

    # 没有提供有效长度，直接返回标准sofemax
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape

        # 处理不同维度的有效长度
        if valid_lens.dim() == 1:
            # 如果是一维,复制到每个位置
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            # 如果是二维,展平
            valid_lens = valid_lens.reshape(-1)

        # 应用序列掩码,将无效位置设置为很小的值
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)

        # 应用softmax并恢复原始形状
        return nn.functional.softmax(X.reshape(shape), dim=-1)


def check_shape(tensor, expected_shape):
    assert tensor.shape == expected_shape, \
        f'Expected shape: {expected_shape}, got: {tensor.shape}'
    return tensor

Choose the scaled dot product attention
for each head

In [6]:
class DotProductAttention(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        # 获取特征维度（用于缩放）
        d = queries.shape[-1]

        # 计算注意力分数：Q * K^T / √d
        # bmm: 批量矩阵乘法
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)

        # 应用带掩码的softmax得到注意力权重
        self.attention_weights = masked_softmax(scores, valid_lens)

        # 对注意力权重应用dropout，然后与值相乘得到最终输出
        return torch.bmm(self.dropout(self.attention_weights), values)

class MultiHeadAttention(nn.Module):
    """Multi-head attention."""
    def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_o = nn.LazyLinear(num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        output = self.attention(queries, keys, values, valid_lens)
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)

Parallel computation of multiple heads

In [7]:
class MultiHeadAttention(nn.Module):
  def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
    super().__init__()
    self.num_heads=num_heads
    self.num_hiddens=num_hiddens
    self.W_q=nn.Linear(num_hiddens,num_hiddens)
    self.W_k=nn.Linear(num_hiddens,num_hiddens)
    self.W_v=nn.Linear(num_hiddens,num_hiddens)
    self.W_o=nn.Linear(num_hiddens,num_hiddens)
    self.attention=DotProductAttention(dropout)

  def transpose_qkv(self, X):
    """Transposition for parallel computation of multiple attention heads."""
    X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
    X = X.permute(0, 2, 1, 3)
    return X.reshape(-1, X.shape[2], X.shape[3])

  def transpose_output(self, X):
    """Reverse the operation of transpose_qkv."""
    X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

  def forward(self, queries, keys, values, valid_lens):
    Q=self.transpose_qkv(self.W_q(queries))
    K=self.transpose_qkv(self.W_k(keys))
    V=self.transpose_qkv(self.W_v(values))

    if valid_lens is not None:
      valid_lens=torch.repeat_interleave(valid_lens,repeats=self.num_heads,dim=0)

    output=self.attention(Q,K,V,valid_lens)
    output=self.transpose_output(output)
    return self.W_o(output)

Test our implemented

In [8]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
check_shape(attention(X, Y, Y, valid_lens),
                (batch_size, num_queries, num_hiddens))

tensor([[[ 0.3629,  0.3063,  0.2922,  0.0157,  0.7288,  0.2118, -0.1649,
          -0.0513,  0.2093, -0.0042, -0.5062,  0.1646,  0.7694, -0.2548,
          -0.1410,  0.1100,  0.5411, -0.2249, -0.5944, -0.3839,  0.3824,
           0.2624,  0.8636,  0.0092,  0.2005,  0.0354,  0.2192,  0.2643,
           0.4009, -0.2797, -0.2515, -0.2407, -0.0861, -0.1620, -0.1907,
           0.5468, -0.1624, -0.2522,  0.1096, -0.3083,  0.0548, -0.2064,
           0.3342, -0.0238,  0.0226,  0.2633,  0.0821, -0.3914, -0.5324,
          -0.2480, -0.0423,  0.1026, -0.0859, -0.1098, -0.2245,  0.3068,
          -0.0097,  0.1018, -0.3227, -0.2032,  0.1347,  0.2842, -0.0829,
          -0.1709,  0.4872,  0.2822, -0.2002,  0.5427, -0.3976, -0.3927,
          -0.1316,  0.1736,  0.0751, -0.5188, -0.2101,  0.1032,  0.0685,
          -0.4454,  0.3371,  0.5203, -0.7297,  0.6100,  0.0792,  0.1182,
          -0.1401, -0.4456,  0.2232, -0.1571, -0.2637,  0.1433, -0.1899,
           0.2908, -0.2296,  0.4885,  0.2207, -0.07