In [None]:
import torch
import torch.nn as nn
import numpy as np

"""掩码张量"""


def subsequent_mask(size):
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(1 - subsequent_mask)


"""注意力机制"""
import math
import torch.nn.functional as F


def attention(query, key, value, mask=None, dropout=None):
    #词嵌入维度
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e-9)
    p_attn = F.softmax(scores, dim=-1)

    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn


"""克隆"""
import copy


def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self,head,embedding_dim,dropout = 0.1):
        super(MultiHeadedAttention,self).__init__()
        assert embedding_dim % head == 0
        self.d_k = embedding_dim // head
        self.head = head
        self.linears = clones(nn.Linear(embedding_dim,embedding_dim), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, query, key, value, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        batch_size = query.size(0)
        query, key, value = [model(x).view(batch_size,-1,self.head,self.d_k) 
                             for model,x in zip(self.linears,[query, key, value])]
        x,self.attn = attention(query,key,value,mask=mask,dropout=self.dropout)
        x = x.transpose(1,2).contiguous().view(batch_size,-1,self.head*self.d_k)
        return self.linears[-1](x)