# 多头注意力的实现

In [ ]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % self.num_heads == 0

        # Define the dimension of each head or subspace
        self.d_k = d_model // self.num_heads

        # These are still of dimension d_model. They will be split into number of heads 
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # Outputs of all sub-layers need to be of dimension d_model
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        # linear layers
        q = self.W_q(q)
        k = self.W_k(k)
        v = self.W_v(v)

        # split into multiple heads
        q = q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        k = k.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = v.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # self attention
        # Scaling by d_k so that the soft(arg)max doesn't explode
        # [Batch Size, num_heads, seq_len, d_k] * [Batch Size, num_heads, d_k, seq_len] = [Batch Size, num_heads, seq_len, seq_len]
        QK = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Apply the mask
        if mask is not None:
            QK = QK.masked_fill(mask.to(QK.dtype) == 0, float('-inf'))

        # Calculate the attention weights (softmax over the last dimension)
        weights = F.softmax(QK, dim=-1)

        # Apply the self attention to the values
        # [Batch Size, num_heads, seq_len, seq_len] * [Batch Size, num_heads, seq_len, d_k] = [Batch Size, num_heads, seq_len, d_k]
        scores = torch.matmul(weights, v)

        # concatenate heads 
        concat = scores.transpose(1, 2).view(batch_size, -1, self.d_model)

        # final linear layer
        output = self.W_o(concat)

        return output, weights