# 投机推理（Speculative Decoding）

相关文章链接：[LLM提速利器：投机推理的原理与常见方案](https://zhuanlan.zhihu.com/p/1978037808544370747)

示例内容：
* 投机推理原理演示
* Ngram原理演示
* vLLM的示例


Author: kaiyuan

Email: kyxie@zju.edu.cn

## 1 投机推理原理演示

模拟草稿模型+大模型：草稿模型推理，大模型负责验证。

可修改每次推理的字符个数：默认3个

spec_sim.speculative_step(3)  # 每次提议3个字符

In [None]:
import random
import time
from typing import List, Tuple, Callable

class SpeculativeInferenceSimulator:
    def __init__(self, target_text: str = "人工智能是未来科技发展的核心驱动力。"):
        self.target_text = target_text
        self.current_index = 0
        self.generated_text = ""
        self.steps_log = []

    def small_model_generate(self, num_chars: int = 3) -> str:
        """小模型生成函数 - 有一定概率生成正确文本"""
        if self.current_index >= len(self.target_text):
            return ""

        # 模拟小模型的生成：70%概率生成正确，30%概率生成近似字符
        remaining = self.target_text[self.current_index:]
        candidate_length = min(num_chars, len(remaining))

        # 构建候选文本
        candidate = ""
        for i in range(candidate_length):
            correct_char = remaining[i]
            # 小模型生成逻辑
            if random.random() < 0.7:  # 70%概率生成正确字符
                candidate += correct_char
            else:
                # 生成相似但可能错误的字符
                similar_chars = self._get_similar_chars(correct_char)
                candidate += random.choice(similar_chars)

        return candidate

    def large_model_validate(self, candidate: str) -> Tuple[bool, str]:
        """大模型验证函数 - 高准确率验证"""
        if not candidate:
            return False, ""

        remaining = self.target_text[self.current_index:]

        # 逐字符验证
        for i, char in enumerate(candidate):
            if i >= len(remaining):
                return False, ""

            if char == remaining[i]:
                continue
            else:
                # 验证失败，返回正确的部分
                return False, candidate[:i]

        return True, candidate

    def _get_similar_chars(self, char: str) -> List[str]:
        """获取相似字符（用于模拟小模型错误）"""
        similarity_map = {
            "人": ["人", "入", "八"],
            "工": ["工", "二", "干"],
            "智": ["智", "知", "日"],
            "能": ["能", "熊", "态"],
            "是": ["是", "定", "日"],
            "未": ["未", "末", "木"],
            "来": ["来", "米", "夹"],
            "科": ["科", "料", "和"],
            "技": ["技", "枝", "支"],
            "发": ["发", "友", "皮"],
            "展": ["展", "层", "屋"],
            "的": ["的", "得", "地"],
            "核": ["核", "该", "孩"],
            "心": ["心", "必", "芯"],
            "驱": ["驱", "欧", "区"],
            "动": ["动", "运", "力"],
            "力": ["力", "刀", "九"],
            "。": ["。", ".", "．"]
        }

        return similarity_map.get(char, [char, "?", "X"])

    def speculative_step(self, small_model_chars: int = 3):
        """执行一次投机推理步骤"""
        if self.current_index >= len(self.target_text):
            print("✓ 文本生成完成！")
            return False

        step_info = {
            "step": len(self.steps_log) + 1,
            "current_text": self.generated_text,
            "remaining_target": self.target_text[self.current_index:],
            "small_model_generated": "",
            "validation_result": "",
            "accepted": False,
            "final_text": ""
        }

        print(f"\n{'='*60}")
        print(f"步骤 {step_info['step']}: 当前已生成: '{self.generated_text}'")
        print(f"目标剩余: '{self.target_text[self.current_index:]}'")

        # 小模型生成候选
        candidate = self.small_model_generate(small_model_chars)
        step_info["small_model_generated"] = candidate

        print(f"小模型生成: '{candidate}'")
        print("-"*40)

        if not candidate:
            return False

        # 大模型验证
        is_valid, valid_part = self.large_model_validate(candidate)
        step_info["validation_result"] = f"验证结果: {'通过' if is_valid else '拒绝'} | 有效部分: '{valid_part}'"

        print(f"大模型验证: {step_info['validation_result']}")

        if valid_part:
            # 接受有效部分
            self.generated_text += valid_part
            self.current_index += len(valid_part)
            step_info["accepted"] = True
            step_info["final_text"] = self.generated_text

            print(f"✓ 接受: '{valid_part}'")
            print(f"更新文本: '{self.generated_text}'")
        else:
            print(f"✗ 全部拒绝，继续尝试...")

        # 如果验证完全失败，回退一个字符重新开始
        if not valid_part and self.current_index > 0 and len(candidate) > 0:
            # 单步回退策略
            single_char = self.target_text[self.current_index]
            self.generated_text += single_char
            self.current_index += 1
            print(f"→ 回退策略: 接受单字符 '{single_char}'")
            print(f"更新文本: '{self.generated_text}'")

        self.steps_log.append(step_info)
        return True

    def run_full_generation(self, max_steps: int = 20):
        """运行完整生成过程"""
        print("开始投机推理模拟")
        print(f"目标文本: '{self.target_text}'")
        print("="*60)

        steps = 0
        while steps < max_steps and self.current_index < len(self.target_text):
            if not self.speculative_step():
                break
            steps += 1
            time.sleep(0.5)  # 模拟处理时间

        print(f"\n{'='*60}")
        print("生成总结:")
        print(f"目标文本: '{self.target_text}'")
        print(f"生成文本: '{self.generated_text}'")
        print(f"准确率: {self._calculate_accuracy():.1%}")
        print(f"总步数: {steps}")
        print(f"文本长度: {len(self.generated_text)}")

        if self.generated_text == self.target_text:
            print("完全匹配！")
        else:
            print("存在差异")

    def _calculate_accuracy(self) -> float:
        """计算生成准确率"""
        if not self.generated_text:
            return 0.0

        min_len = min(len(self.generated_text), len(self.target_text))
        matches = sum(1 for i in range(min_len)
                     if self.generated_text[i] == self.target_text[i])
        return matches / len(self.target_text) if self.target_text else 0.0

    def print_statistics(self):
        """打印详细统计信息"""
        print("\n 详细统计:")
        print(f"{'步骤':<4} {'小模型输出':<10} {'验证结果':<20} {'接受':<6} {'当前文本'}")
        print("-"*80)

        for log in self.steps_log:
            small_out = log['small_model_generated'] or 'None'
            valid_res = log['validation_result'][:18] + '...' if len(log['validation_result']) > 18 else log['validation_result']
            accepted = '✓' if log['accepted'] else '✗'
            current_text = log['current_text'] + (f"+{log['small_model_generated']}" if log['small_model_generated'] else '')

            print(f"{log['step']:<4} {small_out:<10} {valid_res:<20} {accepted:<6} {current_text}")



def demo_rejection_sampling():
    """拒绝采样演示"""
    print("\n\n 拒绝采样优化演示")
    print("="*60)

    target = "人工智能是未来科技发展的核心驱动力。"

    print(f"目标文本: {target}")
    print("\n传统生成 vs 投机推理对比:")

    # 模拟传统生成
    print("\n1. 传统大模型生成（逐个字符）:")
    traditional_text = ""
    for i, char in enumerate(target):
        # 模拟大模型生成
        traditional_text += char
        print(f"   步骤 {i+1}: 生成 '{char}' -> '{traditional_text}'")
        time.sleep(0.1)

    # 模拟投机推理
    print("\n2. 投机推理生成（小模型提议，大模型验证）:")
    spec_sim = SpeculativeInferenceSimulator(target)
    steps = 0
    while steps < 10 and spec_sim.current_index < len(target):
        spec_sim.speculative_step(3)  # 每次提议3个字符
        steps += 1
        time.sleep(0.2)

if __name__ == "__main__":
    demo_rejection_sampling()



 拒绝采样优化演示
目标文本: 人工智能是未来科技发展的核心驱动力。

传统生成 vs 投机推理对比:

1. 传统大模型生成（逐个字符）:
   步骤 1: 生成 '人' -> '人'
   步骤 2: 生成 '工' -> '人工'
   步骤 3: 生成 '智' -> '人工智'
   步骤 4: 生成 '能' -> '人工智能'
   步骤 5: 生成 '是' -> '人工智能是'
   步骤 6: 生成 '未' -> '人工智能是未'
   步骤 7: 生成 '来' -> '人工智能是未来'
   步骤 8: 生成 '科' -> '人工智能是未来科'
   步骤 9: 生成 '技' -> '人工智能是未来科技'
   步骤 10: 生成 '发' -> '人工智能是未来科技发'
   步骤 11: 生成 '展' -> '人工智能是未来科技发展'
   步骤 12: 生成 '的' -> '人工智能是未来科技发展的'
   步骤 13: 生成 '核' -> '人工智能是未来科技发展的核'
   步骤 14: 生成 '心' -> '人工智能是未来科技发展的核心'
   步骤 15: 生成 '驱' -> '人工智能是未来科技发展的核心驱'
   步骤 16: 生成 '动' -> '人工智能是未来科技发展的核心驱动'
   步骤 17: 生成 '力' -> '人工智能是未来科技发展的核心驱动力'
   步骤 18: 生成 '。' -> '人工智能是未来科技发展的核心驱动力。'

2. 投机推理生成（小模型提议，大模型验证）:

步骤 1: 当前已生成: ''
目标剩余: '人工智能是未来科技发展的核心驱动力。'
小模型生成: '八工智'
----------------------------------------
大模型验证: 验证结果: 拒绝 | 有效部分: ''
✗ 全部拒绝，继续尝试...

步骤 2: 当前已生成: ''
目标剩余: '人工智能是未来科技发展的核心驱动力。'
小模型生成: '人干智'
----------------------------------------
大模型验证: 验证结果: 拒绝 | 有效部分: '人'
✓ 接受: '人'
更新文本: '人'

步骤 3: 当前已生成: '人'
目标剩余: '工智能是未来科技发展的核心驱动

## 2 Ngram原理演示

注意：由于没有固定random的seed，用例每次运行的结果由差异

In [None]:
import random
from collections import defaultdict
from typing import List, Dict, Tuple

class SimpleNGram:
    def __init__(self, n: int = 2):
        self.n = n  # n-gram的n值
        self.counts = defaultdict(lambda: defaultdict(int))  # 统计计数
        self.vocab = set()  # 词汇表

    def train(self, corpus: List[List[str]]):
        """训练n-gram模型"""
        for sentence in corpus:
            tokens = ['<s>'] * (self.n-1) + sentence + ['</s>']
            self.vocab.update(tokens)

            for i in range(len(tokens) - self.n + 1):
                context = tuple(tokens[i:i+self.n-1])  # 上下文
                word = tokens[i+self.n-1]  # 目标词
                self.counts[context][word] += 1

    def get_probability(self, word: str, context: Tuple[str]) -> float:
        """获取条件概率 P(word|context)"""
        if context not in self.counts or word not in self.counts[context]:
            return 0.0

        total = sum(self.counts[context].values())
        return self.counts[context][word] / total

    def generate_next(self, context: Tuple[str], k: int = 3) -> List[Tuple[str, float]]:
        """生成下一个可能的词及其概率"""
        if context not in self.counts:
            return []

        candidates = []
        for word, count in self.counts[context].items():
            prob = count / sum(self.counts[context].values())
            candidates.append((word, prob))

        candidates.sort(key=lambda x: x[1], reverse=True)
        return candidates[:k]

    def generate_sentence(self, max_len: int = 10) -> str:
        """生成句子"""
        context = tuple(['<s>'] * (self.n-1))
        result = []

        for _ in range(max_len):
            if not self.counts[context]:
                break

            # 按概率随机选择下一个词
            words, probs = zip(*[(w, p) for w, p in self.counts[context].items()])
            next_word = random.choices(words, weights=probs, k=1)[0]

            if next_word == '</s>':
                break

            result.append(next_word)
            context = tuple(list(context[1:]) + [next_word]) if self.n > 1 else ()

        return ' '.join(result)


# 运行演示
if __name__ == "__main__":
    # create_ngram_demo()

    # 简单训练演示
    print("\n" + "="*50)
    print("简单训练演示")

    corpus = [
        ["我", "喜欢", "吃", "苹果"],
        ["我", "喜欢", "吃", "香蕉"],
        ["他", "喜欢", "吃", "西瓜"],
        ["她", "喜欢", "吃", "葡萄"],
        ["我们", "都", "喜欢", "水果"],
    ]

    # 创建bigram模型
    bigram = SimpleNGram(n=2)
    bigram.train(corpus)

    # 测试
    test_context = tuple(["喜欢"])
    next_words = bigram.generate_next(test_context)

    print(f"\n在 '喜欢' 之后可能出现的词:")
    for word, prob in next_words:
        print(f"  '{word}': {prob:.2f}")

    # 生成句子
    print("\n生成的句子示例:")
    for i in range(3):
        sentence = bigram.generate_sentence()
        print(f"  示例{i+1}: {sentence}")


简单训练演示

在 '喜欢' 之后可能出现的词:
  '吃': 0.80
  '水果': 0.20

生成的句子示例:
  示例1: 我们 都 喜欢 水果
  示例2: 她 喜欢 水果
  示例3: 我们 都 喜欢 吃 苹果


## 3 vLLM运行示例

运行要求：在支持vLLM的环境中运行用例。

相关用例社区还在不断改进，具体可参看vLLM官网最新用例。

### 3.1 草稿模型

In [None]:
# 在vLLM中使用投机推理的示例：
# 注意：草稿模型的功能需要在0.9.x中使用，v0.10.x中使用会报NotImplementedError错误。
from vllm import LLM, SamplingParams


prompts = [
    "InfraTech是什么类型的公众号？",
]
sampling_params = SamplingParams(temperature=0.9)


llm = LLM(
    model="Qwen/Qwen3-VL-235B-A22B-Thinking",
    tensor_parallel_size=1,
    speculative_config={
        "model": "Qwen/Qwen2.5-VL-3B-Instruct",
        "num_speculative_tokens": 5,
    },
)
outputs = llm.generate(prompts, sampling_params)


for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

### 3.2 N-grams

In [None]:
from vllm import LLM, SamplingParams

prompts = [
    "如何关注InfraTech公众号？",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(
    model="facebook/opt-6.7b",
    tensor_parallel_size=1,
    speculative_config={
        "method": "ngram",
        "num_speculative_tokens": 5,
        "prompt_lookup_max": 4,
    },
)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

### 3.3 Suffix Decoding/MTP

In [None]:
from vllm import LLM, SamplingParams

prompts = [
    "如何关注InfraTech公众号？",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(
    model="facebook/opt-6.7b",
    tensor_parallel_size=1,
    speculative_config={
        "method": "suffix",   #"MTP"
        "num_speculative_tokens": 5,  # 1
    },
)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")