In [1]:
import torch
from torch import nn
import math

In [None]:

# 多头注意力
class MultiHeadAttention(nn.Module):
    """
    多头注意力层，用于计算多个注意力头的输出
    """
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 除取整

        self.W_q = nn.Linear(d_model, d_model)  # （输入的特征维度，输出的特征维度）
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        计算多头注意力的输出
        """
        # Q (N,n_head,S,d_k)
        # K (N,n_head,S,d_k)
        # V (N,n_head,S,d_k)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        # attn_scores (N,n_head,S,S)
        if mask is not None: 
            # mask (N,1,1,S)
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9) # attn_scores里mask为0的地方，用负无穷填充
        attn_probs = torch.softmax(attn_scores, dim=-1)
        # attn_probs (N,n_head,S,S)
        output = torch.matmul(attn_probs, V)
        # output (N,n_head,S,d_k)
        return output

    def split_heads(self, x):
        """
        分割多头注意力的输入，将输入的特征维度分割成多个头
        """
        # (N,S,D)
        batch_size, seq_length, d_model = x.size()
        # (N,S,n_head,d_k)
        # (N,n_head,S,d_k)
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        """
        组合多头注意力的输出，将多个头的输出拼接起来
        """
        # x (N,n_head,S,d_k)
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        # Q (N,S,D)
        Q = self.split_heads(self.W_q(Q))
        # Q (N,n_head,S,d_k)
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        # attn_output (N,n_head,S,d_k)
        output = self.W_o(self.combine_heads(attn_output))

        return output

In [3]:
torch.manual_seed(42)
d_model=16
n_heads=4
attention=MultiHeadAttention(d_model=d_model,num_heads=n_heads)

In [None]:
batch_size=1
seq_len=4
t1=torch.randn(1,d_model)
t2=torch.randn(1,d_model)
t3=torch.randn(1,d_model)
t4=torch.randn(1,d_model)
x1=torch.stack([t1,t2,t3,t4],dim=1)
x2=torch.stack([t3,t1,t2,t4],dim=1)
output=attention(x1,x1,x1)
print(output[0])
reversed_output=attention(x2,x2,x2)
print(reversed_output[0])
# 可以看出。更换attention的输入token顺序时（没有mask的情况下），输出的值也只是换了个顺序，但对应的值并没有变化。
# 导致解码时，若在最后一个token相同时，即使交换前面token的顺序，decode出来的token是一样的（attention value的输出来说，只是换了一下attention_output1=q1k1V1+q1k2V2+...中的加法顺序）。
# 所以需要在嵌入时，加上位置编码。


tensor([[-0.0737,  0.1454, -0.2397,  0.1644, -0.2396, -0.0370, -0.2983,  0.0133,
         -0.1099, -0.2542, -0.2101, -0.2282, -0.0183, -0.0015,  0.0715,  0.0826],
        [-0.1173,  0.2360, -0.1877,  0.2655, -0.2713, -0.1749, -0.2866, -0.0064,
         -0.0932, -0.1685, -0.2030, -0.2352, -0.1353,  0.0887,  0.1380,  0.2580],
        [-0.0735,  0.1678, -0.2673,  0.2655, -0.1810, -0.0833, -0.2807, -0.0929,
         -0.2139, -0.1899, -0.1074, -0.2367, -0.1392,  0.0705,  0.1032,  0.2191],
        [-0.0651,  0.1655, -0.2487,  0.2442, -0.1902, -0.0869, -0.2930, -0.0966,
         -0.1672, -0.1772, -0.1552, -0.2677, -0.1435,  0.0510,  0.0936,  0.2404]],
       grad_fn=<SelectBackward0>)
tensor([[-0.0735,  0.1678, -0.2673,  0.2655, -0.1810, -0.0833, -0.2807, -0.0929,
         -0.2139, -0.1899, -0.1074, -0.2367, -0.1392,  0.0705,  0.1032,  0.2191],
        [-0.0737,  0.1454, -0.2397,  0.1644, -0.2396, -0.0370, -0.2983,  0.0133,
         -0.1099, -0.2542, -0.2101, -0.2282, -0.0183, -0.0015,  0.071