<a href="https://colab.research.google.com/github/SunshineGreeny/Dive-into-deep-learning-Pytorch/blob/colab-experiments/chapter_attention-mechanisms-and-transformers/multihead-attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-Head Attention



In [13]:
import math
import torch
from torch import nn


Tools

In [14]:
def masked_softmax(X, valid_lens):
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.size(1)

        # 创建掩码：对于每个位置，如果索引小于有效长度则为True，否则为False
        mask = torch.arange((maxlen), dtype=torch.float32,
                            device=X.device)[None, :] < valid_len[:, None]

        # 将掩码为False的位置设置为指定的值
        X[~mask] = value
        return X

    # 没有提供有效长度，直接返回标准sofemax
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape

        # 处理不同维度的有效长度
        if valid_lens.dim() == 1:
            # 如果是一维,复制到每个位置
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            # 如果是二维,展平
            valid_lens = valid_lens.reshape(-1)

        # 应用序列掩码,将无效位置设置为很小的值
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)

        # 应用softmax并恢复原始形状
        return nn.functional.softmax(X.reshape(shape), dim=-1)

class DotProductAttention(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        # 获取特征维度（用于缩放）
        d = queries.shape[-1]

        # 计算注意力分数：Q * K^T / √d
        # bmm: 批量矩阵乘法
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)

        # 应用带掩码的softmax得到注意力权重
        self.attention_weights = masked_softmax(scores, valid_lens)

        # 对注意力权重应用dropout，然后与值相乘得到最终输出
        return torch.bmm(self.dropout(self.attention_weights), values)

def check_shape(tensor, expected_shape):
    assert tensor.shape == expected_shape, \
        f'Expected shape: {expected_shape}, got: {tensor.shape}'
    return tensor

Choose the scaled dot product attention
for each head

In [15]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention."""
    def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_o = nn.LazyLinear(num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        output = self.attention(queries, keys, values, valid_lens)
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)

Parallel computation of multiple heads

In [16]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_hiddens, num_heads, dropout=0.0):
        super().__init__()
        self.num_heads = num_heads
        self.num_hiddens = num_hiddens
        # 线性映射 qkv 和输出
        self.W_q = nn.Linear(num_hiddens, num_hiddens)
        self.W_k = nn.Linear(num_hiddens, num_hiddens)
        self.W_v = nn.Linear(num_hiddens, num_hiddens)
        self.W_o = nn.Linear(num_hiddens, num_hiddens)
        self.dropout = nn.Dropout(dropout)
        self.attention = DotProductAttention(dropout) # Add attention module

    def transpose_qkv(self, X):
        """Transposition for parallel computation of multiple attention heads."""
        # X: (batch_size, seq_len, num_hiddens)
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
        X = X.permute(0, 2, 1, 3)  # (B, num_heads, seq_len, head_dim)
        return X.reshape(-1, X.shape[2], X.shape[3])  # 合并 batch 和 head

    def transpose_output(self, X):
        """Reverse the operation of transpose_qkv."""
        # X: (B*num_heads, seq_len, head_dim)
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        X = X.permute(0, 2, 1, 3)  # (B, seq_len, num_heads, head_dim)
        return X.reshape(X.shape[0], X.shape[1], -1)  # 合并 head_dim

    def forward(self, queries, keys, values, valid_lens):
        # 假设 queries/keys/values 都是 (B, T, num_hiddens)
        Q = self.transpose_qkv(self.W_q(queries))
        K = self.transpose_qkv(self.W_k(keys))
        V = self.transpose_qkv(self.W_v(values))

        # Repeat valid_lens for each head
        if valid_lens is not None:
             # valid_lens shape: (batch_size,)
             # Repeat for each head
             valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
             # valid_lens shape after repeat: (batch_size * num_heads,)


        # Here is the scaled dot-product attention
        output = self.attention(Q, K, V, valid_lens) # Use the attention module
        output = self.transpose_output(output)
        return self.W_o(output)

Test our implemented

In [17]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
check_shape(attention(X, Y, Y, valid_lens),
                (batch_size, num_queries, num_hiddens))

tensor([[[-5.0682e-01, -9.1867e-02, -9.6111e-02,  3.5473e-01,  4.1095e-01,
           5.9812e-01, -1.0575e-01, -1.7675e-01,  4.7958e-01,  4.4001e-01,
          -6.4413e-01,  6.1430e-01,  1.9326e-01, -2.2981e-02, -2.0853e-01,
          -5.9309e-02,  9.8570e-01,  4.1541e-01, -8.1727e-01,  2.6197e-01,
           4.5640e-01,  4.1301e-02,  3.0625e-01, -4.3465e-01,  5.2807e-02,
           7.6951e-02,  2.5736e-01,  7.4846e-02, -3.3225e-01,  3.3010e-01,
          -4.2965e-01,  3.9745e-01, -1.3902e-01, -3.9534e-01, -6.4195e-01,
          -1.7562e-01,  2.5535e-01,  6.6362e-01,  5.9950e-01, -2.9680e-01,
           4.2467e-01,  7.9417e-02, -5.6159e-01,  2.3377e-01, -2.9319e-01,
          -7.0728e-01,  2.0414e-01, -2.5556e-01,  3.5403e-01, -3.0231e-01,
          -3.6119e-02,  6.9782e-01, -1.2039e+00, -1.9244e-03, -2.7149e-01,
          -2.3990e-01,  2.7966e-01, -2.2049e-01, -4.8323e-02, -9.8118e-01,
           2.1713e-02,  2.7151e-01, -3.9364e-01, -7.9916e-01, -1.3124e-01,
           1.4663e-01,  2