### 手撕transformer_decoder架构

原版的比较复杂，一般也不会让写。这里的 Decoder 一般指的是 CausalLM，具体变化是少了 encoder 部分的输入，所以也就没有了 encoder and decoder cross attention

causal lm: self_attention + FFN

1. causalLM decoder 的流程是 input -> self-attention -> FFN

2. [self-attention, FFN] 是一个 block，一般会有很多的 block

3. FFN 矩阵有两次变化，一次升维度，一次降维度。其中 LLaMA 对于 GPT 的改进还有把 GeLU 变成了 SwishGLU，多了一个矩阵。所以一般升维会从 4h -> 4h * 2 / 3

4. 原版的 transformers 用 post-norm, 后面 gpt2, llama 系列用的是 pre-norm。其中 llama 系列一般用 RMSNorm 代替 GPT and transformers decoder 中的 LayerNorm。

In [2]:
import torch
import math
import torch.nn as nn

In [None]:
# block / layer
class SimpleDecoderLayer(nn.Module):
    def __init__(self, hidden_dim, head_num, dropout_rate = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.head_dim = hidden_dim // head_num
        self.dropout = dropout_rate

        # layer (mha, ffn)
        # mha
        
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)
        self.drop_att = nn.Dropout(self.dropout)
        self.att_ln = nn.LayerNorm(hidden_dim, eps=1e-6)

        ## ffn (升维 -> 降维 -> ln)
        self.up_proj = nn.Linear(hidden_dim, 4 * hidden_dim)
        self.down_proj = nn.Linear(4 * hidden_dim, hidden_dim)
        self.act_ffn = nn.GELU()
        self.drop_ffn = nn.Dropout(0.1) 
        self.ffn_ln = nn.LayerNorm(hidden_dim, eps=1e-6)

    def attention_layer(self, query, key, value, attn_mask = None):
        # output shape: (b, s, h)
        key = key.transpose(2, 3)
        attn_weight = (query @ key) / math.sqrt(self.head_dim)

        # causal llm自带的下三角矩阵以及 padding不同序列的 attention_mask
        if attn_mask is not None:
            attn_mask = attn_mask.tril()
            attn_weight = attn_weight.masked_fill(
                attn_mask == 0, float("-inf")
            )
        else:   # 只有下三角矩阵
            attn_mask = torch.ones_like(
                attn_weight
            ).tril()
            attn_weight = attn_weight.masked_fill(
                attn_mask == 0, float("-inf")
            )
        
        # 先softmax 再dropout！
        attn_weight = torch.softmax(attn_weight, dim = -1)
        attn_weight = self.drop_att(attn_weight)
        # (b, head_num, seq, head_dim)
        mid_out = attn_weight @ value

        mid_out = mid_out.transpose(1, 2).contiguous()
        batch, seq, _, _ = mid_out.size()
        mid_out = mid_out.view(batch, seq, -1)  # cat 成 hidden_dim
        output = self.o_proj(mid_out)
        return output


    def mha(self, X, mask = None):
        # (b, s, h) -> (b, head_num, s, head_dim)
        batch, seq, _ = X.size()
        query = self.q_proj(X).view(batch, seq, self.head_num, -1).transpose(1,2)
        key = self.k_proj(X).view(batch, seq, self.head_num, -1).transpose(1,2)
        value = self.v_proj(X).view(batch, seq, self.head_num, -1).transpose(1,2)

        output = self.attention_layer(query, key, value, mask)

        # post norm shape: (b, s, h)
        return self.att_ln(X + output)

    def ffn(self, X):
        up = self.up_proj(X)
        up = self.act_ffn(up)
        down = self.down_proj(up)

        # dropout
        down = self.drop_ffn(down)
        # post layernorm
        return self.ffn_ln(X + down)

    def forward(self, X, attention_mask = None):
        X = self.mha(X, attention_mask)
        X = self.ffn(X)
        return X

x = torch.rand(3, 4, 64)
net = SimpleDecoderLayer(64, 8)     # head_num = head_dim = 8
mask = (
    torch.tensor(
        [
            [1, 1, 1, 1], 
            [1, 1, 0, 0], 
            [1, 1, 1, 0]
        ]
    )   # (3, 4)
    .unsqueeze(1)   # (3, 1, 4)
    .unsqueeze(2)   # (3, 1, 1, 4)
    .repeat(1, 8, 4, 1)
)

net(x, mask)

tensor([[[ 1.6452e+00, -6.5689e-01,  2.0645e+00,  8.3845e-01,  1.8109e-01,
           5.5268e-01, -1.0549e+00, -1.4376e+00,  1.2504e+00,  4.8248e-01,
           1.1925e+00,  8.2608e-01, -3.5209e-01, -5.0234e-01, -2.7126e-01,
           2.2702e-01, -4.5262e-02,  6.3747e-01, -2.0026e+00,  4.7674e-01,
          -8.0082e-01,  4.0579e-01, -1.6028e+00,  1.8945e-01, -2.6150e-01,
           2.3285e+00,  1.0221e+00,  4.6685e-01,  2.1390e-01, -3.8437e-01,
           5.5431e-01, -1.1759e+00, -8.3214e-02, -8.1166e-01,  9.1685e-01,
          -4.6888e-01,  5.5990e-02, -2.7177e+00,  3.0247e-01, -8.8141e-01,
          -7.8093e-01, -1.7215e+00, -2.3153e-01,  1.8139e+00,  2.8382e-02,
           1.3981e+00, -9.2162e-01,  2.0731e-01,  1.5194e+00, -6.9500e-01,
           2.6262e-01,  3.7997e-01,  8.3887e-01,  8.0925e-01, -7.7429e-01,
          -1.8555e-01, -1.4461e+00, -7.1366e-01,  2.5071e-01, -8.7361e-01,
          -4.9056e-01, -1.0009e+00,  1.3282e+00, -3.2088e-01],
         [-4.1449e-01,  2.9160e-01, -

In [None]:
class Decoder(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.layer_list = nn.ModuleList(
            [
                SimpleDecoderLayer(64, 8) for i in range(5)
            ]
        )
        self.emb = nn.Embedding(12, 64) # num_embeddings=12, embedding_dim=64
        self.out = nn.Linear(64, 12)

    def forward(self, X, mask = None):
        # X 的初始shape: (batch_size, seq_len)
        X = self.emb(X)
        for i, layer in enumerate(self.layer_list):
            X = layer(X, mask)
        print(X.shape)
        output = self.out(X)
        return torch.softmax(output, dim = -1)
    
x = torch.randint(low = 0, high = 12, size = (3, 4))
net = Decoder()
mask = (
    torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0], [1, 1, 1, 0]])
    .unsqueeze(1)
    .unsqueeze(2)
    .repeat(1, 8, 4, 1)
)

net(x, mask)

torch.Size([3, 4, 64])


tensor([[[0.0721, 0.1170, 0.0497, 0.1347, 0.0463, 0.0856, 0.1799, 0.0394,
          0.0957, 0.1127, 0.0316, 0.0354],
         [0.0474, 0.0769, 0.0737, 0.0942, 0.2074, 0.0532, 0.0978, 0.0261,
          0.0713, 0.1569, 0.0418, 0.0531],
         [0.0891, 0.1044, 0.0462, 0.2078, 0.0408, 0.0805, 0.1207, 0.0316,
          0.1288, 0.0826, 0.0290, 0.0385],
         [0.0602, 0.0867, 0.0329, 0.0656, 0.0986, 0.1495, 0.0727, 0.0480,
          0.1367, 0.0928, 0.0611, 0.0951]],

        [[0.0247, 0.1308, 0.0865, 0.0783, 0.0267, 0.0344, 0.1504, 0.1370,
          0.0460, 0.0662, 0.0612, 0.1579],
         [0.0180, 0.0669, 0.1251, 0.0642, 0.0789, 0.1124, 0.1500, 0.1136,
          0.0369, 0.0866, 0.0877, 0.0596],
         [0.0293, 0.0486, 0.0515, 0.0349, 0.0461, 0.0764, 0.1028, 0.1529,
          0.0489, 0.0702, 0.0785, 0.2599],
         [0.0574, 0.0947, 0.0114, 0.0568, 0.0471, 0.0574, 0.2577, 0.2192,
          0.0193, 0.0587, 0.0272, 0.0932]],

        [[0.0243, 0.0746, 0.0758, 0.0594, 0.0618, 0.1802, 0.