In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
mat1 = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(mat1, mat2)
print(f"mat1.shape: {mat1.shape}")
print(f"mat2.shape: {mat2.shape}")
print(f"res.shape: {res.shape}")

mat1.shape: torch.Size([10, 3, 4])
mat2.shape: torch.Size([10, 4, 5])
res.shape: torch.Size([10, 3, 5])


# 注意力机制(规则一)

In [2]:
class Attn(nn.Module):
    def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
        # query_size 代表Q的最后一个维度, key_size 代表k最后一个维度
        # V 的尺寸(1, value_size1, value_size2)
        # output_size 输出的最后一个维度
        super(Attn, self).__init__()
        self.query_size = query_size
        self.key_size = key_size
        self.value_size1 = value_size1
        self.value_size2 = value_size2
        self.output_size = output_size

        # 注意力机制的线性层
        self.attn = nn.Linear(self.query_size + self.key_size, self.value_size1)
        # 注意力机制实现的第三步线性层
        self.attn_combin = nn.Linear(self.query_size + self.value_size2, self.output_size)

    def forward(self, Q, K, V):
        # 假定输入的Q、K、V都是三维
        # 第一步, 将Q、K纵轴拼接-->线性变换-->softmax激活====>注意力向量
        attn_weight = F.softmax(self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)

        # 将注意力矩阵与V张量乘法
        attn_applied = torch.bmm(attn_weight.unsqueeze(0), V)

        # 取Q[0]降维，再次和上面结果拼接
        output = torch.cat((Q[0], attn_applied[0]), 1)

        # 第三步, 将上面输出进行线性变换，然后扩展维度到三维
        output = self.attn_combin(output).unsqueeze(0)
        return output, attn_weight

In [4]:
query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
output_size = 64
attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
Q = torch.randn(1, 1, 32)
K = torch.randn(1, 1, 32)
V = torch.randn(1, 32, 64)
output = attn(Q, K, V)
print(output[0].size())
print(output[1].size())

torch.Size([1, 1, 64])
torch.Size([1, 32])
