# 手写transformer decode

In [26]:
import math,torch

torch.seed()

38536810831500

In [None]:
import math,torch
import torch.nn as nn

class SimpleDecoderLayer(nn.Module):
    def __init__(self,hidden_dim,head_num,attn_dropout_rate=0.1):
        super().__init__()

        self.head_num = head_num
        # 每个头的维度
        self.head_dim = hidden_dim // head_num

        # qkv,o映射+drop和laynorm
        self.q_proj = nn.Linear(hidden_dim,hidden_dim)
        self.k_proj = nn.Linear(hidden_dim,hidden_dim)
        self.v_proj = nn.Linear(hidden_dim,hidden_dim)
        self.o_proj = nn.Linear(hidden_dim,hidden_dim)
        self.attn_dropout = nn.Dropout(attn_dropout_rate)
        self.attn_ln = nn.LayerNorm(hidden_dim,eps=0.00001)

        # ffn
        self.up_proj = nn.Linear(hidden_dim,hidden_dim*4)
        self.ffn_act = nn.ReLU()
        self.down_proj = nn.Linear(hidden_dim*4,hidden_dim)
        self.ffn_dropout = nn.Dropout(attn_dropout_rate)
        self.ffn_ln = nn.LayerNorm(hidden_dim,eps=0.00001)

    def attention_output(self,q,k,v,attention_mask):
         # 注意力分数
        attn_score = torch.matmul(q,k.transpose(-1,-2)) / math.sqrt(self.head_dim)
        
        # 掩码矩阵
        if attention_mask is not None:
            # 下三角矩阵，解码需要
            attention_mask = attention_mask.tril()
            attn_score = attn_score.masked_fill(
                attention_mask==0,
                float('-inf')
            )
        else: # 没有padding，也就是没有attention_mask
            # 下三角来自attention_score
            attention_mask = torch.ones_like(attn_score).tril()
            attn_score = attn_score.masked_fill(
                attention_mask==0,
                float('-inf')
            )
        attn_weight = torch.softmax(attn_score,dim=-1)
        attn_weight = self.attn_dropout(attn_weight)
        # output [batch_size,head_num,seq_len,head_dim]
    
        output = attn_weight@v
        
        # concat head
        # output [batch_size,seq_len,head_num,head_dim]
        output = output.transpose(1,2).contiguous()
       
        # output [batch_size,seq_len,head_num * head_dim]
        # 传进来的q是转置过seq_len和head_num的，所以直接用输出的前两个维度
        # batch,_,seq,_ = q.size()
        batch,seq,_,_ = output.size()
        output = output.view(batch,seq,-1)

        output = self.o_proj(output)

        return output
        
    def mha(self,x,attention_mask=None):
        # x [batch_size,seq_len,hidden_dim]
        batch,seq,_ = x.size()
        q = self.q_proj(x).view(batch,seq,self.head_num,-1).transpose(1,2)
        
        k = self.k_proj(x).view(batch,seq,self.head_num,-1).transpose(1,2)
        v = self.v_proj(x).view(batch,seq,self.head_num,-1).transpose(1,2)

        output = self.attention_output(q,k,v,attention_mask)
        return self.attn_ln(output+x)

    def ffn(self,x):
        up = self.up_proj(x)
        act = self.ffn_act(up)
        down = self.down_proj(act)

        # dropout
        down = self.ffn_dropout(down)
        # add + layernorm
        return self.ffn_ln(down + x)

    def forward(self,x,attention_mask=None):
        x = self.mha(x,attention_mask)
        x = self.ffn(x)
        return x

# 测试
# x = torch.rand(3, 4, 64)
# net = SimpleDecoderLayer(64, 8)
# mask = (
#     torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0], [1, 1, 1, 0]])
#     .unsqueeze(1)
#     .unsqueeze(2)
#     .repeat(1, 8, 4, 1)
# )

# net(x, mask).shape

torch.Size([3, 4, 64])

## chaofa用代码打点酱油 手写版本

In [None]:
# 导入相关需要的包
import math
import torch
import torch.nn as nn

import warnings
warnings.filterwarnings(action="ignore")

# 写一个 Block
class SimpleDecoder(nn.Module):
    def __init__(self, hidden_dim, nums_head, dropout=0.1):
        super().__init__()

        self.nums_head = nums_head
        self.head_dim = hidden_dim // nums_head

        self.dropout = dropout

        # 这里按照 transformers 中的 decoder 来写，用 post_norm 的方式实现，主意有 残差链接
        # eps 是为了防止溢出；其中 llama 系列的模型一般用的是 RMSnorm 以及 pre-norm（为了稳定性）
        # RMSnorm 没有一个 recenter 的操作，而 layernorm 是让模型重新变成 均值为 0，方差为 1
        # RMS 使用 w平方根均值进行归一化 $\sqrt{\frac{1}{n} \sum_{1}^{n}{a_i^2} }$
        self.layernorm_att = nn.LayerNorm(hidden_dim, eps=0.00001)

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)
        self.drop_att = nn.Dropout(self.dropout)

        # for ffn 准备
        self.up_proj = nn.Linear(hidden_dim, hidden_dim * 4)
        self.down_proj = nn.Linear(hidden_dim * 4, hidden_dim)
        self.layernorm_ffn = nn.LayerNorm(hidden_dim, eps=0.00001)
        self.act_fn = nn.ReLU()
        
        self.drop_ffn = nn.Dropout(self.dropout)

    def attention_output(self, query, key, value, attention_mask=None):
        # 计算两者相关性
        key = key.transpose(2, 3)  # (batch, num_head, head_dim, seq)
        att_weight = torch.matmul(query, key) / math.sqrt(self.head_dim)

        # attention mask 进行依次调整；变成 causal_attention
        if attention_mask is not None:
            # 变成下三角矩阵
            attention_mask = attention_mask.tril()
            att_weight = att_weight.masked_fill(attention_mask == 0, float("-1e20"))
        else:
            # 人工构造一个下三角的 attention mask
            attention_mask = torch.ones_like(att_weight).tril()
            att_weight = att_weight.masked_fill(attention_mask == 0, float("-1e20"))

        att_weight = torch.softmax(att_weight, dim=-1)
        print(att_weight)

        att_weight = self.drop_att(att_weight)

        mid_output = torch.matmul(att_weight, value)
        # mid_output shape is: (batch, nums_head, seq, head_dim)

        mid_output = mid_output.transpose(1, 2).contiguous()
        batch, seq, _, _ = mid_output.size()
        mid_output = mid_output.view(batch, seq, -1)
        output = self.o_proj(mid_output)
        return output

    def attention_block(self, X, attention_mask=None):
        batch, seq, _ = X.size()
        query = self.q_proj(X).view(batch, seq, self.nums_head, -1).transpose(1, 2)
        key = self.k_proj(X).view(batch, seq, self.nums_head, -1).transpose(1, 2)
        value = self.v_proj(X).view(batch, seq, self.nums_head, -1).transpose(1, 2)

        output = self.attention_output(
            query,
            key,
            value,
            attention_mask=attention_mask,
        )
        return self.layernorm_att(X + output)

    def ffn_block(self, X):
        up = self.act_fn(
            self.up_proj(X),
        )
        down = self.down_proj(up)

        # 执行 dropout
        down = self.drop_ffn(down)

        # 进行 norm 操作
        return self.layernorm_ffn(X + down)

    def forward(self, X, attention_mask=None):
        # X 一般假设是已经经过 embedding 的输入， (batch, seq, hidden_dim)
        # attention_mask 一般指的是 tokenizer 后返回的 mask 结果，表示哪些样本需要忽略
        # shape 一般是： (batch, nums_head, seq)

        att_output = self.attention_block(X, attention_mask=attention_mask)
        ffn_output = self.ffn_block(att_output)
        return ffn_output


# 测试

# x = torch.rand(3, 4, 64)
# net = SimpleDecoder(64, 8)
# mask = (
#     torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0], [1, 1, 1, 0]])
#     .unsqueeze(1)
#     .unsqueeze(2)
#     .repeat(1, 8, 4, 1)
# )

# net(x, mask).shape

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4636, 0.5364, 0.0000, 0.0000],
          [0.3056, 0.3526, 0.3418, 0.0000],
          [0.2186, 0.2720, 0.2724, 0.2370]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4781, 0.5219, 0.0000, 0.0000],
          [0.3292, 0.3364, 0.3344, 0.0000],
          [0.2520, 0.2495, 0.2418, 0.2566]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4706, 0.5294, 0.0000, 0.0000],
          [0.3260, 0.3509, 0.3231, 0.0000],
          [0.2512, 0.2583, 0.2446, 0.2460]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5406, 0.4594, 0.0000, 0.0000],
          [0.3464, 0.3334, 0.3203, 0.0000],
          [0.2658, 0.2430, 0.2423, 0.2489]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5143, 0.4857, 0.0000, 0.0000],
          [0.3335, 0.3027, 0.3638, 0.0000],
          [0.2503, 0.2514, 0.2570, 0.2414]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4990, 0.5010, 0.0000, 0.0000],
          [0.3308, 0.3

torch.Size([3, 4, 64])

### 对比两个代码

差异为0

In [None]:
import torch
import math
import torch.nn as nn
import random, numpy as np

# ========== 固定随机种子 ==========
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(1234)

# ========== 输入与mask ==========
x = torch.rand(3, 4, 64)
mask = (
    torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0], [1, 1, 1, 0]])
    .unsqueeze(1)
    .unsqueeze(2)
    .repeat(1, 8, 4, 1)
)

# ========== 初始化模型 ==========


set_seed(1234)
net1 = SimpleDecoderLayer(64, 8)
set_seed(1234)
net2 = SimpleDecoder(64, 8)
# Dropout 一定要关掉（或固定随机数）使用net.eval() 模式
net1.eval()
net2.eval()

# ========== 验证输出 ==========
with torch.no_grad():
    out1 = net1(x, mask)
    out2 = net2(x, mask)

print("输出是否相同：", torch.allclose(out1, out2, atol=1e-6))
print("输出差异：", torch.abs(out1 - out2).max())


tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4997, 0.5003, 0.0000, 0.0000],
          [0.3383, 0.3257, 0.3360, 0.0000],
          [0.2631, 0.2395, 0.2341, 0.2633]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5132, 0.4868, 0.0000, 0.0000],
          [0.3386, 0.3299, 0.3315, 0.0000],
          [0.2562, 0.2445, 0.2496, 0.2497]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4922, 0.5078, 0.0000, 0.0000],
          [0.3637, 0.3182, 0.3181, 0.0000],
          [0.2501, 0.2442, 0.2310, 0.2747]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4938, 0.5062, 0.0000, 0.0000],
          [0.3514, 0.3171, 0.3316, 0.0000],
          [0.2508, 0.2479, 0.2487, 0.2526]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5050, 0.4950, 0.0000, 0.0000],
          [0.3271, 0.3270, 0.3459, 0.0000],
          [0.2438, 0.2399, 0.2589, 0.2574]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4868, 0.5132, 0.0000, 0.0000],
          [0.3333, 0.3

# 完整Decoder层

In [33]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_list = nn.ModuleList(
            [   # hidden_dim,head_num 64,8
                SimpleDecoderLayer(64,8) for i in range(5)
            ]
        )
        # hidden_dim 64
        self.emb = nn.Embedding(12,64)
        self.out = nn.Linear(64,12)

    def forward(self,x,mask=None):
        # x (b,s)
        x = self.emb(x)
        for i,l in enumerate(self.layer_list):
            x = l(x,mask)
        print(x.shape)
        output = self.out(x)
        return torch.softmax(output,dim=-1)

# 完整测试代码
x = torch.randint(low=0,high=12,size=(3,4))

net = Decoder()
mask = (
    torch.tensor(
        [
            [1,1,1,1],
            [1,1,0,0],
            [1,1,1,0]
        ]
    )
    .unsqueeze(1)
    .unsqueeze(2)
    .repeat(1,8,4,1) # head_num,seq_len: 8,4

)

net(x,mask)

torch.Size([3, 4, 64])


tensor([[[0.0493, 0.0713, 0.1049, 0.1347, 0.1210, 0.0952, 0.0727, 0.0426,
          0.0845, 0.0876, 0.0280, 0.1082],
         [0.0637, 0.0771, 0.1638, 0.1294, 0.0771, 0.0827, 0.0429, 0.0499,
          0.0515, 0.0619, 0.1050, 0.0950],
         [0.0520, 0.0907, 0.0710, 0.1408, 0.0933, 0.1512, 0.0784, 0.0479,
          0.0252, 0.1195, 0.0292, 0.1007],
         [0.1514, 0.1203, 0.0526, 0.0758, 0.0534, 0.0576, 0.0628, 0.1585,
          0.0442, 0.1429, 0.0401, 0.0404]],

        [[0.0508, 0.0223, 0.1029, 0.2019, 0.1459, 0.0300, 0.0330, 0.0841,
          0.0943, 0.1302, 0.0474, 0.0573],
         [0.1275, 0.0979, 0.0654, 0.0678, 0.0346, 0.0635, 0.0599, 0.1868,
          0.0745, 0.1091, 0.0613, 0.0516],
         [0.0632, 0.0465, 0.0421, 0.1633, 0.0864, 0.0526, 0.1039, 0.1056,
          0.0694, 0.0764, 0.1070, 0.0836],
         [0.0586, 0.0563, 0.0470, 0.0998, 0.1076, 0.2248, 0.0673, 0.0567,
          0.0461, 0.1066, 0.0428, 0.0863]],

        [[0.0249, 0.0410, 0.1306, 0.2135, 0.0965, 0.0464, 0.