![MTP](./img/MTP.png)

## Multi-Token Prediction (MTP) 总体思路

### 目标与动机
  - **目标**：在每个位置预测多个未来令牌（$t_{i+1}, t_{i+2}, \dots, t_{i+D}$）
  - **动机**：
    - 增加训练信号密度，提高数据效率
    - 帮助模型预规划表示，改善未来预测
    - 保持完整因果链（顺序预测而非并行）

### 架构概览
  - **模块数量**：$D$个顺序MTP模块（对应$D$个预测深度）
  - **共享组件**：嵌入层$\text{Emb}(\cdot)$和输出头$\text{OutHead}(\cdot)$与主模型共享
  - **专用组件**：每个深度$k$有专用Transformer块$\text{TRM}_k(\cdot)$和投影矩阵 $M_k$


## 公式原理

### 1. 表示结合与投影
  对于第$i$个令牌，在深度$k$的表示计算：
  $$h_i^k = M_k \left[ \text{RMSNorm}(h_i^{k-1}); \text{RMSNorm}(\text{Emb}(t_{i+k})) \right]$$
  其中：
  - $h_i^{k-1} \in \mathbb{R}^d$：前一个深度的表示
  - $\text{Emb}(t_{i+k}) \in \mathbb{R}^d$：未来令牌的嵌入
  - $M_k \in \mathbb{R}^{d \times 2d}$：投影矩阵
  - $[:,:]$：向量连接操作

**特殊情况**：当$k=1$时，$h_i^{k-1}$来自主模型输出。

### 2. Transformer处理
  将投影后的表示输入Transformer块：
  $$h_{1:T-k}^k = \text{TRM}_k(h_{1:T-k}^{tk})$$
  - $T$：序列长度
  - $t_{ij}$：切片操作（包含边界）

### 3. 概率分布计算
  使用共享输出头预测概率：
  $$P_{i+k+1}^k = \text{OutHead}(h_i^k)$$
  - $\text{OutHead}(\cdot)$：线性映射 + Softmax
  - $P_{i+k+1}^k \in \mathbb{R}^V$：词汇表上的概率分布
  - $V$：词汇表大小

## Loss计算

### 单个深度损失
  对于深度$k$的交叉熵损失：
  $$\mathcal{L}_{MTP}^{k} = -\frac{1}{T} \sum_{i=2+k}^{T+1} \log P_i^k[t_i]$$
  其中：
  - $t_i$：第$i$个位置的真实令牌
  - $P_i^k[t_i]$：模型对$t_i$的预测概率

### 整体MTP损失
  对所有深度取平均并加权：
  $$\mathcal{L}_{MTP} = \frac{\lambda}{D} \sum_{k=1}^{D} \mathcal{L}_{MTP}^k$$
  - $\lambda$：权重因子（超参数）
  - $D$：总预测深度

## 推理阶段

### 主要模式
  $$\text{推理} = \text{主模型（丢弃MTP模块）}$$
  - MTP模块仅用于训练增强
  - 推理时主模型独立工作，无需修改
  

In [1]:
import copy
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from typing import List, Dict, Any
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import autocast, GradScaler



In [2]:
#获取model对应的 transformers block
model_path = "../model/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)

In [3]:
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
import inspect

inspect.signature(Qwen2DecoderLayer.forward)

<Signature (self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[transformers.cache_utils.Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: typing_extensions.Unpack[transformers.utils.generic.TransformersKwargs]) -> torch.Tensor>

MTP

In [4]:
class MLPModule(nn.Module):
    """轻量 MLP 版本（作为 Transformer block 的 fallback）"""
    def __init__(self, hidden_size):
        super().__init__()
        self.ff1 = nn.Linear(2 * hidden_size, 4 * hidden_size)
        self.act = nn.GELU()
        self.ff2 = nn.Linear(4 * hidden_size, hidden_size)

    def forward(self, x):
        # x: (B, S, 2H)
        x = self.ff1(x)
        x = self.act(x)
        x = self.ff2(x)
        return x  # (B, S, H)

In [5]:
class TransformerWrapper(nn.Module):
    """
    从 main_model 中拿到一个 Transformer block
    对于 concat 后的 proj_down : 2H --> H
    """
    def __init__(self, block_module, hidden_size):
        super().__init__()
        # block_module 是从主模型复制来的单层 transformer block（nn.Module）
        self.block = copy.deepcopy(block_module)
        # 如果 block 的输入期望 H 维，但我们拼接后是 2H，先用线性降维
        self.input_proj = nn.Linear(2 * hidden_size, hidden_size)

    def forward(self, prev_hidden, input_embed, attention_mask=None, **kwargs):
        # prev_hidden & input_embed: (B, S, H)
        x = torch.cat([prev_hidden, input_embed], dim=-1)  # (B,S,2H)
        x = self.input_proj(x)  # (B,S,H)
        # 假定 block 的 forward 接口为 block(x, attention_mask=...)
        out = self.block(x, attention_mask=attention_mask, **kwargs)
        # block 可能返回 tuple，取第0项作为 hidden
        if isinstance(out, (tuple, list)):
            out = out[0]
        return out  # (B,S,H)

In [6]:
class MTPHead(nn.Module):
    """把 hidden -> vocab logits 的 head，便于做权重共享或替换"""
    def __init__(self, hidden_size, vocab_size, tie_embedding=None):
        super().__init__()
        self.linear = nn.Linear(hidden_size, vocab_size)
        # 权重绑定到 embedding 的 weight（weight tying）
        if tie_embedding is not None:
            self.linear.weight = tie_embedding

    def forward(self, hidden_states):
        logits = self.linear(hidden_states)  # (B, S, H) --> (B, S, V)
        return logits

In [7]:
class MTP(nn.Module):
    def __init__(self,
                 model,
                 predict_tokens_num: int = 5,
                 mtp_lambda: float = 0.5,
                 random_depth_rate: float = 1.0,  # 1.0 表示每个 batch 全部 depth 都计算；<1 表示随机采样部分 depth
                 use_mlp: bool = True,
                 freeze_base_model: bool = True,
                 use_peft: bool = False
                 ):
        """
        random_depth_rate:
          - 1.0: 训练时计算所有 MTP depth 的 loss
          - 0.5: 随机采样约一半的 depth 来计算（节省计算）
        """
        super().__init__()
        self.predict_tokens_num = predict_tokens_num
        self.mtp_lambda = mtp_lambda
        self.random_depth_rate = random_depth_rate
        self.use_mlp = use_mlp
        self.use_peft = use_peft    

        # 载入主模型（取 base_model 方便直接拿 last_hidden_state / embeddings）
        if self.use_peft:
            self.main_model = model
        else:
            self.main_model = model.base_model


        # 冻结基础模型参数（关键步骤！）
        if freeze_base_model:
            for param in self.main_model.parameters():
                param.requires_grad = False
            if hasattr(model, 'lm_head'):
                for param in model.lm_head.parameters():
                    param.requires_grad = True

        H = self.main_model.config.hidden_size
        V = self.main_model.config.vocab_size

        mtp_modules = []

        if use_mlp:
            for _ in range(self.predict_tokens_num-1):
                mtp_modules.append(MLPModule(H))
        else:
            transformers_block = self.main_model.layers[0]
            for _ in range(self.predict_tokens_num-1):
                mtp_modules.append(
                    TransformerWrapper(transformers_block, H)
                )
        self.mtp_modules = nn.ModuleList(mtp_modules)

        # 输出 head，默认与 embedding weight tying
        embedding_weight = self.main_model.get_input_embeddings().weight
        self.output_head = MTPHead(H, V, tie_embedding=embedding_weight)

    def forward_main(self, input_ids, attention_mask=None, **kwargs):
        """
        main_hidden_output: (B,S,H) and main_logits: (B,S,V)
        """
        with torch.no_grad():
            if self.use_peft:
                base = self.main_model.get_base_model()          
                outputs = base(input_ids=input_ids,
                            attention_mask=attention_mask,
                            output_hidden_states=True,           # 显式要求返回 hidden_states
                            **kwargs)
                last_hidden = outputs.hidden_states[-1] 
            else:    
                outputs = self.main_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
                last_hidden = outputs.last_hidden_state  # (B,S,H)
            logits = self.output_head(last_hidden)   # (B,S,V)          # 取最后一层
        return last_hidden, logits

    def forward_mtp_once(self, input_ids, prev_hidden, head_index, attention_mask=None, **kwargs):
        """
        mtp_hidden: (B,S,H) and mtp_logits: (B,S,V)
        如果使用 transformer wrapper，需要传 attention_mask
        """
        # embedding
        input_embed = self.main_model.get_input_embeddings()(input_ids)  # (B,S,H)
        if self.use_mlp:
            # 直接拼接并输入MLP
            concat = torch.cat([prev_hidden, input_embed], dim=-1)  # (B,S,2H)
            mtp_hidden = self.mtp_modules[head_index](concat)  # (B,S,H)
        else:
            # Transformer版本,使用TransformerWrapper
            module = self.mtp_modules[head_index]
            mtp_hidden = module(prev_hidden, input_embed, 
                                attention_mask=attention_mask, **kwargs)
        mtp_logits = self.output_head(mtp_hidden)  # (B,S,V)
        return mtp_hidden, mtp_logits

    def forward(self, input_ids, attention_mask=None, training=True, sample_depths=None, **kwargs):
        """
        若 training=True：默认返回 dict 包含 'head_main' 和 'mtp_head_i'（可能只包含部分 depth，受 sample_depths 控制）
        若 training=False：只返回 main head 的 logits（推理）
        sample_depths: 如果为 None，依据 random_depth_rate 决定是否随机采样 depths 以节省计算
        """
        outputs = {}
        main_hidden, main_logits = self.forward_main(input_ids, attention_mask=attention_mask, **kwargs)
        outputs['head_main'] = main_logits
        if not training:
            return outputs  # 推理只需主 head

        # 训练时决定计算哪些 depth
        D = self.predict_tokens_num - 1
        if sample_depths is None:
            if self.random_depth_rate >= 1.0:  # deafult
                sample_depths = list(range(D))
            else:
                # 随机采样 depth 索引 随机选择 num 个
                num = max(1, int(D * self.random_depth_rate))
                all_idx = list(range(D))
                sample_depths = sorted(torch.randperm(D)[:num].tolist())
                
        prev_hidden = main_hidden
        for idx in range(D):
            # 如果当前 depth 不在要计算的 sample_depths，仍需推进 prev_hidden（因为 MTP 是链式）
            prev_hidden, mtp_logits = self.forward_mtp_once(input_ids, prev_hidden, idx, attention_mask=attention_mask, **kwargs)
            # 只有在 sample_depths 中才把 logits 写入 outputs（避免不必要的内存）
            if idx in sample_depths:
                outputs[f'mtp_head_{idx}'] = mtp_logits

        return outputs

    def compute_loss(self, outputs: dict, labels: torch.Tensor):
        """
        统一计算 main loss 与 mtp losses。
        labels: (B, S), 且 pad 已经被 -100 填充
        返回： main_loss, mtp_loss (可能为 0), total_loss
        """
        device = labels.device
        main_logits = outputs['head_main']  # (B,S,V)
        B, S, V = main_logits.shape

        # main loss: predict t+1
        main_logits_flat = main_logits[:, :-1, :].reshape(-1, V)
        main_targets = labels[:, 1:].reshape(-1)
        main_loss = F.cross_entropy(main_logits_flat, main_targets, ignore_index=-100)

        # mtp losses: 对每个存在的 mtp_head_i 进行对齐计算
        mtp_losses = []
        for key in outputs:
            if not key.startswith('mtp_head_'):
                continue
            idx = int(key.split('_')[-1])  # 0-based head index
            mtp_logits = outputs[key]  # (B,S,V)
            offset = idx + 2  # head_index=0 -> predict t+2
            valid_len = S - offset
            if valid_len <= 0:
                continue
            logits = mtp_logits[:, :valid_len, :].reshape(-1, V)
            targets = labels[:, offset:offset+valid_len].reshape(-1)
            loss_i = F.cross_entropy(logits, targets, ignore_index=-100)
            mtp_losses.append(loss_i)

        if len(mtp_losses) > 0:
            mtp_loss = torch.stack(mtp_losses).mean()
            total_loss = main_loss + self.mtp_lambda * mtp_loss
        else:
            mtp_loss = torch.tensor(0.0, device=device)
            total_loss = main_loss

        return main_loss, mtp_loss, total_loss

    @torch.no_grad()
    def generate_autoregressive(self, input_ids, attention_mask=None, max_length=50):
        """
        简化的生成函数：不启用 speculative decoding，仅用主模型逐步生成。
        若需要复杂的 speculative decoding，可以在此基础上扩展。
        """
        self.eval()
        device = input_ids.device
        seq = input_ids.clone()
        batch_size = seq.size(0)
        for _ in range(max_length - seq.size(1)):
            outputs = self.forward(seq, attention_mask=attention_mask, training=False)
            logits = outputs['head_main']  # (B, S, V)
            next_logits = logits[:, -1, :]  # (B, V)
            next_token = torch.argmax(next_logits, dim=-1, keepdim=True)  # (B,1)
            seq = torch.cat([seq, next_token], dim=1)
        return seq


In [8]:
def train(model: MTP, dataloader, optimizer, writer, device='cuda',
          epochs=5, print_step=10, save_step=1000, save_path='../model/mtp/checkpoint',
          mtp_lambda=0.5, grad_clip=1.0):
    
    os.makedirs(save_path, exist_ok=True)
    scaler = GradScaler()
    model.to(device)
    steps = 0
    model.train()
    
    for epoch in range(epochs):
        for step, batch in enumerate(dataloader):
            optimizer.zero_grad()
            
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)

            # 混合精度前向传播
            with autocast():
                outputs = model(input_ids, attention_mask=None, training=True)
                main_loss, mtp_loss, total_loss = model.compute_loss(outputs, labels)

            # 正确的顺序：scale -> backward -> unscale -> clip -> step -> update
            scaler.scale(total_loss).backward()
            
            # 取消缩放梯度
            scaler.unscale_(optimizer)
            
            # 梯度裁剪（必须在unscale之后）
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            
            # 更新参数
            scaler.step(optimizer)
            scaler.update()

            # 记录日志
            if steps % print_step == 0:
                writer.add_scalar('train/main_loss', main_loss.item(), steps)
                writer.add_scalar('train/mtp_loss', mtp_loss.item(), steps)
                writer.add_scalar('train/total_loss', total_loss.item(), steps)
                print(f"[Epoch {epoch+1}] Step {steps}, main_loss={main_loss.item():.4f}, mtp_loss={mtp_loss.item():.4f}, total_loss={total_loss.item():.4f}")

            if steps % save_step == 0 and steps > 0:
                torch.save({
                    'model_state': model.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'step': steps,
                    'scaler_state': scaler.state_dict()
                }, f"{save_path}/checkpoint_{steps}.pt")

            steps += 1

In [9]:
class MyDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, max_length=512):
        super().__init__()
        self.dataset = hf_dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        sample = self.dataset[index]
        user = sample["input"]
        assistant = sample["target"]

        # 构造 prompt
        q = self.tokenizer.apply_chat_template(
            [{"role": "user", "content": user}],
            tokenize=False,
            add_generation_prompt=True
        )

        # 拼接答案（带 eos）
        a = assistant + self.tokenizer.eos_token

        q_input_ids = self.tokenizer(q)["input_ids"]
        a_input_ids = self.tokenizer(a)["input_ids"]

        input_ids = q_input_ids + a_input_ids
        labels = [-100] * len(q_input_ids) + a_input_ids

        if len(input_ids) > self.max_length:
            input_ids = input_ids[:self.max_length]
            labels = labels[:self.max_length]

        return {
            "input_ids": input_ids,
            "labels": labels,
        }


In [10]:
class MyDataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
    
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        max_len = max(len(feature['input_ids']) for feature in features)
        input_ids = []
        labels = []
        attention_mask = []

        for feature in features:
            # pad input_ids
            padding_len = max_len - len(feature["input_ids"])
            input_ids.append(feature["input_ids"] + [self.tokenizer.pad_token_id] * padding_len)

            # pad labels (用 -100 而不是 pad_token_id)
            labels.append(feature["labels"] + [-100] * padding_len)

            # attention_mask
            attention_mask.append([1] * len(feature["input_ids"]) + [0] * padding_len)
            
        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
        }


In [11]:
ds = load_dataset("YeungNLP/firefly-train-1.1M")
ds["train"]

Repo card metadata block was not found. Setting CardData to empty.


Dataset({
    features: ['kind', 'input', 'target'],
    num_rows: 1649399
})

In [12]:
dataset = ds["train"].shuffle(42).select(range(2000))
train_data = MyDataset(dataset, tokenizer, max_length=512)
collator = MyDataCollator(tokenizer)
loader = DataLoader(train_data, batch_size=2, collate_fn=collator)

In [13]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)
use_peft=False
freeze_base_model = True
model = AutoModelForCausalLM.from_pretrained(model_path)
if use_peft:
    model = get_peft_model(model, lora_config)     
mtp_model = MTP(model=model,
                predict_tokens_num=5,
                mtp_lambda=0.3,
                use_mlp=True,
                freeze_base_model=freeze_base_model,
                use_peft=use_peft)             

In [14]:
def print_trainable_parameters(model):
    """
    打印模型的可训练参数数量
    """
    trainable_params = 0
    all_params = 0
    for name, param in model.named_parameters():
        all_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    
    print(f"可训练参数总数: {trainable_params / 1e6:.2f}M")
    print(f"总参数: {all_params / 1e6:.2f}M")
    print(f"训练参数占比: {100 * trainable_params / all_params:.2f}%")
    return trainable_params, all_params

trainable_params, all_params = print_trainable_parameters(mtp_model)

可训练参数总数: 174.84M
总参数: 532.74M
训练参数占比: 32.82%


In [15]:
writer = SummaryWriter('../model/mtp3/runs')
optimizer = torch.optim.Adam(mtp_model.parameters(), lr=1e-4)
train(mtp_model, loader, optimizer, 
      writer, device='cuda', 
      epochs=1, print_step=10, 
      save_step=1000, save_path='../model/mtp3/checkpoint')

  attn_output = torch.nn.functional.scaled_dot_product_attention(


[Epoch 1] Step 0, main_loss=4.9228, mtp_loss=12.0181, total_loss=8.5282
[Epoch 1] Step 10, main_loss=4.0091, mtp_loss=9.8869, total_loss=6.9752
[Epoch 1] Step 20, main_loss=1.8060, mtp_loss=9.1845, total_loss=4.5614
[Epoch 1] Step 30, main_loss=2.4778, mtp_loss=8.7691, total_loss=5.1086
[Epoch 1] Step 40, main_loss=2.5200, mtp_loss=9.1255, total_loss=5.2576
[Epoch 1] Step 50, main_loss=3.0501, mtp_loss=8.8223, total_loss=5.6968


KeyboardInterrupt: 

![MTP loss](./img/MTP_loss.png)