In [None]:
MAX_LEN = _
class MaskedMultiheadAttention(nn.Module):
    def __init__(self, mask=False):
        super(MaskedMultiheadAttention, self).__init__()
        assert args.nhid_tran % args.nhead == 0
        self.key = nn.Linear(args.nhid_tran, args.nhid_tran)
        self.query = nn.Linear(args.nhid_tran, args.nhid_tran)
        self.value = nn.Linear(args.nhid_tran, args.nhid_tran)
        # regularization
        self.attn_drop = nn.Dropout(args.attn_pdrop)
        # output projection
        self.proj = nn.Linear(args.nhid_tran, args.nhid_tran)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        if mask:
            self.register_buffer("mask", torch.tril(torch.ones(MAX_LEN, MAX_LEN)))
        self.nhead = args.nhead
        self.d_k = args.nhid_tran // args.nhead

    def forward(self, q, k, v, mask=None):
        q = self.query(q).reshape(q.shape[0], q.shape[1], self.nhead, -1).contiguous().transpose(1,2).contiguous()
        k = self.key(k).reshape(k.shape[0], k.shape[1], self.nhead, -1).contiguous().transpose(1,2).contiguous()
        v = self.value(v).reshape(v.shape[0], v.shape[1], self.nhead, -1).contiguous().transpose(1,2).contiguous()
        
        similiarity = torch.matmul(q, k.transpose(-1,-2)) / self.d_k ** 0.5

        
        # if hasattr(self, 'mask'): 
        if mask is None:
          # similiarity.shape = (B, nhead, T_q, T); 
          # self.mask.shape = (T_q, T) --> (1, 1, T_q, T)
          mask = self.mask[:similiarity.shape[-2], :similiarity.shape[-1]].unsqueeze(dim=0).unsqueeze(dim=0)
          mask = mask.repeat(similiarity.shape[0], 1, 1, 1)
        else:
          # similiarity.shape = (B, nhead, T_q, T)
          # mask shape = (B,T)  --> (B, 1, 1, T)
          mask = mask.unsqueeze(dim=1).unsqueeze(dim=1)
          mask = mask.repeat(1, 1, similiarity.shape[2], 1)
        
        similiarity = similiarity.masked_fill(mask==0, -np.inf)
        scaled = self.attn_drop(torch.softmax(similiarity, dim=-1))
        attn_out = torch.matmul(scaled, v).transpose(1,2)
        attn_out = attn_out.contiguous().reshape(attn_out.shape[0], attn_out.shape[1], -1)
        output = self.proj(attn_out)

        return output
