In [1]:
import torch
from torch import nn
# src, tgt = tokenize_nmt(preprocess_nmt(read_data_nmt()),num_examples=600)
# src_vocab = Vocal(src, min_feq=2,
#                   reserved_tokens=['<pad>', '<bos>', '<eos>'])
# tgt_vocab = Vocal(tgt, min_feq=2,
#                   reserved_tokens=['<pad>', '<bos>', '<eos>'])
# src_data, src_valid = build_array_nmt(src, src_vocab, 10)
# tgt_data, tgt_valid = build_array_nmt(tgt, tgt_vocab, 10)
# dataset = torch.utils.data.TensorDataset(src_data, src_valid, tgt_data, tgt_valid)
# ## 训练数据
# train_data = torch.utils.data.DataLoader(dataset=dataset, batch_size=64, shuffle=False)

In [2]:
# 先实现一个点积注意力
class DotProductAttention(nn.Module):
    def __init__(self, dropout=0.2):
        super(DotProductAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)
    
    # q(b,step,embed_size)
    # k(b,键值对个数,embed_size)
    # v(b,键值对个数,embed_size)
    def forward(self, q, k, v,valid_lens):
        attn_weights = torch.bmm(q, k.transpose(1, 2))/torch.sqrt(torch.tensor(q.shape[-1]))
        self.attention_weights=masked_softmax(attn_weights,valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), v)

#@save
class MultiHeadAttention(nn.Module):
    """多头注意力"""
    # qkv各自的embed_size, 隐藏层大小 头数量
    # 需要并行运算多个头 因此num_hiddens 必须能够整除以num_heads
    def __init__(self,key_size,query_size,value_size,num_hiddens,num_heads,dropout,bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q=nn.Linear(query_size,num_hiddens,bias=bias)
        self.W_k=nn.Linear(key_size,num_hiddens,bias=bias)
        self.W_v=nn.Linear(value_size,num_hiddens,bias=bias)
        self.W_o=nn.Linear(num_hiddens,num_hiddens,bias=bias)

    def forward(self, q, k, v,valid_lens=None):

        queries=self.W_q(q)
        keys=self.W_k(k)
        values=self.W_v(v)

        # 在这一步需要对qkv拆分为多头 并行计算attention
        queries=queries.reshape(queries.shape[0],queries.shape[1],self.num_heads,-1).permute(0,2,1,3)
        keys=keys.reshape(keys.shape[0],keys.shape[1],self.num_heads,-1).permute(0,2,1,3)
        values=values.reshape(values.shape[0],values.shape[1],self.num_heads,-1).permute(0,2,1,3)

        queries=queries.reshape(-1,queries.shape[2],queries.shape[3])
        keys=keys.reshape(-1,keys.shape[2],keys.shape[3])
        values=values.reshape(-1,values.shape[2],values.shape[3])

        if valid_lens is not None:
            # 在轴0，将第一项（标量或者矢量）复制num_heads次，
            # 然后如此复制第二项，然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        attn_weights=self.attention(queries,keys,values,valid_lens)
        attn_weights=attn_weights.reshape(-1,self.num_heads,attn_weights.shape[1],attn_weights.shape[2])
        attn_weights=attn_weights.permute(0,2,1,3)
        attn_weights=attn_weights.reshape(attn_weights.shape[0],attn_weights.shape[1],-1)

        return self.W_o(attn_weights)


In [10]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape

torch.Size([2, 4, 100])