In [1]:
import math
import torch
import torch.nn as nn


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self,hidden_dim,nums_head,bias=False,drop_rate=0.1,reflect_matrix=True)->None:
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.nums_head = nums_head
        self.head_dim  = hidden_dim//nums_head
        self.num_heads = nums_head
        assert self.hidden_dim % self.num_heads == 0 #注意力的头数量得可以整除
        self.query = nn.Linear(hidden_dim,hidden_dim,bias=bias)
        self.key = nn.Linear(hidden_dim,hidden_dim,bias=bias)
        self.value = nn.Linear(hidden_dim,hidden_dim,bias=bias)
        self.att_dropout = nn.Dropout(drop_rate)
        
        if reflect_matrix:
            self.outputs = nn.Linear(hidden_dim,hidden_dim)
        else:
            self.outputs = None
    def forward(self,x,mask=None):
        b,s,hidden_dim = x.size()
        q = self.query(x).view(b,s,self.nums_head,self.head_dim).transpose(1,2)
        k = self.key(x).view(b,s,self.nums_head,self.head_dim).transpose(1,2)
        v = self.value(x).view(b,s,self.nums_head,self.head_dim).transpose(1,2)
        # q.size() = (b,nums_head,s,head_dim)
        attention_score = q@k.transpose(-1,-2)/math.sqrt(self.head_dim)
         # attention_score.size() = (b,nums_head,s,s)
        if mask is not None:
            attention_score = attention_score.masked_fill(mask==0,float('-inf'))
        attention_weight = torch.softmax(attention_score,dim=-1)
        attention_weight = self.att_dropout(attention_weight)
        attention_weight = attention_weight@v
        # attention_weight.size() = (b,nums_head,s,head_dim)
        output_mid = attention_weight.transpose(1,2).contiguous()#output_mid.size() = (b,s,nums_head,head_dim)
        output_mid = output_mid.view(b,s,-1)
        #因为使用了多头注意力，把
        if self.outputs is not None:
            result = self.outputs(output_mid)
        else:
            result = output_mid
        return result

In [3]:
if __name__ == "__main__":
    attention_mask = (
        torch.tensor(
            [
                [0, 1],
                [0, 0],
                [1, 0],
            ]
        )
        .unsqueeze(1)
        .unsqueeze(2)
        .expand(3, 8, 2, 2)
    )

    x = torch.rand(3, 2, 128)
    net = MultiHeadAttention(128, 8)
    print(net(x, attention_mask).shape)

torch.Size([3, 2, 128])
