# 简介

> 指导文章：[09. 深入理解 Beam Search：原理, 示例与代码实现](https://github.com/Hoper-J/LLM-Guide-and-Demos-zh_CN/blob/master/09.%20深入理解%20Beam%20Search：原理%2C%20示例与代码实现.md#具体是怎么处理-eos-的)

在线链接：[Kaggle](https://www.kaggle.com/code/aidemos/07-beam-search) | [Colab](https://colab.research.google.com/drive/1apYBAQ6HNlo4xJDBT0RtUCgmNo_mQVXF?usp=sharing)

# 示例：过程演示

![过程演示](../Guide/assets/%E5%9B%BE%E7%89%87%201-6584229.png)


In [1]:
import math

def beam_search(initial_sequence, beam_width, max_length, vocab, get_next_probs):
    beam = [(initial_sequence, 0.0)]  # (sequence, log_prob)
    completed = []

    for step in range(max_length):
        print(f"\n第 {step + 1} 步:")
        all_candidates = []
        for seq, score in beam:
            if seq.endswith('<eos>'):
                completed.append((seq, score))
                print(f"已完成序列: {seq}，得分为 {score}")
                continue
            next_probs = get_next_probs(seq)
            print(f"扩展序列: {seq}，当前得分为 {score}")
            for token, prob in next_probs.items():
                new_seq = seq + token
                new_score = score + math.log(prob)
                all_candidates.append((new_seq, new_score))
                print(f"  候选序列: {new_seq}，得分为 {new_score}")
        
        # 对所有候选序列按得分降序排列，选择得分最高的 beam_width 个序列
        all_candidates.sort(key=lambda x: x[1], reverse=True)
        beam = all_candidates[:beam_width]

        # 打印选出的顶束序列
        print(f"\n选择的 {beam_width} 个顶束序列:")
        for seq, score in beam:
            print(f"  {seq}，得分为 {score}")
        
        # 如果没有更多序列可以扩展，则退出循环
        if not beam:
            break

    # 将当前 beam 中剩下的序列加入完成序列中
    completed += beam

    # 对完成的序列按得分降序排列，选择得分最高的序列
    completed.sort(key=lambda x: x[1], reverse=True)
    
    print("\n已完成的所有序列:")
    for seq, score in completed:
        print(f"  {seq}，得分为 {score}")
    
    return completed[0][0]

# 我们之前示例中设置的概率
def get_next_probs(seq):
    probs = {
        "": {"A": 0.4, "B": 0.3, "C": 0.2, "<eos>": 0.1},
        "A": {"A": 0.3, "B": 0.1, "C": 0.4, "<eos>": 0.2},
        "B": {"A": 0.1, "B": 0.1, "C": 0.3, "<eos>": 0.5},
        "AC": {"A": 0.1, "B": 0.2, "C": 0.5, "<eos>": 0.2},
    }
    return probs.get(seq, {"<eos>": 1.0})

initial_sequence = ""
beam_width = 2
max_length = 5
vocab = {"A", "B", "C", "<eos>"}

best_sequence = beam_search(initial_sequence, beam_width, max_length, vocab, get_next_probs)
print("\n最佳序列:", best_sequence)


第 1 步:
扩展序列: ，当前得分为 0.0
  候选序列: A，得分为 -0.916290731874155
  候选序列: B，得分为 -1.2039728043259361
  候选序列: C，得分为 -1.6094379124341003
  候选序列: <eos>，得分为 -2.3025850929940455

选择的 2 个顶束序列:
  A，得分为 -0.916290731874155
  B，得分为 -1.2039728043259361

第 2 步:
扩展序列: A，当前得分为 -0.916290731874155
  候选序列: AA，得分为 -2.120263536200091
  候选序列: AB，得分为 -3.2188758248682006
  候选序列: AC，得分为 -1.83258146374831
  候选序列: A<eos>，得分为 -2.525728644308255
扩展序列: B，当前得分为 -1.2039728043259361
  候选序列: BA，得分为 -3.506557897319982
  候选序列: BB，得分为 -3.506557897319982
  候选序列: BC，得分为 -2.4079456086518722
  候选序列: B<eos>，得分为 -1.8971199848858813

选择的 2 个顶束序列:
  AC，得分为 -1.83258146374831
  B<eos>，得分为 -1.8971199848858813

第 3 步:
扩展序列: AC，当前得分为 -1.83258146374831
  候选序列: ACA，得分为 -4.135166556742355
  候选序列: ACB，得分为 -3.4420193761824103
  候选序列: ACC，得分为 -2.525728644308255
  候选序列: AC<eos>，得分为 -3.4420193761824103
已完成序列: B<eos>，得分为 -1.8971199848858813

选择的 2 个顶束序列:
  ACC，得分为 -2.525728644308255
  ACB，得分为 -3.4420193761824103

第 4 步:
扩展序列: ACC，当前得分为 -2.52572864430

# 示例：使用 Hugging Face Transformers 库

In [None]:
!uv add transformers
#!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [3]:
import warnings
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# 忽略 FutureWarning 警告
warnings.filterwarnings("ignore", category=FutureWarning)

# 指定模型名称
model_name = "distilgpt2"

# 加载分词器和模型
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# 移动模型到设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 设置模型为评估模式
model.eval()

# 输入文本
input_text = "Hello GPT"

# 编码输入文本，并生成 attention mask
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
attention_mask = torch.ones_like(inputs).to(device)

# 生成文本，使用 Beam Search
beam_width = 5
with torch.no_grad():
    outputs = model.generate(
        inputs,
        attention_mask=attention_mask,
        max_length=50,
        num_beams=beam_width,  # 你可以看到 beam_width 对应的参数名为 num_beams
        no_repeat_ngram_size=2,
        early_stopping=True,  # 当所有候选序列生成<eos>停止
        pad_token_id=tokenizer.eos_token_id
    )

# 解码生成的文本
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("生成的文本：")
print(generated_text)


生成的文本：
Hello GPT.

This article was originally published on The Conversation. Read the original article.


## 对比不同束宽的输出

In [4]:
# 输入文本
input_text = "Hello GPT"

# 编码输入文本，并生成 attention mask
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
attention_mask = torch.ones_like(inputs).to(device)

# 设置束宽不同的生成策略
beam_widths = [1, 3, 5]  # 使用不同的束宽

# 生成并打印结果
for beam_width in beam_widths:
    with torch.no_grad():
        outputs = model.generate(
            inputs,
            attention_mask=attention_mask,
            max_length=50,
            num_beams=beam_width,
            no_repeat_ngram_size=2,
            early_stopping=True,
            pad_token_id=tokenizer.eos_token_id
        )
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"束宽 {beam_width} 的生成结果：")
    print(generated_text)
    print('-' * 50)


束宽 1 的生成结果：
Hello GPT is a free and open source software project that aims to provide a platform for developers to build and use GPGP-based GPSP based GPCs. GPP is an open-source software development platform that is designed to
--------------------------------------------------
束宽 3 的生成结果：
Hello GPT.

This article is part of a series of articles on the topic, and will be updated as more information becomes available.
--------------------------------------------------




束宽 5 的生成结果：
Hello GPT.

This article was originally published on The Conversation. Read the original article.
--------------------------------------------------
