In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.useful_func import *

In [7]:
def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X为3D batch_size*x*y valid_len 最后一个轴的有效长度
    # 如果valid_len是1d 则len必须等于X最外围的长度
    # 如果valid_len是2d 则 reshape后必须等于 x的0层 *1层长度
    if valid_lens is None:
        return F.softmax(X, dim=-1)
    else:
        shape=X.shape
        if valid_lens.ndim==1:
            ## valid_lens是最后一个轴的长度 因此要扩展以匹配X维度
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
    X=sequence_mask(X.reshape(-1, shape[-1]), valid_lens,value=-1e6)
    return F.softmax(X, dim=-1)
        

In [11]:
x=torch.rand([2,3,5])

In [14]:
masked_softmax(x,valid_lens=torch.tensor([3,2]))

tensor([[0.2861, 0.3829, 0.3310, 0.0000, 0.0000],
        [0.2443, 0.4275, 0.3282, 0.0000, 0.0000],
        [0.3074, 0.3590, 0.3336, 0.0000, 0.0000],
        [0.3905, 0.6095, 0.0000, 0.0000, 0.0000],
        [0.3660, 0.6340, 0.0000, 0.0000, 0.0000],
        [0.4616, 0.5384, 0.0000, 0.0000, 0.0000]])

In [15]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.2686, 0.2477, 0.4837, 0.0000],
        [0.5050, 0.4950, 0.0000, 0.0000],
        [0.3359, 0.2150, 0.2300, 0.2191]])

# 1、加性注意力

In [None]:
class AdditiveAttention(nn.Module):
    """加性注意力"""
    def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs):
        super(AdditiveAttention,self).__init__()
        self.W_k = nn.Linear(key_size,num_hiddens,bias=False)
        self.W_q = nn.Linear(query_size,num_hiddens,bias=False)
        self.W_v = nn.Linear(num_hiddens,1,bias=False)
        self.dropout = nn.Dropout(dropout)
    # queries 维度应该是 batch_size * 要查询的数量 * q_size向量长度
    # keys 维度是 batch_size * keys的数量（key-value)键值对 * key向量长度
    # values与key相等 value_size可以不一样
    def forward(self, queries, keys, values, valid_lens): ## valid_len从输入来的 屏蔽掉填充部分
        queries,keys=self.W_q(queries),self.W_k(keys)
        queries=queries.unsqueeze(2)
        keys=keys.unsqueeze(1)
        features = queries + keys
        features = torch.tanh(features)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

In [25]:
x=torch.ones([2,3,5])
y=torch.ones([3,5])

In [26]:
x+y

tensor([[[2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.]],

        [[2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.]]])