多头自注意力机制

$\alpha$

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embedding_dim, head_num):
        super().__init__()
        self.q = nn.Linear(embedding_dim, embedding_dim)
        self.k = nn.Linear(embedding_dim, embedding_dim)
        self.v = nn.Linear(embedding_dim, embedding_dim)
        # 每个头的维度
        self.head_dim = embedding_dim // head_num
        self.head_num = head_num

        # 线性映射
        self.proj = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, x):
        batch_size, seq_length, seq_dim = x.shape
        Q = self.q(x)
        K = self.k(x)
        V = self.v(x)
        # 将qkv的形状进行重塑，然后实现并行计算
        Q = Q.view(batch_size, seq_length, self.head_num, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_length, self.head_num, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_length, self.head_num, self.head_dim).transpose(1, 2)
        # 计算交叉注意力
        attn = F.softmax((Q @ K.transpose(2,3))/ self.head_dim ** 0.5, dim =-1)
        print(attn.shape)
        # 计算输出 NOTE 此时矩阵的乘法需要注意有哪几种
        output = attn @ V
        # 将所有头的结果拼接在一起 (b, head_num, seq_len, head_dim) -> (b, seq_len, head_num, head_dim) -> (b, seq_len, embedding_dim)
        output = output.transpose(1,2).reshape(batch_size, seq_length, -1)
        output = self.proj(output)
        return output


embedding_dim = 768
head_num = 8
dummy_input = torch.randn(1, 196, 768)
multi_head_self_attention = MultiHeadSelfAttention(embedding_dim=embedding_dim, head_num=8)
output = multi_head_self_attention(dummy_input)
print(output.shape)

torch.Size([1, 8, 196, 196])
torch.Size([1, 196, 768])
