In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleAttention(nn.Module):
    def __init__(self, feature_dim):
        super(SimpleAttention, self).__init__()
        self.feature_dim = feature_dim
        self.scale = 1.0 / (self.feature_dim ** 0.5)
        self.query = nn.Linear(self.feature_dim, self.feature_dim)
        self.key = nn.Linear(self.feature_dim, self.feature_dim)
        self.value = nn.Linear(self.feature_dim, self.feature_dim)

    def forward(self, x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        attention_weights = F.softmax(attention_scores, dim=-1)

        attention_output = torch.matmul(attention_weights, V)
        return attention_output

# 设置参数
num_matrices = 5
num_rows = 5825
feature_dim = 128
total_rows = num_matrices * num_rows

# 创建模型实例
attention_layer = SimpleAttention(feature_dim)

# 创建一个29125x128的矩阵
large_matrix = torch.randn(total_rows, feature_dim)

# 应用注意力层
attention_result = attention_layer(large_matrix)

# 假设我们想将结果拆分回原来的五个部分并求平均
split_attention_results = attention_result.view(num_matrices, num_rows, feature_dim)
averaged_result = torch.mean(split_attention_results, dim=0)

# 检查结果的尺寸
print(averaged_result.shape)

torch.Size([5825, 128])
