In [1]:
import torch
import torch.nn as nn
import math
import copy

In [2]:
def attention(query, key, value, mask=None, dropout=None):
    
    '''
    query=(1xd_k), key,value=(nxd_k)로 생각 (실제로 query도 (nxd_k) word matrix)
    scaling : d_k 커질수록 softmax 시 gradient saturate 방지
    '''
    
    d_k = query.size(-1) 
    scores = torch.matmul(query, key.transpose(-2,-1)) / math.sqrt(d_k) # score per word wrt query
    if mask:
        scores = scores.masked_fill(mask==0, -1e9) # pad masking
    prob = scores.softmax(dim=0)
    if dropout:
        prob = dropout(prob)
    
    weighted_query = torch.matmul(prob, value)
    
    return weighted_query, prob

In [3]:
def clones(module, N):
    "Produce modulelist with N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model % h == 0
        
        self.d_model = d_model
        self.d_k = d_model // h
        self.h = h
        self.dropout = nn.Dropout(p=dropout)
        self.modulelist = clones(nn.Linear(d_model, d_model), N=4) # torch에 인식 위함, qkv 위해 3개, 마지막 위해 1개
        self.att_prob = None
        
    def forward(self, query, key, value, mask=None):
        if mask:
            mask = mask.unsqueeze(1) # ??
        num_batch = query.size(0)
        
        qkv_list = []
        for lin, x in zip(self.modulelist, (query, key, value)):
            qkv = lin(x).view(num_batch, -1, self.h, self.d_k).transpose(1,2)
            qkv_list.append(qkv)
        
        weighted_query, self.att_prob = attention(query, key, value, mask=mask, dropout=self.dropout)
        
        # concat
        weighted_query = (weighted_query.transpose(1,2).contiguous().view(num_batch, -1, self.h*self.d_k))
        del query
        del key
        del value
        
        return self.modulelist[-1](weighted_query)
        
            