搓完Transformer的结构之后开始手搓GPT2,自己写出来的结构之后方便做Probing

计算 Transformer（以 GPT-2 为例）的参数量，需要将模型中的每个模块的参数逐一分解计算，然后汇总得到总参数量。

假设有以下参数定义：
- \( n_{\text{heads}} \): 注意力头数
- \( n_{\text{layers}} \): Transformer 解码器层数
- \( d_{\text{model}} \): 每个 token 的嵌入维度（模型维度）
- \( d_{\text{ffn}} \): 前馈网络隐藏层的维度
- \( d_k \): 每个注意力头的键和查询向量的维度
- \( d_v \): 每个注意力头的值向量的维度
- \( V \): 词汇表大小

以下是分模块的参数计算公式和总参数量公式：

---

### **1. 输入嵌入和位置嵌入**
- **词嵌入矩阵：**
  \[
  \text{Params}_{\text{embeddings}} = V \times d_{\text{model}}
  \]
  （词汇表大小 \(V\) × 模型维度 \(d_{\text{model}}\)）

- **位置嵌入矩阵：**
  \[
  \text{Params}_{\text{positional embeddings}} = L \times d_{\text{model}}
  \]
  （序列长度 \(L\) × 模型维度 \(d_{\text{model}}\)）

---

### **2. 注意力机制（Multi-Head Self-Attention, MHA）**
对于单层的 MHA：
1. **键 (Key)、查询 (Query)、值 (Value) 的投影层：**
   - 每个头的键、查询、值权重：\( d_{\text{model}} \times d_k, d_{\text{model}} \times d_k, d_{\text{model}} \times d_v \)
   - 总权重矩阵：
     \[
     \text{Params}_{\text{QKV}} = 3 \times d_{\text{model}} \times (d_k \times n_{\text{heads}})
     \]

2. **注意力输出的线性层：**
   - 汇聚后的多头注意力的投影：
     \[
     \text{Params}_{\text{attention output}} = (d_k \times n_{\text{heads}}) \times d_{\text{model}}
     \]

---

### **3. 前馈网络 (Feed-Forward Network, FFN)**
前馈网络由两层全连接层组成：
1. 第一层的权重：
   \[
   \text{Params}_{\text{FFN1}} = d_{\text{model}} \times d_{\text{ffn}}
   \]

2. 第二层的权重：
   \[
   \text{Params}_{\text{FFN2}} = d_{\text{ffn}} \times d_{\text{model}}
   \]

3. 两层的偏置参数（可选，但通常较小）：
   \[
   \text{Params}_{\text{FFN bias}} = d_{\text{ffn}} + d_{\text{model}}
   \]

---

### **4. 层归一化 (Layer Normalization)**
每层有两个参数（权重和偏置）：
\[
\text{Params}_{\text{LayerNorm}} = 2 \times d_{\text{model}}
\]

---

### **5. 输出层**
输出层通常是一个线性层，将模型维度 \(d_{\text{model}}\) 投影到词汇表大小 \(V\)：
\[
\text{Params}_{\text{output}} = d_{\text{model}} \times V
\]

---

### **总参数量公式**
假设 Transformer 有 \( n_{\text{layers}} \) 层，则总参数量公式为：

\[
\text{Total Parameters} =
\text{Params}_{\text{embeddings}} + \text{Params}_{\text{positional embeddings}} + n_{\text{layers}} \times \left( \text{Params}_{\text{MHA}} + \text{Params}_{\text{FFN}} + \text{Params}_{\text{LayerNorm}} \right) + \text{Params}_{\text{output}}
\]

其中：
- **MHA 参数：**
  \[
  \text{Params}_{\text{MHA}} = \text{Params}_{\text{QKV}} + \text{Params}_{\text{attention output}}
  = 3 \times d_{\text{model}} \times (d_k \times n_{\text{heads}}) + (d_k \times n_{\text{heads}}) \times d_{\text{model}}
  \]

- **FFN 参数：**
  \[
  \text{Params}_{\text{FFN}} = \text{Params}_{\text{FFN1}} + \text{Params}_{\text{FFN2}} + \text{Params}_{\text{FFN bias}}
  = d_{\text{model}} \times d_{\text{ffn}} + d_{\text{ffn}} \times d_{\text{model}} + d_{\text{ffn}} + d_{\text{model}}
  \]

- **LayerNorm 参数：**
  \[
  \text{Params}_{\text{LayerNorm}} = 2 \times d_{\text{model}}
  \]

---

### **参数量示例**
假设：
- \( n_{\text{heads}} = 12 \)
- \( n_{\text{layers}} = 12 \)
- \( d_{\text{model}} = 768 \)
- \( d_{\text{ffn}} = 3072 \)
- \( d_k = d_v = d_{\text{model}} / n_{\text{heads}} = 64 \)
- \( V = 50,000 \)
- \( L = 512 \)

计算：
1. **嵌入层：**
   \[
   \text{Params}_{\text{embeddings}} = 50,000 \times 768 = 38.4M
   \]
   \[
   \text{Params}_{\text{positional embeddings}} = 512 \times 768 = 0.39M
   \]

2. **MHA（单层）：**
   \[
   \text{Params}_{\text{QKV}} = 3 \times 768 \times 64 \times 12 = 1.77M
   \]
   \[
   \text{Params}_{\text{attention output}} = 768 \times 768 = 0.59M
   \]
   总计：
   \[
   \text{Params}_{\text{MHA}} = 1.77M + 0.59M = 2.36M
   \]

3. **FFN（单层）：**
   \[
   \text{Params}_{\text{FFN}} = 768 \times 3072 + 3072 \times 768 + 3072 + 768 = 4.72M
   \]

4. **LayerNorm（单层）：**
   \[
   \text{Params}_{\text{LayerNorm}} = 2 \times 768 = 1.5K
   \]

5. **每层参数：**
   \[
   \text{Params}_{\text{per layer}} = \text{Params}_{\text{MHA}} + \text{Params}_{\text{FFN}} + \text{Params}_{\text{LayerNorm}} \approx 7.08M
   \]

6. **总参数量：**
   \[
   \text{Total Parameters} = 38.4M + 0.39M + 12 \times 7.08M + 50,000 \times 768
   = 124M
   \]

import torch
import torch.nn as nn
import math
import numpy as np

# ====================================================================================================
# 位置编码类 (Positional Encoding)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        '''
        初始化位置编码类
        d_model:    模型的维度，即每个词的嵌入维度
        dropout:    防止过拟合的dropout
        max_len:    最大序列长度，用于生成位置编码
        '''
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)  # 用于存储位置编码
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        # 偶数维度使用正弦，奇数维使用余弦
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # 添加一个batch维度并转置
        pe = pe.unsqueeze(0).transpose(0, 1)

        # 将位置编码存储为不可训练的缓冲区
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        将位置编码添加到张量x中
        x: [seq_len, batch_size, d_model]
        加入位置编码并dropout之后的张量
        """
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


def get_attn_pad_mask(seq_q, seq_k):
    """
    创建 Padding Mask ，用于屏蔽序列中的填充（PAD）位置
    seq_q: [batch_size, seq_len]
    seq_k: [batch_size, seq_len]
    返回 
        pad_attn_mask :形状为[batch_size, len_q, len_k]的掩码, True表示被掩盖的位置
    """
    batch_size, len_q = seq_q.size()  # 这个seq_q只是用来expand维度的
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # 在某维插入新维度
    return pad_attn_mask.expand(batch_size, len_q, len_k)


def get_attn_subsequence_mask(seq):
    """
    seq: [batch_size, tgt_len]
    """
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1)  # 生成一个上三角矩阵
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask  # [batch_size, tgt_len, tgt_len]


# ==========================================================================================
# 缩放点积注意力（Scaled Dot Product Attention）

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        """
        Q: Query [batch_size, n_heads, len_q, d_k]
        K: Key [batch_size, n_heads, len_k, d_k]
        V: Value [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        """
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(K.size(-1))  # 点积并缩放
        scores.masked_fill_(attn_mask, -1e9)  # 使用掩码屏蔽掉填充位置

        attn = nn.Softmax(dim=-1)(scores)  # 对最后一个维度做softmax
        context = torch.matmul(attn, V)  # 计算上下文向量

        return context, attn


# 多头注意力机制 (Multi-head Attention)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.n_heads = n_heads

        # 线性层将输入映射到多头子空间
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        
        # 用于将多头注意力的输出映射回d_model
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

    def forward(self, input_Q, input_K, input_V, attn_mask):
        batch_size = input_Q.size(0)

        # 将输入的Q、K、V投影到多个子空间
        Q = self.W_Q(input_Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(input_K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_V(input_V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)

        # 扩展注意力掩码的维度以适应多头机制
        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)

        # 计算缩放点积注意力
        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)

        # 将不同头的输出拼接在一起
        context = context.transpose(1, 2).reshape(batch_size, -1, self.n_heads * self.d_v)

        # 使用全连接层将拼接的输出映射到模型维度
        output = self.fc(context)

        return nn.LayerNorm(self.d_model)(output + input_Q), attn


# 前馈神经网络 (Feed-Forward Network)

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False)
        )

    def forward(self, inputs):
        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(inputs.size(-1))(output + residual)


# 编码器层 (Encoder Layer)

class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_k, d_v, n_heads, d_ff):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads)
        self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
        enc_outputs = self.pos_ffn(enc_outputs)
        return enc_outputs, attn

