In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import *
from torch.nn.parameter import Parameter
import numpy as np

In [8]:
#Scaled Dot-Product Attention的实现
class ScaledDotProductAttention(nn.Module):
    def __init__(self,d_model):
        super(ScaledDotProductAttention,self),__init__()
            #计算缩放因子为8
        self.temper = np.power(d_model,0.5) #d_model为 词嵌入维度512
    
    def forward(self,Q,K,V):
        qk = torch.bmm(Q,K.transpose(1,2))
        scaled_qk = qk/self.temper
        attention_score = F.softmax(scaled_qk,dim = 1)
        V_attention = torch.matmul(V,attention_score)
        return V_attention

In [10]:
class MutilHeadAttention(nn.Module):
    def __init__(self, model_dim = 512, num_heads = 8,dropout = 0.0):
        super(MutilHeadAttention,self).__init__()
        self.per_head_dim = model_dim//num_heads
        self.num_heads = num_heads
        #线性映射为8份
        self.linear_k = nn.Linear(model_dim, self.per_head_dim*num_heads)
        self.linear_q = nn.Linear(model_dim, self.per_head_dim*num_heads)
        self.linear_v = nn.Linear(model_dim, self.per_head_dim*num_heads)
        
        #attention
        self.attention_net = ScaledDotProductAttention(d_model) #d_model = 64
        
        #最后拼接输入线性映射层
        self.final_linear = nn.Linear(model_dim,model_dim)
        
        self.dropout = nn.Dropout(dropout)
        # multi-head attention之后需要做layer norm
        self.layer_norm = nn.LayerNorm(model_dim)
    
    def forward(self,query,key,value):
        #残差
        residual = query
        num_heads = self.num_heads
        per_head_dim = self.per_head_dim
        #batchsize
        batch_size = k.size(0)
        
        # linear projection 
        key = self.linear_k(key)
        query = self.linear_k(query)
        value = self.linear_k(value)
        
        #spilt heads 64维
        key = key.view(batch_size*num_heads,-1,per_head_dim)
        query = query.view(batch_size*num_heads,-1,per_head_dim)
        value = value.view(batch_size*num_heads,-1,per_head_dim)
        
        #送入attention_net
        outputs = self.attention_net(query,key,value)
        
        #拼接
        outputs = outputs.view(batch_size,-1,per_head_dim*num_heads)
        
        #送入线性层
        outputs = self.final_linear(outputs)
        #dropout
        outputs = self.dropout(outputs)
        
        #加上残差项，再layer norm
        outputs = self.layer_norm(outputs+residual)
        
        return outputs