# 描述    
基于pytorch实现Transformer的decoder模块（非传统版本，是CausalLM decoder）   

区别：
- 传统Decoder: 包含自注意力层+编码器-解码器注意力层+前馈网络，依赖编码器的输出（如机器翻译中的源语言编码），典型模型Transformer、BART、T5
- CausalLM decoder: 仅包含自注意力层+前馈网络，仅依赖自身生成的序列（无编码器输入），典型模型GPT系列、LLaMA

参考文献：
- [手写 transformer decoder（CausalLM）](https://bruceyuan.com/hands-on-code/hands-on-causallm-decoder.html)

In [None]:
# 导入相关需要的包
import math
import torch
import torch.nn as nn

class SimpleDecoder(nn.Module):
    def __init__(self, hidden_dim, nums_head, dropout=0.1):
        super(SimpleDecoder, self).__init__()
        self.nums_head = nums_head
        self.head_dim = hidden_dim // nums_head
        self.dropout_prob = dropout
        
        # MHA
        self.layernorm_att = nn.LayerNorm(hidden_dim, eps=1e-6)
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)
        self.dropout_att = nn.Dropout(self.dropout_prob)

        # FFN
        self.up_proj = nn.Linear(hidden_dim, hidden_dim * 4) # 升维
        self.down_proj = nn.Linear(hidden_dim * 4, hidden_dim) # 降维
        self.layernorm_ffn = nn.LayerNorm(hidden_dim, eps=1e-6)
        self.activation_ffn = nn.ReLU()
        self.dropout_ffn = nn.Dropout(self.dropout_prob)

    def attention_output(self, query, key, value, attention_mask=None):
        att_weight = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)
        # attention mask 进行依次调整；变成 causal_attention
        # Pad Mask：过滤无效填充符号，避免模型关注输入序列中的填充符号，形状是二维矩阵（[B, L]）
        # Causal Mask：防止未来信息泄漏，确保生成时每个位置只能访问历史信息，形状是下三角矩阵（[L, L]）
        if attention_mask is not None:
            # 变成下三角矩阵
            attention_mask = attention_mask.tril()
        else:
            # 人工构造一个下三角的 attention mask
            attention_mask = torch.ones_like(att_weight).tril()
        att_weight = att_weight.masked_fill(attention_mask == 0, float('-inf'))
        print(att_weight)
        att_weight = self.dropout_att(att_weight)
        att_weight = torch.softmax(att_weight, dim=-1)
        mid_output = torch.matmul(att_weight, value) # shape is (batch_size, nums_head, seq_len, head_dim)
        mid_output = mid_output.transpose(1, 2).contiguous()
        batch_size, seq_len, _, _ = mid_output.size()
        mid_output = mid_output.view(batch_size, seq_len, -1) # shape is (batch_size, seq_len, hidden_dim)
        output = self.o_proj(mid_output)
        return output       

    def attention_block(self, x, att_mask=None):
        batch_size, seq_len, _ = x.size()
        query = self.q_proj(x).view(batch_size, seq_len, self.nums_head, -1).transpose(1, 2) # (batch_size, nums_head, seq_len, head_dim)
        key = self.q_proj(x).view(batch_size, seq_len, self.nums_head, -1).transpose(1, 2) # (batch_size, nums_head, seq_len, head_dim)
        value = self.q_proj(x).view(batch_size, seq_len, self.nums_head, -1).transpose(1, 2) # (batch_size, nums_head, seq_len, head_dim)
        output = self.attention_output(query, key, value, attention_mask=att_mask)
        return self.layernorm_att(x + output)
    
    def ffn_block(self, x):
        up = self.up_proj(x)
        up = self.activation_ffn(up)
        down = self.down_proj(up)
        down = self.dropout_ffn(down)
        return self.layernorm_ffn(x + down)
    
    def forward(self, x, att_mask=None):
        # x shape is (batch_size, seq_len, hidden_dim)
        # att_mask 一般是指经过tokenizer后返回的mask结果，表示哪些样本需要被忽略
        att_output = self.attention_block(x, att_mask=att_mask)
        ffn_output = self.ffn_block(att_output)
        return ffn_output

class Decoder(nn.Module):
    def __init__(self, layer_nums, emb_size=12, hidden_dim=64, nums_head=8, dropout=0.1):
        super(Decoder, self).__init__()
        self.layer_nums = layer_nums
        self.emb_size = emb_size
        self.hidden_dim = hidden_dim
        self.nums_head = nums_head
        self.dropout_prob = dropout
        self.layer_list = nn.ModuleList(
            [
                SimpleDecoder(self.hidden_dim, self.nums_head, self.dropout_prob) for _ in range(layer_nums)
            ]
        )
        self.emb = nn.Embedding(self.emb_size, self.hidden_dim)
        self.out = nn.Linear(self.hidden_dim, self.emb_size)

    def forward(self, x, att_mask=None):
        x = self.emb(x)
        for i, l in enumerate(self.layer_list):
            x = l(x, att_mask=att_mask)
        output = self.out(x)
        return torch.softmax(output, dim=-1)

In [None]:
# 测试
x = torch.randint(low=0, high=12, size=(3, 4))
net = Decoder(5)
mask = (
    torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0], [1, 1, 1, 0]])
    .unsqueeze(1)
    .unsqueeze(2)
    .repeat(1, 8, 4, 1)
)
print(x.size())
print(mask.size())
net(x, mask).shape