## Decode Strategy of Text Generation

GPT等自回归模型的解码策略会直接影响大模型的输出效果。在回归预测第$N$个token时，模型基于前序$N-1$个tokens计算第$N$个token的条件概率分布$P(w_N | w_1, \cdots, w_{N-1})$。
解码策略基于预测分布来决定模型的第$N$个token输出。常见的解码策略包含：
* Greedy Search
* Beam Search
* Sample
* Beam Sample

In [None]:
# https://huggingface.co/docs/transformers/main_classes/text_generation

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    RepetitionPenaltyLogitsProcessor,
    StoppingCriteriaList,
    MaxLengthCriteria,
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

# gpt2 has no PAD token
model.generation_config.pad_token_id = model.generation_config.eos_token_id

### Greedy Search


<div>
<img src="./Figure/DecoderDemo.png" width=512>
</div>

基于贪心搜索的解码策略，每步都选择概率最大的token：
$$ \mathrm{argmin}_{w_N} P(w_N | w_1, \cdots, w_{N-1}) $$

缺点：
* 可能错过全局概率最大的序列，如上图所示，The dog has的总概率最大。
* 由于缺少随机性，模型在输出一个重复token后，可能陷入重复输出序列的循环。
* 基于贪心搜索的解码与模型训练的目标函数类似，容易复述训练数据，缺乏创造性。

#### Greedy Search的代码执行逻辑

<div>
<img src="./Figure/GreedySearch.png" width=1280>
</div>

#### Greedy Search 快速上手案例

In [None]:
in_prompt = "This sunday we may"
in_ids = tokenizer(in_prompt, return_tensors='pt').input_ids

logit_processor = LogitsProcessorList(
    [
        MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
        RepetitionPenaltyLogitsProcessor(1.2),
    ]
)

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])

# 路径： $install_path/transformers/generation/utils.py
outputs = model.greedy_search(
    in_ids, logits_processor=logit_processor, stopping_criteria=stopping_criteria
)

result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(result)

### Beam Search

Beam Search的核心思想是，保留当前最佳的$n$个序列，并针对每个序列都再计算最后的$n$个next token，然后从$n\times n$个结果中，保留$n$个概率乘积最大的序列。

#### Beam Search的代码执行逻辑

<div>
<img src="./Figure/GreedySearch.png" width=1280>
</div>

#### Beam Search 快速上手案例

In [None]:
from transformers import (
    AutoModelForSeq2SeqLM,
    NoRepeatNGramLogitsProcessor,
    BeamSearchScorer
)

tokenizer_beam = AutoTokenizer.from_pretrained("t5-base")
model_beam = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

In [None]:
import torch

encoder_in_str = "translate English to German: That is good."
encoder_in_ids = tokenizer_beam(encoder_in_str, return_tensors="pt").input_ids

num_beams = 3
in_ids = torch.ones((num_beams, 1), device=model_beam.device, dtype=torch.long)
in_ids = in_ids * model_beam.config.decoder_start_token_id

model_kwargs = {
    "encoder_outputs": model_beam.get_encoder()(
        encoder_in_ids.repeat_interleave(num_beams, dim=0),
        return_dict=True
    )
}

beam_scorer = BeamSearchScorer(
    batch_size=1,
    num_beams=num_beams,
    num_beam_hyps_to_keep=2,
    device=model_beam.device,
)

logit_processor = LogitsProcessorList(
    [
        NoRepeatNGramLogitsProcessor(2),
    ]
)

outputs = model_beam.beam_search(
    in_ids, beam_scorer, logits_processor=logit_processor, **model_kwargs
)

result = tokenizer_beam.batch_decode(outputs, skip_special_tokens=True)
print(result)

### Sample

随机采样策略根据第$N$个token的预测概率来进行采样。
为保证生成的语句是通顺的，这里会引入temperature来改变预测概率分布，使其偏向于更高概率的结果，具体做法如下，在softmax概率中引入$t\in (0, 1]$，通过调整$t$的大小，可以避免从长尾分布中采样出不通顺的结果。
$$ P( x | x_{1:N-1}) = \frac{e^{u_t / t}}{\sum_{t^{\prime}} e^{u_{t^{\prime}/t}}}$$

#### Top-k sampling

<div>
<img src="./Figure/TopKSampling.png" width=1024>
</div>

Top-k采样只保留概率最高的$k$个token，然后计算重新归一化后的概率分布用以采样。

#### Top-p (Nucleus) sampling

<div>
<img src="Figure/TopPSampling.png" width=1028>
</div>

Top-p采样保留累积概率大于等于$p$的tokens，再重新计算归一化后的概率分布用以采样。

#### Sample的代码执行逻辑

<img src="./Figure/Sample.png" width=1280>

In [36]:
from transformers import (
    TopKLogitsWarper,
    TopPLogitsWarper
)

# 同上
model.config.pad_token_id = model.config.eos_token_id

in_prompt = "Today is a beautiful day, and"
in_ids = tokenizer(in_prompt, return_tensors="pt").input_ids

logit_processor = LogitsProcessorList(
    [
        MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
    ]
)

logit_warper = LogitsProcessorList(
    [
        TopKLogitsWarper(50),
        TopPLogitsWarper(0.9),
    ]
)

stopping_criteria = StoppingCriteriaList(
    [
        MaxLengthCriteria(max_length=32)
    ]
)

outputs = model.sample(
    in_ids,
    logits_processor=logit_processor,
    logits_warper=logit_warper,
    stopping_criteria=stopping_criteria,
)

result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(result)

["Today is a beautiful day, and we're all so glad that our country is finally recognizing that the world isn't always a better place than it was five decades"]


### Beam Sample

Beam Sample对Sample的结果进行排序，并保留当前最佳的$n$个序列，这样既保证了多样性和创造性，又可避免不通顺的输出。

#### Beam Sample的代码执行逻辑

<img src="./Figure/BeamSample.png" width=1280>

In [63]:
from transformers import (
    TemperatureLogitsWarper
)

encoder_in_str = "translate English into German: How old are you?"
encoder_in_ids = tokenizer_beam(encoder_in_str, return_tensors="pt").input_ids

num_beams = 3
in_ids = torch.ones((num_beams, 1), device=model_beam.device, dtype=torch.long)
in_ids = in_ids * model_beam.config.decoder_start_token_id

model_kwargs = {
    "encoder_outputs": model_beam.get_encoder()(
        encoder_in_ids.repeat_interleave(num_beams, dim=0),
        return_dict=True,
    )
}

beam_scorer = BeamSearchScorer(
    batch_size=1,
    max_length=model_beam.config.max_length,
    num_beams=num_beams,
    device=model_beam.device,
)

logit_warper = LogitsProcessorList(
    [
        TemperatureLogitsWarper(0.8),
        TopKLogitsWarper(50),
    ]
)

stopping_criteria = StoppingCriteriaList(
    [
        MaxLengthCriteria(max_length=32)
    ]
)

outputs = model_beam.beam_sample(
    in_ids,
    beam_scorer,
    logits_warper=logit_warper,
    stopping_criteria=stopping_criteria,
    **model_kwargs
)

result = tokenizer_beam.batch_decode(outputs, skip_special_tokens=True)
print(result)

['Wie alt bist du?']
