In [46]:
import seaborn as sns
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import rand, tensor, Tensor, nn

%matplotlib inline

sns.set_theme(style="white")

In [47]:
# 加性注意力
class AdditiveAttention(nn.Module):
    def __init__(
        self,
        keySize: int,
        hiddenNum: int,
        querySize: int,
        pairNum: int,
    ) -> None:
        super().__init__()

        self.keySize = keySize
        self.querySize = querySize
        self.hiddenNum = hiddenNum
        self.pairNum = pairNum

        self.KLinear = nn.Linear(keySize, hiddenNum, bias = False)
        self.QLinear = nn.Linear(querySize, hiddenNum, bias = False)
        self.VLinear = nn.Linear(hiddenNum, 1, bias = False)

    def forward(self, queries, keys, values) -> Tensor:
        batchSize = queries.size(0)

        # 映射到 [feature, hiddenNum] 的空间上
        # unsqueeze 使两个张量可以进行广播运算
        queries = self.QLinear(queries).unsqueeze(2)
        keys = self.KLinear(keys).unsqueeze(1)
        print(queries.shape)
        print(keys.shape)
        assert queries.shape[-1] == keys.shape[-1], "Dimension error"

        # 利用广播机制 加和 求得评分函数的值
        features = torch.tanh(queries + keys)

        scores = self.VLinear(features).squeeze(-1)
        self.attentionWeights = F.softmax(scores, dim = 1)
        return torch.bmm(self.attentionWeights, values)

In [49]:
batchSize = 16
dictLength = 128
seqLength = 64
featureNum = 64
kFeature, vFeature = 10, 20

net = AdditiveAttention(kFeature, 32, seqLength, dictLength)

# feature -> 经过 embedding 后 每一个词元的编码长度
queries = rand(batchSize, seqLength, featureNum)
keys = rand(batchSize, dictLength, kFeature)
values = rand(batchSize, dictLength, vFeature)


output = net(queries, keys, values)
print(output.shape)

torch.Size([16, 64, 1, 32])
torch.Size([16, 1, 128, 32])
torch.Size([16, 64, 20])
