In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class MultiGroupAttention(nn.Module):
    def __init__(self, query_dim, key_dim, value_dim, num_heads, num_groups):
        super(MultiGroupAttention, self).__init__()
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.query_dim = query_dim
        self.key_dim = key_dim
        self.value_dim = value_dim

        assert query_dim % (num_heads * num_groups) == 0
        assert key_dim % (num_heads * num_groups) == 0
        assert value_dim % (num_heads * num_groups) == 0

        self.depth = query_dim // (num_groups * num_heads)

        # 定义线性层用于对查询、键和值进行变换
        self.query_layer = nn.Linear(query_dim, query_dim)
        self.key_layer = nn.Linear(key_dim, key_dim)
        self.value_layer = nn.Linear(value_dim, value_dim)
        self.output_layer = nn.Linear(query_dim, query_dim)

    def split_heads_and_groups(self, x, batch_size):
        # 将张量的形状调整为 (batch_size, seq_length, num_groups, num_heads, depth)
        x = x.view(batch_size, -1, self.num_groups, self.num_heads, self.depth)
        x.permute(0, 2, 3, 1, 4)
        # 打印转换后的形状
        print("Shape of x after permutation:", x.shape)
        # 交换维度，使其形状为 (batch_size, num_groups, num_heads, seq_length, depth)
        return x

    def forward(self, query, keys, values):
        batch_size = query.size(0)

        query = self.query_layer(query)
        keys = self.key_layer(keys)
        values = self.value_layer(values)

        # 计算缩放点积注意力得分
        attention_scores = torch.matmul(query, keys.transpose(-2, -1)) / (self.depth ** 0.5)
        # 通过softmax计算注意力权重
        attention_weights = F.softmax(attention_scores, dim=-1)
         # 用注意力权重加权求和值
        weighted_values = torch.matmul(attention_weights, values)
        # 交换维度并调整形状，使其变回 (batch_size, seq_length, query_dim)
        weighted_values = weighted_values.transpose(1, 2).contiguous().view(batch_size, -1, self.query_dim)

        # 通过线性层输出最终结果
        output = self.output_layer(weighted_values)

        return output, attention_weights

In [4]:
# 示例
multi_group_attention = MultiGroupAttention(query_dim=12, key_dim=12, value_dim=12, num_heads=2, num_groups=2)
query = torch.randn(1, 3, 12)  # 单个查询，维度为6
keys = torch.randn(1, 5, 12)   # 5个键，维度为6
values = torch.randn(1, 5, 12) # 5个值，维度为6

output, weights = multi_group_attention(query, keys, values)
print("多组注意力输出:", output)
print("多组注意力权重:", weights)

多组注意力输出: tensor([[[-0.2787,  0.2405,  0.0251, -0.3208,  0.1949, -0.0251,  0.0369,
           0.0381, -0.1912, -0.1752,  0.3274,  0.1324],
         [-0.2174, -0.0020, -0.2271, -0.0426,  0.2986,  0.4000,  0.3628,
          -0.2052,  0.1341,  0.0652,  0.2452,  0.1521],
         [-0.2798,  0.3276, -0.2936, -0.2366, -0.0550, -0.0022,  0.1846,
           0.0132, -0.1206, -0.3476,  0.2301,  0.2930]]],
       grad_fn=<ViewBackward0>)
多组注意力权重: tensor([[[0.1131, 0.0872, 0.1408, 0.3662, 0.2927],
         [0.2915, 0.2067, 0.2004, 0.1587, 0.1427],
         [0.1323, 0.1613, 0.3088, 0.1702, 0.2274]]],
       grad_fn=<SoftmaxBackward0>)
