In [1]:
import math

import torch
from torch import nn
from torch import functional as F

In [None]:
def attention(query,key,value,dropout = None):
    """

    :param query: 每行是一个查询 n*m
    :param key: 每行是一个键的词向量 k*m
    :param value: k * 1
    :param dropout:
    :return:
    """
    d_k = query.size(-1) # 获得key的维度
    # 计算注意力得分
    # 转置后，此处为 n*m  *  m*k  =  n*k
    scores = torch.matmul(query,key.transpose(-2,-1)) / math.sqrt(d_k)


    p_attn = torch.softmax(scores,dim=-1)

    if dropout is not None:
        p_attn = dropout(p_attn)

    # 得到最终的注意力结果 n*k  *  k*1  =  n * 1
    attn = torch.matmul(p_attn,value)
    return attn,p_attn

def gen_mask(max_seqlen):
    mask = torch.full((1,max_seqlen,max_seqlen),float("-inf"))
    mask = torch.triu(mask,diagonal=1) #保留mask的上三角部分，diagonal设为1表示不包括主对角线
    return mask

def masked(scores:torch.Tensor,seqlen,max_seqlen):
    mask = gen_mask(max_seqlen)
    scores = scores + mask[:, :seqlen,:seqlen]
    scores = torch.softmax(scores,dim=-1)
    return scores
