In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [23]:
class MultiHeadAttention(nn.Module):
    def __init__(self, query_dim, key_dim, value_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.query_dim = query_dim
        self.key_dim = key_dim
        self.value_dim = value_dim

        assert query_dim % num_heads == 0
        assert key_dim % num_heads == 0
        assert value_dim % num_heads == 0

        self.depth = key_dim // num_heads # 计算每个头的深度

        self.query_layer = nn.Linear(query_dim, query_dim)
        self.key_layer = nn.Linear(key_dim, key_dim)
        self.value_layer = nn.Linear(value_dim, value_dim)
        self.output_layer = nn.Linear(query_dim, query_dim)

    def split_heads(self, x, batch_size):
        print(f"Before split: {x.shape}")
        x = x.view(batch_size, -1, self.num_heads, self.depth) # 重新调整输入的维度，将其拆分成多个头。
        x = x.transpose(1, 2)
        print(f"After split: {x.shape}")
        return x

    def forward(self, query, keys, values):
        batch_size = query.size(0)

        print("Batch size is ", batch_size)

        query = self.query_layer(query)
        keys = self.key_layer(keys)
        values = self.value_layer(values)

        print("Split Query.")
        query = self.split_heads(query, batch_size)
        print("Split Key.")
        keys = self.split_heads(keys, batch_size)
        print("Split Value.")
        values = self.split_heads(values, batch_size)


        # Attention(Q, K, V) = softmax(Q * K.T/sqrt(dim_k))
        attention_scores = torch.matmul(query, keys.transpose(-2, -1)) / (self.depth ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        weighted_values = torch.matmul(attention_weights, values)

        weighted_values = weighted_values.transpose(1, 2).contiguous().view(batch_size, -1, self.query_dim)
        output = self.output_layer(weighted_values)

        return output, attention_weights

In [29]:
# 示例
multi_head_attention = MultiHeadAttention(query_dim=6, key_dim=6, value_dim=6, num_heads=2)
query = torch.randn(1, 3, 6)  # 单个查询，维度为6, (batch_size = 1, query_len, query_dim=6)
keys = torch.randn(1, 5, 6)   # 5个键，维度为6 batch_size = 1, key_len, query_dim=6
values = torch.randn(1, 5, 6) # 5个值，维度为6 batch_size = 1, value_len, query_dim=6

output, weights = multi_head_attention(query, keys, values)
print("多头注意力输出:", "输出维度：",output.shape, "\n", output)
print("多头注意力权重:", "权重维度：", weights.shape, "\n", weights)

Batch size is  1
Split Query.
Before split: torch.Size([1, 3, 6])
After split: torch.Size([1, 2, 3, 3])
Split Key.
Before split: torch.Size([1, 5, 6])
After split: torch.Size([1, 2, 5, 3])
Split Value.
Before split: torch.Size([1, 5, 6])
After split: torch.Size([1, 2, 5, 3])
多头注意力输出: 输出维度： torch.Size([1, 3, 6]) 
 tensor([[[-0.1535, -0.4247, -0.1395, -0.4718,  0.1190,  0.3651],
         [-0.1183, -0.3199, -0.1579, -0.4439,  0.1964,  0.3679],
         [-0.0967, -0.4078, -0.1471, -0.4125,  0.1205,  0.3756]]],
       grad_fn=<ViewBackward0>)
多头注意力权重: 权重维度： torch.Size([1, 2, 3, 5]) 
 tensor([[[[0.0768, 0.2107, 0.2565, 0.1368, 0.3192],
          [0.3251, 0.2179, 0.1229, 0.1612, 0.1728],
          [0.0869, 0.2524, 0.2104, 0.1082, 0.3421]],

         [[0.2062, 0.1938, 0.1699, 0.1516, 0.2784],
          [0.2344, 0.2726, 0.1321, 0.1613, 0.1997],
          [0.2618, 0.2863, 0.0658, 0.0781, 0.3080]]]],
       grad_fn=<SoftmaxBackward0>)
