In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from d2l import torch as d2l

In [2]:
def masked_softmax(X, valid_lens):
    """Perform softmax operation by masking elements on the last axis.

    Defined in :numref:`sec_attention-scoring-functions`"""
    # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                              value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

In [3]:
#加性注意力
class AdditiveAttention(nn.Module):
    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # 在维度扩展后，
        # queries的形状：(batch_size，查询的个数，1，num_hidden)
        # key的形状：(batch_size，1，“键－值”对的个数，num_hiddens)
        # 使用广播方式进行求和
        print("q, k升维后的维度：")
        print(queries.unsqueeze(2).shape, keys.unsqueeze(1).shape)
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        print("特征维度是：", features.shape)
        # self.w_v仅有一个输出，因此从形状中移除最后那个维度。
        # scores的形状：(batch_size，查询的个数，“键-值”对的个数)
        print("v的维度是：", self.w_v(features).shape)
        scores = self.w_v(features).squeeze(-1)
        print("注意力得分的维度是：", scores.shape)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # values的形状：(batch_size，“键－值”对的个数，值的维度)
        return torch.bmm(self.dropout(self.attention_weights), values)


In [4]:
attention = AdditiveAttention(key_size=5, query_size=6, num_hiddens=4, dropout=0.1)

# 假设我们有一个批次的查询（queries）、键（keys）和值（values）
batch_size = 3
num_queries = 4
num_key_value_pairs = 6
keys = torch.randn((batch_size, num_key_value_pairs, 5))  # (batch_size, num_key_value_pairs, key_size)
values = torch.randn((batch_size, num_key_value_pairs, 4))  # (batch_size, num_key_value_pairs, value_size)
queries = torch.randn((batch_size, num_queries, 6))  # (batch_size, num_queries, query_size)

# 假设每个查询对所有键值对都有效
valid_lens = torch.tensor([6] * batch_size)

# 计算注意力输出
output = attention(queries, keys, values, valid_lens)
print("输出的形状: ", output.shape)
print("注意力权重: ", attention.attention_weights)

q, k升维后的维度：
torch.Size([3, 4, 1, 4]) torch.Size([3, 1, 6, 4])
特征维度是： torch.Size([3, 4, 6, 4])
v的维度是： torch.Size([3, 4, 6, 1])
注意力得分的维度是： torch.Size([3, 4, 6])
输出的形状:  torch.Size([3, 4, 4])
注意力权重:  tensor([[[0.1598, 0.1438, 0.2024, 0.1870, 0.1553, 0.1516],
         [0.1630, 0.1399, 0.2013, 0.1888, 0.1571, 0.1500],
         [0.1687, 0.1577, 0.1763, 0.1720, 0.1673, 0.1580],
         [0.1690, 0.1518, 0.1838, 0.1763, 0.1667, 0.1524]],

        [[0.2336, 0.1637, 0.1296, 0.1214, 0.1786, 0.1732],
         [0.1831, 0.1673, 0.1618, 0.1580, 0.1666, 0.1633],
         [0.1994, 0.1732, 0.1406, 0.1288, 0.1807, 0.1774],
         [0.2390, 0.1611, 0.1336, 0.1262, 0.1746, 0.1655]],

        [[0.1307, 0.1993, 0.2078, 0.1505, 0.1453, 0.1665],
         [0.1255, 0.2036, 0.2153, 0.1458, 0.1394, 0.1704],
         [0.1239, 0.2010, 0.2096, 0.1491, 0.1450, 0.1715],
         [0.1219, 0.2027, 0.2128, 0.1471, 0.1430, 0.1725]]],
       grad_fn=<SoftmaxBackward0>)
