In [1]:
import torch
import torch.nn as nn
import math

In [16]:
class SimpleDecoder(nn.Module):
    def __init__(self, hidden_dim, head_num, dropout=0.1):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.head_dim = hidden_dim // head_num

        #定义qkv矩阵
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        #attention的投影层
        self.proj_o = nn.Linear(hidden_dim, hidden_dim)
        #dropout
        self.att_dropout = nn.Dropout(dropout)
        #层归一化
        self.layernorm_att = nn.LayerNorm(hidden_dim, eps=0.00001)

        #前馈神经网络的函数
        self.up = nn.Linear(hidden_dim, 4 * hidden_dim)
        self.down = nn.Linear(hidden_dim * 4, hidden_dim)
        self.layernorm_ffn = nn.LayerNorm(hidden_dim, eps=0.00001)
        self.act_fn = nn.ReLU()
        self.ffn_dropout = nn.Dropout(dropout)

        #最后一层投影层
        self.proj_ffn = nn.Linear(hidden_dim, hidden_dim)

    #因果注意力
    def causal_attention(self, q, k, v, att_mask=None):
        batch, _, seq_len, _ = q.size()
        #q.shape : batch, head_num, seq_len, head_dim
        att_weight = q @ k.transpose(-1, -2) / math.sqrt(self.head_dim)

        #加mask
        if att_mask is not None:
            att_weight = att_weight.tril()
            att_weight = att_weight.masked_fill(
                att_mask == 0,
                float('-inf')
            )
        
        else:
            att_mask = torch.ones_like(att_weight).tril()
            att_weight = att_weight.masked_fill(
                att_mask == 0,
                float('-inf')
            )

        # softmax
        att_weight = torch.softmax(att_weight, dim=-1)
        #dropout
        att_weight = self.att_dropout(att_weight)
        #计算中间结果 shape batch, head_num, seq_len, head_dim
        output_mid = att_weight @ v
        #把shape转换成想要的输出 shape batch, seq_len, hidden_dim
        att_output = self.proj_o(output_mid.transpose(1,2).contiguous().view(batch, seq_len, -1))

        return att_output
    
    #attention块
    def att_block(self, X, att_mask=None):
        batch, seq_len, _ = X.size()
        q = self.query(X).view(batch, seq_len, self.head_num, self.head_dim).transpose(1,2)
        k = self.key(X).view(batch, seq_len, self.head_num, self.head_dim).transpose(1,2)
        v = self.value(X).view(batch, seq_len, self.head_num, self.head_dim).transpose(1,2)

        att_output = self.causal_attention(
            q,k,v,
            att_mask=att_mask
        )

        return self.layernorm_att(X + att_output)

    # 前馈神经网络块
    def ffn_block(self, X):

        up = self.up(X)
        up = self.act_fn(up)
        down = self.down(up)
        output_mid = self.ffn_dropout(down)
    
        return self.layernorm_ffn(X + output_mid)

    def forward(self, X, att_mask=None):
        att_output = self.att_block(X, att_mask=att_mask)
        ffn_output = self.ffn_block(att_output)
        output = self.proj_ffn(ffn_output)

        return output





x = torch.rand(3, 4, 64)
net = SimpleDecoder(64, 8)
mask = (
    torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0], [1, 1, 1, 0]])
    .unsqueeze(1)
    .unsqueeze(2)
    .repeat(1, 8, 4, 1)
)

net(x, mask).shape
            



torch.Size([3, 4, 64])