# 模型初始化

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, Qwen2ForCausalLM
device = "cuda"  # the device to load the model onto

model_path = 'D:\learning\python\pretrain_checkpoint\Qwen2.5-1.5B-Instruct'
model: Qwen2ForCausalLM = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_path)



In [2]:
text = [
    {"role": "system", "content": "你是一个人工智能助手"},
    {"role": "user", "content": '5个字夸一下微博'}
]
text = tokenizer.apply_chat_template(text, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer(text, return_tensors="pt").to(device)

# SAMPLE 随机采样

In [5]:
# 调用函数生成
generated_ids = model.generate(
    top_k=5,
    top_p=1,
    max_new_tokens=64,
    num_return_sequences=1,
    output_scores=True,
    output_logits=True,
    return_dict_in_generate=True,
    do_sample=True,
    **model_inputs,
)
print(tokenizer.decode(generated_ids['sequences'][0][model_inputs.input_ids.shape[1]:]))

信息交流快<|im_end|>


In [6]:
# 随机采样每步输出
print(tokenizer.decode(generated_ids['sequences'][0][model_inputs.input_ids.shape[1]:]))
print('-'*100)
for idx in range(len(generated_ids['scores'])):
    all_token_ids = (generated_ids['scores'][idx][0] != float('-inf')).nonzero().view(-1)
    score = generated_ids['scores'][idx][0][all_token_ids]
    for i, token_id in enumerate(all_token_ids):
        score = torch.softmax(generated_ids['scores'][idx][0][all_token_ids], dim=-1)
        print(token_id.item(), tokenizer.decode(token_id), round(score[i].item(), 2))
    print('-'*100)

信息交流快<|im_end|>
----------------------------------------------------------------------------------------------------
27369 信息 0.35
43815 内容 0.14
50007 更新 0.06
93149 分享 0.1
100848 精彩 0.35
----------------------------------------------------------------------------------------------------
26288 大 0.06
50007 更新 0.06
93149 分享 0.08
100667 及时 0.11
101069 交流 0.22
104793 传递 0.46
----------------------------------------------------------------------------------------------------
80942 广 0.02
99234 快 0.69
100133 平台 0.16
105066 无限 0.03
105499 便捷 0.1
----------------------------------------------------------------------------------------------------
1773 。 0.12
3837 ， 0.0
6313 ！ 0.0
29524 如 0.01
151645 <|im_end|> 0.87
----------------------------------------------------------------------------------------------------


In [None]:
def _sample(
    self,
    input_ids: torch.LongTensor,
    logits_processor: LogitsProcessorList,
    stopping_criteria: StoppingCriteriaList,
    generation_config: GenerationConfig,
    synced_gpus: bool,
    streamer: Optional["BaseStreamer"],
    **model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
    r"""
    使用 **multinomial sampling** 生成序列的 token ids, 适用于具有语言建模头的文本解码器、文本到文本、语音到文本和视觉到文本模型。

    Parameters:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            用于生成序列的输入序列。
        logits_processor (`LogitsProcessorList`):
            一个 [`LogitsProcessorList`] 实例。包含从 [`LogitsProcessor`] 派生的类的实例列表，用于在每个生成步骤中修改语言建模头的预测分数。
        stopping_criteria (`StoppingCriteriaList`):
            一个 [`StoppingCriteriaList`] 实例。包含从 [`StoppingCriteria`] 派生的类的实例列表，用于告诉生成循环是否应该停止。
        generation_config ([`~generation.GenerationConfig`]):
            用于解码方法的生成配置。
        synced_gpus (`bool`):
            是否继续运行 while 循环直到 max_length (需要避免与 `FullyShardedDataParallel` 和 DeepSpeed ZeRO Stage 3 的死锁)。
        streamer (`BaseStreamer`, *optional*):
            流式对象，用于流式生成序列。生成的 token 通过 `streamer.put(token_ids)` 传递，流式对象负责任何进一步的处理。
        model_kwargs:
            额外的模型特定 kwargs 将传递给模型的 `forward` 函数。如果模型是编码器-解码器模型，kwargs 应包括 `encoder_outputs`。

    Return:
        [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
        A `torch.LongTensor` containing the generated tokens (default behaviour) or a
        [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
        `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
        `model.config.is_encoder_decoder=True`.
    """
    # 先拿出一些变量
    pad_token_id = generation_config._pad_token_tensor  # pad 值
    output_attentions = generation_config.output_attentions  # 是否输出 attentions
    output_hidden_states = generation_config.output_hidden_states  # 是否输出 hidden states
    output_scores = generation_config.output_scores  # 是否输出 scores
    output_logits = generation_config.output_logits  # 是否输出 logits
    return_dict_in_generate = generation_config.return_dict_in_generate  # 是否返回 dict
    max_length = generation_config.max_length  # 最大长度
    has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)  # 检测有没有 eos 停止条件
    do_sample = generation_config.do_sample  # 是否使用采样

    # 这里先将对应值初始化为元组类型
    scores = () if (return_dict_in_generate and output_scores) else None
    raw_logits = () if (return_dict_in_generate and output_logits) else None
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

    # 如果是 encoder-decoder 模型，从 model_kwargs 里取出 encoder 的 attentions 和 hidden states
    if return_dict_in_generate and self.config.is_encoder_decoder:
        encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

    batch_size, cur_len = input_ids.shape
    # 是否 batch 内所有序列都生成完成的判断标志位
    this_peer_finished = False
    # 创建一个跟踪每个序列是否完成生成的变量
    unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
    # 初始化位置序号，如果用了 cache，就从 cache 里取，没有就根据 input_ids 长度创建
    # prefilling 阶段确定输入的长度
    model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

    while self._has_unfinished_sequences(
        this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
    ):
        # 开始准备模型的输入
        model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

        # 准备变量输出控制（注意：一些模型不接受所有输出控制）
        # 模型默认只输出token的logits，如果需要输出attentions和hidden_states，则需要在model_inputs中传入特殊参数
        model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
        model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

        # 前向传播以获取下一个token
        outputs = self(**model_inputs, return_dict=True)

        # synced_gpus: 不要浪费资源运行我们不需要的代码；kwargs必须在跳过之前更新
        model_kwargs = self._update_model_kwargs_for_generation(
            outputs,
            model_kwargs,
            is_encoder_decoder=self.config.is_encoder_decoder,
        )
        if synced_gpus and this_peer_finished:
            continue

        # 克隆是必要的，以避免保持对 outputs.logits 的引用，这可能会非常大，特别是对于第一次迭代（克隆本身总是很小）
        # 取出最后一个token的logits
        next_token_logits = outputs.logits.clone()[:, -1, :].float()
        next_token_logits = next_token_logits.to(input_ids.device)

        # 解码策略处理
        next_token_scores = logits_processor(input_ids, next_token_logits)

        # 如果需要返回则存储分数、注意力、隐藏状态
        if return_dict_in_generate:
            if output_scores:
                scores += (next_token_scores,)
            if output_logits:
                raw_logits += (next_token_logits,)
            if output_attentions:
                decoder_attentions += (
                    (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                )
                if self.config.is_encoder_decoder:
                    cross_attentions += (outputs.cross_attentions,)

            if output_hidden_states:
                decoder_hidden_states += (
                    (outputs.decoder_hidden_states,)
                    if self.config.is_encoder_decoder
                    else (outputs.hidden_states,)
                )

        # 随机采样
        if do_sample:
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
        # 贪婪搜索
        else:
            next_tokens = torch.argmax(next_token_scores, dim=-1)

        # 如果存在 eos 停止条件，则将下一个 token 设置为 pad 值
        if has_eos_stopping_criteria:
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

        # 更新生成的 ids、模型输入和长度
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        # 如果流式对象不为空，则将下一个 token 传递给流式对象
        if streamer is not None:
            streamer.put(next_tokens.cpu())

        # 更新未完成序列的标志位
        unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
        this_peer_finished = unfinished_sequences.max() == 0
        cur_len += 1

        # 需要删除 outputs.logits，因为第一次迭代时它可能非常大
        # 否则，会保留对 outputs 的引用，这会保持 logits 在下一个迭代中存活
        del outputs

    # 如果流式对象不为空，则结束流式对象
    if streamer is not None:
        streamer.end()

    # 如果需要返回则返回
    if return_dict_in_generate:
        if self.config.is_encoder_decoder:
            return GenerateEncoderDecoderOutput(
                sequences=input_ids,
                scores=scores,
                logits=raw_logits,
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
                past_key_values=model_kwargs.get("past_key_values"),
            )
        else:
            return GenerateDecoderOnlyOutput(
                sequences=input_ids,
                scores=scores,
                logits=raw_logits,
                attentions=decoder_attentions,
                hidden_states=decoder_hidden_states,
                past_key_values=model_kwargs.get("past_key_values"),
            )
    # 如果不需要返回则返回生成的ids
    else:
        return input_ids

# BEAM_SEARCH 束搜索

In [None]:
# 调用函数生成
generated_ids = model.generate(
    top_k=5,
    top_p=1,
    max_new_tokens=64,
    num_return_sequences=1,
    output_scores=True,
    output_logits=True,
    return_dict_in_generate=True,
    num_beams=2,
    do_sample=False,
    **model_inputs,
)
print(tokenizer.decode(generated_ids['sequences'][0][model_inputs.input_ids.shape[1]:]))

From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.


信息传递快<|im_end|>


In [None]:
# beam search每步输出
print(tokenizer.decode(generated_ids['sequences'][0][model_inputs.input_ids.shape[1]:]))
print('-'*100)
for idx in range(len(generated_ids['scores'])):
    score1 = torch.topk(generated_ids['scores'][idx][0], k=5)
    score2 = torch.topk(generated_ids['scores'][idx][1], k=5)
    print("score1: ")
    for i in range(5):
        print(tokenizer.decode(score1.indices[i].item()), score1.indices[i].item(), round(score1.values[i].item(), 2))
    print("score2: ")
    for i in range(5):
        print(tokenizer.decode(score2.indices[i].item()), score2.indices[i].item(), round(score2.values[i].item(), 2))
    print('-'*100)


信息传递快<|im_end|>
----------------------------------------------------------------------------------------------------
score1: 
精彩 100848 -1.76
信息 27369 -1.89
内容 43815 -2.51
分享 93149 -2.64
更新 50007 -3.14
score2: 
精彩 100848 -1.76
信息 27369 -1.89
内容 43815 -2.51
分享 93149 -2.64
更新 50007 -3.14
----------------------------------------------------------------------------------------------------
score1: 
无限 105066 -0.65
每一天 114169 -3.15
分享 93149 -3.15
随时 102422 -3.4
纷 100100 -3.4
score2: 
传递 104793 -1.68
交流 101069 -2.43
及时 100667 -2.93
大 26288 -3.18
更新 50007 -3.31
----------------------------------------------------------------------------------------------------
score1: 
微博 101216 -2.56
聊 100281 -2.7
界 97120 -2.76
时 13343 -3.01
多 42140 -3.2
score2: 
快 99234 -0.04
迅速 104015 -4.67
速 94299 -4.92
便捷 105499 -5.17
快速 101098 -5.54
----------------------------------------------------------------------------------------------------
score1: 
<|im_end|> 151645 -0.06
如 29524 -4.11
。 1773 -4.36
， 3837 -4.93


In [None]:
# BeamScorer 抽象基类
class BeamScorer(ABC):
    """
    Abstract base class for all beam scorers that are used for [`~PreTrainedModel.beam_search`] and
    [`~PreTrainedModel.beam_sample`].
    """

    @abstractmethod
    @add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
    def process(
        self,
        input_ids: torch.LongTensor,
        next_scores: torch.FloatTensor,
        next_tokens: torch.LongTensor,
        next_indices: torch.LongTensor,
        **kwargs,
    ) -> Tuple[torch.Tensor]:
        raise NotImplementedError("This is an abstract method.")

    @abstractmethod
    @add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
    def finalize(
        self,
        input_ids: torch.LongTensor,
        next_scores: torch.FloatTensor,
        next_tokens: torch.LongTensor,
        next_indices: torch.LongTensor,
        max_length: int,
        **kwargs,
    ) -> torch.LongTensor:
        raise NotImplementedError("This is an abstract method.")

## BeamSearchScorer

In [None]:
class BeamSearchScorer(BeamScorer):
    r"""
    [`BeamScorer`] 实现标准的 beam search 解码

    Adapted in part from [Facebook's XLM beam search code]
    (https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).

    多样束搜索算法和实现的参考 [Ashwin Kalyan's DBS implementation](https://github.com/ashwinkalyan/dbs/blob/master/dbs/beam_utils.lua)

    Args:
        batch_size (`int`):
            用于标准束搜索解码的 `input_ids` 的批量大小。
        num_beams (`int`):
            束搜索的束数。
        device (`torch.device`):
            定义此实例的设备类型 (*e.g.*, `"cpu"` 或 `"cuda"`)。
        length_penalty (`float`, *optional*, defaults to 1.0):
            用于基于束搜索生成的序列长度的指数惩罚。它作为指数应用于序列长度,然后用于除以序列的分数。由于分数是序列的对数似然(即负值),
            因此 `length_penalty` > 0.0 会促进生成更长的序列,而 `length_penalty` < 0.0 则会鼓励生成更短的序列。
        do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
            控制基于束的方法(如束搜索)的停止条件。它接受以下值:
            `True`, 当有 `num_beams` 个完整候选项时立即停止生成; 
            `False`, 应用启发式方法, 当找到更好候选项的可能性很小时停止生成; 
            `"never"`, 只有在确定不可能有更好的候选项时束搜索过程才会停止(标准束搜索算法)。
        num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
            在调用 [`~transformers.BeamSearchScorer.finalize`] 时返回的束假设的数量。
        num_beam_groups (`int`, *optional*, defaults to 1):
            将 `num_beams` 分成多个组, 以确保不同组束之间的多样性。
            每个组的大小为 `num_beams // num_beam_groups`。
        max_length (`int`, *optional*):
            要生成的序列的最大长度。
    """

    def __init__(
        self,
        batch_size: int,
        num_beams: int,
        device: torch.device,
        length_penalty: Optional[float] = 1.0,
        do_early_stopping: Optional[Union[bool, str]] = False,
        num_beam_hyps_to_keep: Optional[int] = 1,
        num_beam_groups: Optional[int] = 1,
        max_length: Optional[int] = None,
    ):
        self.num_beams = num_beams
        self.device = device
        self.length_penalty = length_penalty
        self.do_early_stopping = do_early_stopping
        self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
        self.num_beam_groups = num_beam_groups
        self.group_size = self.num_beams // self.num_beam_groups  # 分组束搜索每组束的个数

        self._is_init = False  # 貌似也没用上
        # self._beam_hyps[i*self.num_beam_groups+j] 是第 i 个 mini-batch 中第 j 组的束假设。
        # 如果未使用 group_beam_search, 列表包含 `batch_size` 个束假设。
        self._beam_hyps = [
            BeamHypotheses(
                num_beams=self.group_size,
                length_penalty=self.length_penalty,
                early_stopping=self.do_early_stopping,
                max_length=max_length,
            )
            for _ in range(batch_size * self.num_beam_groups)
        ]
        # self._done[i*self.num_beam_groups+j] 表示第 i 个 mini-batch 中第 j 组的束假设是否完成。
        self._done = torch.tensor(
            [False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
        )

        if not isinstance(num_beams, int) or num_beams <= 1:
            raise ValueError(
                f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
                " one should make use of `greedy_search` instead."
            )

        if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
            raise ValueError(
                "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
                f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
            )

    @property
    def is_done(self) -> bool:
        return self._done.all()

    def process(
        self,
        input_ids: torch.LongTensor,
        next_scores: torch.FloatTensor,
        next_tokens: torch.LongTensor,
        next_indices: torch.LongTensor,
        pad_token_id: Optional[Union[int, torch.Tensor]] = None,
        eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
        beam_indices: Optional[torch.LongTensor] = None,
        group_index: Optional[int] = 0,
        decoder_prompt_len: Optional[int] = 0,
    ) -> Dict[str, torch.Tensor]:
        # 将 next_scores 计算的长度(包括提示)加到当前长度
        cur_len = input_ids.shape[-1] + 1
        batch_size = len(self._beam_hyps) // self.num_beam_groups  # 计算 batch 的大小

        # 分组束搜索的验证，输入的 batch 大小除以每组束的个数是否等于 batch 的大小
        if not (batch_size == (input_ids.shape[0] // self.group_size)):
            if self.num_beam_groups > 1:
                raise ValueError(
                    f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
                    f"size of {self.group_size} is expected by the beam scorer."
                )
            else:
                raise ValueError(
                    f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
                    f"{self.group_size} is expected by the beam scorer."
                )

        # 初始化下一层 beam 的分数、token 和索引
        device = input_ids.device
        next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
        next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
        next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

        # 如果 eos_token_id 不为空，则将其转换为 tensor
        if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
            if isinstance(eos_token_id, int):
                eos_token_id = [eos_token_id]
            eos_token_id = torch.tensor(eos_token_id)

        # 遍历每个 batch
        for batch_idx in range(batch_size):
            batch_group_idx = batch_idx * self.num_beam_groups + group_index  # 分组束搜索，计算 batch 的组索引
            if self._done[batch_group_idx]:  # 如果该 组/batch 已经生成结束，则跳过
                if self.num_beams < len(self._beam_hyps[batch_group_idx]):  # 如果束的个数小于 num_beams，则抛出错误，正常生成应该不会出现这种情况
                    raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
                if eos_token_id is None or pad_token_id is None:
                    raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
                # pad the batch
                # 如果当前 组/batch 生成结束，则用 pad 值填充，将当前 batch 的 beam 的分数、索引都设置为 0
                next_beam_scores[batch_idx, :] = 0
                next_beam_tokens[batch_idx, :] = pad_token_id
                next_beam_indices[batch_idx, :] = 0
                continue

            # 当前句子的下一个 token
            beam_idx = 0
            # 遍历当前 batch 的每个 beam 的 token
            for beam_token_rank, (next_token, next_score, next_index) in enumerate(
                zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
            ):
                batch_beam_idx = batch_idx * self.group_size + next_index  # 计算当前 batch 的 beam 的索引
                # 如果 eos_token_id 不为空，并且当前 token 是 eos_token_id 之一，则将该 token 添加到结束生成的束中
                if (eos_token_id is not None) and (next_token.item() in eos_token_id):
                    # 分组束搜索，如果当前 token 不属于 top num_beams 个 token，则跳过
                    is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
                    if is_beam_token_worse_than_top_num_beams:
                        continue
                    # 作为生成中间结果返回，如果 beam_indices 不为空，则将当前 batch 的 beam 的索引添加到 beam_indices 中
                    if beam_indices is not None:
                        beam_index = beam_indices[batch_beam_idx]
                        beam_index = beam_index + (batch_beam_idx,)
                    else:
                        beam_index = None

                    # 遇到eos，将其分数记录，用作后续判断是否停止
                    self._beam_hyps[batch_group_idx].add(
                        input_ids[batch_beam_idx].clone(),
                        next_score.item(),
                        beam_indices=beam_index,
                        generated_len=cur_len - decoder_prompt_len,
                    )
                else:
                    # 将下一个预测的 token 添加到束中, 因为它不是 eos_token
                    next_beam_scores[batch_idx, beam_idx] = next_score
                    next_beam_tokens[batch_idx, beam_idx] = next_token
                    next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
                    beam_idx += 1

                # 一旦束的下一个步骤已满, 不再添加更多 token
                if beam_idx == self.group_size:
                    break

            # 分组束搜索，如果当前 batch 的 beam 的 token 个数小于每组束的个数，则抛出错误
            if beam_idx < self.group_size:
                raise ValueError(
                    f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
                    f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
                )

            # 检查是否完成, 以便我们可以保存一个 pad 步骤(如果所有都完成)
            self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
                next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
            )

        return UserDict(
            {
                "next_beam_scores": next_beam_scores.view(-1),
                "next_beam_tokens": next_beam_tokens.view(-1),
                "next_beam_indices": next_beam_indices.view(-1),
            }
        )

    def finalize(
        self,
        input_ids: torch.LongTensor,
        final_beam_scores: torch.FloatTensor,
        final_beam_tokens: torch.LongTensor,
        final_beam_indices: torch.LongTensor,
        max_length: int,
        pad_token_id: Optional[Union[int, torch.Tensor]] = None,
        eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
        beam_indices: Optional[torch.LongTensor] = None,
        decoder_prompt_len: Optional[int] = 0,
    ) -> Tuple[torch.LongTensor]:
        # 计算 batch 的大小
        batch_size = len(self._beam_hyps) // self.num_beam_groups

        # 如果 eos_token_id 不为空，则将其转换为 tensor
        if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
            if isinstance(eos_token_id, int):
                eos_token_id = [eos_token_id]
            eos_token_id = torch.tensor(eos_token_id)

        # 完成所有打开的束假设并添加到生成的束假设中
        for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
            if self._done[batch_group_idx]:
                continue

            # 所有打开的束假设都添加到束假设中
            # 束假设类自动保留最佳束
            for index_per_group in range(self.group_size):
                batch_beam_idx = batch_group_idx * self.group_size + index_per_group
                final_score = final_beam_scores[batch_beam_idx].item()
                final_tokens = input_ids[batch_beam_idx]
                beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
                generated_len = final_tokens.shape[-1] - decoder_prompt_len
                beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)

        # 选择最佳束假设
        sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
        best = []
        best_indices = []
        best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)

        # 检索最佳束假设
        for i in range(batch_size):
            beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
            candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
            sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
            for j in range(self.num_beam_hyps_to_keep):
                best_hyp_tuple = sorted_hyps.pop()
                best_score = best_hyp_tuple[0]
                best_hyp = best_hyp_tuple[1]
                best_index = best_hyp_tuple[2]
                sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)

                # 将束假设添加到列表中
                best.append(best_hyp)

                # 将索引添加到列表中
                best_indices.append(best_index)

                best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

        # 准备添加 eos
        sent_lengths_max = sent_lengths.max().item() + 1
        sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
        decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)

        if len(best_indices) > 0 and best_indices[0] is not None:
            indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
        else:
            indices = None

        # 较短的批次如果需要则填充
        if sent_lengths.min().item() != sent_lengths.max().item():
            if pad_token_id is None:
                raise ValueError("`pad_token_id` has to be defined")
            decoded.fill_(pad_token_id)

        if indices is not None:
            indices.fill_(-1)

        # 将束假设和 eos_token_id 填充到解码器中(如果后者适合)
        for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
            decoded[i, : sent_lengths[i]] = hypo

            if indices is not None:
                indices[i, : len(best_idx)] = torch.tensor(best_idx)

            if sent_lengths[i] < sent_max_len:
                # 仅插入第一个 eos_token_id
                decoded[i, sent_lengths[i]] = eos_token_id[0]

        return UserDict(
            {
                "sequences": decoded,
                "sequence_scores": best_scores,
                "beam_indices": indices,
            }
        )



## BeamHypotheses 束假设类

In [None]:
class BeamHypotheses:
    def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
        """
        初始化束假设列表
        """
        self.length_penalty = length_penalty  # 长度惩罚
        self.early_stopping = early_stopping  # 是否提前停止
        self.max_length = max_length  # 最大长度
        self.num_beams = num_beams  # 束的个数
        self.beams = []  # 束假设列表
        self.worst_score = 1e9  # 最差分数

        # 检查 early_stopping 是否为 bool 类型，并且 max_length 是否为 None
        if not isinstance(self.early_stopping, bool) and self.max_length is None:
            raise ValueError(
                "When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
                " BeamScorer class instance at initialization time."
            )

    def __len__(self):
        """
        返回束假设列表的长度
        """
        return len(self.beams)

    def add(
        self,
        hyp: torch.LongTensor,
        sum_logprobs: float,
        beam_indices: Optional[torch.LongTensor] = None,
        generated_len: Optional[int] = None,
    ):
        """
        添加一个新的束假设到列表中
        """
        # 遇到eos的时候，计算当前束的停止分数
        if generated_len is not None:
            score = sum_logprobs / (generated_len**self.length_penalty)  # 计算束假设的分数
        # 如果 generated_len 为空，则使用 hyp 的长度计算束假设的分数
        else:
            score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)  # 计算束假设的分数

        # 如果束假设列表的长度小于束的个数，或者当前束的分数大于最差分数，则将当前束假设添加到列表中
        if len(self) < self.num_beams or score > self.worst_score:
            self.beams.append((score, hyp, beam_indices))
            # 如果束假设列表的长度大于束的个数，则删除最差的束假设
            if len(self) > self.num_beams:
                sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
                del self.beams[sorted_next_scores[0][1]]
                self.worst_score = sorted_next_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)

    def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
        """
        如果束假设列表的长度大于束的个数，并且当前束的分数大于最差分数，则返回 True，表示生成结束
        """

        if len(self) < self.num_beams:
            return False

        # `True`: 遇到至少 `num_beams` 个束假设后直接结束
        if self.early_stopping is True:
            return True
        # `False`: 启发式 -- 计算最高可能的分数，即使它不完全准确
        # 当 `length_penalty` 为正时。更多信息，请参阅 https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
        # 如果没有额外设置参数，则进入该分支
        elif self.early_stopping is False:
            # 计算最高可达到的分数
            highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
            # 如果最差分数大于最高可达到的分数，则返回 True，表示生成结束
            ret = self.worst_score >= highest_attainable_score
            return ret
        # `"never"`: 计算最高可能的分数，取决于 `length_penalty` 的信号
        # 这里的区别在于，如果 `length_penalty` 为正，则使用 `max_length` 作为最大分母，否则使用 `cur_len` 作为最大分母
        else:
            # `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
            # abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
            # its max this way
            if self.length_penalty > 0.0:
                if self.max_length <= decoder_prompt_len:
                    raise ValueError("max_length is not larger than decoder prompt length")
                highest_attainable_score = (
                    best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
                )
            # 相反的逻辑在这里（从 `cur_len` 中获得最大 `highest_attainable_score`）
            # 与 self.early_stopping is False 的使用相同
            else:
                highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
            ret = self.worst_score >= highest_attainable_score
            return ret

## _beam_search 束搜索函数

In [17]:
def _beam_search(
    self,
    input_ids: torch.LongTensor,
    beam_scorer: BeamScorer,
    logits_processor: LogitsProcessorList,
    stopping_criteria: StoppingCriteriaList,
    generation_config: GenerationConfig,
    synced_gpus: bool,
    **model_kwargs,
) -> Union[GenerateBeamOutput, torch.LongTensor]:
    r"""
    使用 **beam search decoding** 生成序列的 token id，适用于具有语言建模头的文本解码器、文本到文本、语音到文本和视觉到文本模型。

    Parameters:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            用于生成序列的输入序列。
        beam_scorer (`BeamScorer`):
            一个 [`BeamScorer`] 的派生实例，定义了如何构建、存储和排序束假设，在生成过程中。
            更多信息，请参阅 [`BeamScorer`] 的文档。
        logits_processor (`LogitsProcessorList`):
            一个 [`LogitsProcessorList`] 的实例。列表包含从 [`LogitsProcessor`] 派生的类的实例，
            用于在每个生成步骤中修改语言建模头的预测分数。
        stopping_criteria (`StoppingCriteriaList`:
            一个 [`StoppingCriteriaList`] 的实例。列表包含从 [`StoppingCriteria`] 派生的类的实例，
            用于告诉生成循环是否应该停止。
        generation_config ([`~generation.GenerationConfig`]):
            用于解码方法的参数化。
        synced_gpus (`bool`):
            是否继续运行 while 循环直到 max_length（需要避免与 `FullyShardedDataParallel` 和 DeepSpeed ZeRO Stage 3 的死锁）。
        model_kwargs:
            额外的模型特定 kwargs 将转发到模型的 `forward` 函数。如果模型是 encoder-decoder 模型，kwargs 应包括 `encoder_outputs`。

    Return:
        [`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] 或
        `torch.LongTensor`: 一个 `torch.LongTensor` 包含生成的 token（默认行为）或
        [`~generation.GenerateBeamDecoderOnlyOutput`] 如果 `model.config.is_encoder_decoder=False` 和
        `return_dict_in_generate=True` 或一个 [`~generation.GenerateBeamEncoderDecoderOutput`] 如果
        `model.config.is_encoder_decoder=True`.
    """
    # 先拿出一些变量
    pad_token_id = generation_config._pad_token_tensor  # pad 值
    eos_token_id = generation_config._eos_token_tensor  # eos 值
    output_attentions = generation_config.output_attentions  # 是否输出 attentions
    output_hidden_states = generation_config.output_hidden_states  # 是否输出 hidden states
    output_scores = generation_config.output_scores  # 是否输出 scores
    output_logits = generation_config.output_logits  # 是否输出 logits
    return_dict_in_generate = generation_config.return_dict_in_generate  # 是否返回 dict
    sequential = generation_config.low_memory  # 是否顺序生成
    do_sample = generation_config.do_sample  # 是否使用采样

    batch_size = len(beam_scorer._beam_hyps)
    num_beams = beam_scorer.num_beams

    # batch_beam_size：batch_size * num_beams，扩展之后的 batch_size 大小
    batch_beam_size, cur_len = input_ids.shape
    # prefilling 阶段确定输入的长度
    model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

    # 检查 batch_beam_size 大小是否正确
    if num_beams * batch_size != batch_beam_size:
        raise ValueError(
            f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
        )

    # 初始化 attention / hidden states / scores tuples
    scores = () if (return_dict_in_generate and output_scores) else None
    raw_logits = () if (return_dict_in_generate and output_logits) else None
    beam_indices = (
        tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
    )
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

    # 如果模型是 encoder-decoder，获取 encoder attention weights 和 hidden states
    if return_dict_in_generate and self.config.is_encoder_decoder:
        encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

    # 初始化第一个 beam 的分数为 0，其余 beam 的分数为 -1e9。这确保了只有第一个 beam 的 token 被考虑，以避免在所有 beam 中采样相同的 token。
    beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
    beam_scores[:, 1:] = -1e9
    beam_scores = beam_scores.view((batch_size * num_beams,))

    this_peer_finished = False

    decoder_prompt_len = input_ids.shape[-1]  # 记录 prompt 长度

    while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
        # 准备模型输入，例如 position_ids、kv_cache、attention_mask
        model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

        # 准备可变输出控制（注意：某些模型可能不接受所有输出控制）
        model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
        model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

        # 如果 sequential 为 True，将输入拆分为 batch_size 的批次并顺序运行
        if sequential:
            if any(
                model_name in self.__class__.__name__.lower()
                for model_name in [
                    "fsmt",
                    "reformer",
                    "ctrl",
                    "gpt_bigcode",
                    "transo_xl",
                    "xlnet",
                    "cpm",
                    "jamba",
                ]
            ):
                raise RuntimeError(
                    f"Currently generation for {self.__class__.__name__} is not supported "
                    f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature."
                )

            inputs_per_sub_batches = _split_model_inputs(
                model_inputs,
                split_size=batch_size,
                full_batch_size=batch_beam_size,
                config=self.config.get_text_config(),
            )
            outputs_per_sub_batch = [
                self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches
            ]

            outputs = stack_model_outputs(outputs_per_sub_batch, self.config.get_text_config())
        # 正常的方式进行生成
        else:  
            outputs = self(**model_inputs, return_dict=True)

        # 更新模型kwargs，将新生成的 attention_mask 和 kv_cache 添加到 model_kwargs 中
        model_kwargs = self._update_model_kwargs_for_generation(
            outputs,
            model_kwargs,
            is_encoder_decoder=self.config.is_encoder_decoder,
        )
        # 同步GPU，不浪费资源运行我们不需要的代码
        if synced_gpus and this_peer_finished:
            cur_len = cur_len + 1
            continue

        # 取出最后一个 token 的 logits
        # 克隆是必要的，以避免保持对 outputs.logits 的引用，这可能会非常大，尤其是在第一次迭代中（克隆本身总是很小）
        # .float() 是必要的，以保留精度，以便稍后进行 logits 操作
        next_token_logits = outputs.logits[:, -1, :].clone().float()
        next_token_logits = next_token_logits.to(input_ids.device)
        # 计算下一个 token 的分数，使用 log_softmax 函数，得到的值都是负数，将最后一个维度扩展成 vocab_size 的大小
        next_token_scores = nn.functional.log_softmax(
            next_token_logits, dim=-1
        )  # (batch_size * num_beams, vocab_size)

        # 使用解码策略处理下一个 token 的分数
        next_token_scores_processed = logits_processor(input_ids, next_token_scores)
        # 将处理后的分数与 beam_scores 相加，得到每个 token 的累加分数
        next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
            next_token_scores_processed
        )

        # 如果需要返回分数、注意力、隐藏状态，则存储到对应的变量中
        if return_dict_in_generate:
            if output_scores:
                scores += (next_token_scores_processed,)
            if output_logits:
                raw_logits += (next_token_logits,)
            if output_attentions:
                decoder_attentions += (
                    (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                )
                if self.config.is_encoder_decoder:
                    cross_attentions += (outputs.cross_attentions,)
            if output_hidden_states:
                decoder_hidden_states += (
                    (outputs.decoder_hidden_states,)
                    if self.config.is_encoder_decoder
                    else (outputs.hidden_states,)
                )

        # 将维度变成（batch_size, num_beams * vocab_size）
        vocab_size = next_token_scores.shape[-1]
        next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

        # beam search 策略：选择 1 + eos_token_id.shape[0] 个 token 作为每个 beam 的候选 token，
        # 以确保每个 beam 至少有一个非 eos 的 token。
        # 计算 eos_token_id 的数量，如果 eos_token_id 为 None，则设置为 0
        n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
        # 计算需要保留的 token 数量，至少保留 2 个 token，或者加上 eos_token_id 的数量
        n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams
        if do_sample:
            # 使用 softmax 函数计算每个 token 的概率
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            # 使用多样式采样函数从概率分布中采样 token
            next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep)
            # 使用 gather 函数从 next_token_scores 中收集采样到的 token 的分数
            next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
            # 对 next_token_scores 进行排序，以确保每个 beam 的分数从高到低排列
            next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
            # 使用 gather 函数从 next_tokens 中收集排序后的 token
            next_tokens = torch.gather(next_tokens, -1, _indices)
        else:
            # 使用 topk 函数从 next_token_scores 中选择 n_tokens_to_keep 个最大的 token
            next_token_scores, next_tokens = torch.topk(
                next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True
            )

        # 计算 next_tokens 的索引是属于哪个 beam 的
        next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
        # 计算 next_tokens 的实际值，即在词表中的索引
        next_tokens = next_tokens % vocab_size

        # 进行 beam search 的逻辑
        beam_outputs = beam_scorer.process(
            input_ids,
            next_token_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            beam_indices=beam_indices,
            decoder_prompt_len=decoder_prompt_len,
        )
        # 获取下一个 beam 的分数、token 和索引
        beam_scores = beam_outputs["next_beam_scores"]
        beam_next_tokens = beam_outputs["next_beam_tokens"]
        beam_idx = beam_outputs["next_beam_indices"]

        # 将当前的 input_ids 和新的 token 连接起来
        input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

        # 需要删除 outputs.logits，因为第一次迭代时它可能非常大
        # 否则，会保留对 outputs 的引用，这会保持 logits 在下一个迭代中存活
        # 重要：请注意，这应该出现在 _reorder_cache() 调用之前，以保存最大内存
        # （这样内存峰值不会包括 outputs.logits）
        del outputs

        # 如果存在 past_key_values，则需要重新排序
        if model_kwargs.get("past_key_values", None) is not None:
            model_kwargs["past_key_values"] = self._temporary_reorder_cache(
                model_kwargs["past_key_values"], beam_idx
            )

        # 如果需要返回分数，则需要重新排序 beam_indices
        if return_dict_in_generate and output_scores:
            beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

        # 增加 cur_len
        cur_len = cur_len + 1

        # 如果 beam_scorer 已经完成或者所有停止条件都满足，则设置 this_peer_finished 为 True
        if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
            this_peer_finished = True

    # 完成 beam search 的逻辑，此时进行回溯找到最优的序列
    sequence_outputs = beam_scorer.finalize(
        input_ids,
        beam_scores,
        next_tokens,
        next_indices,
        pad_token_id=pad_token_id,
        eos_token_id=eos_token_id,
        max_length=stopping_criteria.max_length,
        beam_indices=beam_indices,
        decoder_prompt_len=decoder_prompt_len,
    )

    # 如果需要返回字典，则返回字典
    if return_dict_in_generate:
        if not output_scores:
            sequence_outputs["sequence_scores"] = None

        if self.config.is_encoder_decoder:
            return GenerateBeamEncoderDecoderOutput(
                sequences=sequence_outputs["sequences"],
                sequences_scores=sequence_outputs["sequence_scores"],
                scores=scores,
                logits=raw_logits,
                beam_indices=sequence_outputs["beam_indices"],
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
                past_key_values=model_kwargs.get("past_key_values"),
            )
        else:
            return GenerateBeamDecoderOnlyOutput(
                sequences=sequence_outputs["sequences"],
                sequences_scores=sequence_outputs["sequence_scores"],
                scores=scores,
                logits=raw_logits,
                beam_indices=sequence_outputs["beam_indices"],
                attentions=decoder_attentions,
                hidden_states=decoder_hidden_states,
                past_key_values=model_kwargs.get("past_key_values"),
            )
    # 如果不需要返回字典，则返回生成的序列
    else:
        return sequence_outputs["sequences"]

# GROUP_BEAM_SEARCH 分组束搜索(待做)

要求 diversity_penalty 不为 0, do_sample 必须为 False，num_beams%num_beam_groups 必须为 0


In [None]:
def _group_beam_search(
    self,
    input_ids: torch.LongTensor,
    beam_scorer: BeamScorer,
    logits_processor: LogitsProcessorList,
    stopping_criteria: StoppingCriteriaList,
    generation_config: GenerationConfig,
    synced_gpus: bool,
    **model_kwargs,
):
    r"""
    Generates sequences of token ids for models with a language modeling head using **diverse beam search
    decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.

    Parameters:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The sequence used as a prompt for the generation.
        beam_scorer (`BeamScorer`):
            An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
            sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
        logits_processor (`LogitsProcessorList`):
            An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
            used to modify the prediction scores of the language modeling head applied at each generation step.
        stopping_criteria (`StoppingCriteriaList`):
            An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
            used to tell if the generation loop should stop.
        generation_config ([`~generation.GenerationConfig`]):
            The generation configuration to be used as parametrization of the decoding method.
        synced_gpus (`bool`):
            Whether to continue running the while loop until max_length (needed to avoid deadlocking with
            `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
        model_kwargs:
            Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
            model is an encoder-decoder model the kwargs should include `encoder_outputs`.

    Return:
        [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
        `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
        [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
        `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
        `model.config.is_encoder_decoder=True`.
    """
    # init values
    pad_token_id = generation_config._pad_token_tensor
    eos_token_id = generation_config._eos_token_tensor
    output_attentions = generation_config.output_attentions
    output_hidden_states = generation_config.output_hidden_states
    output_scores = generation_config.output_scores
    output_logits = generation_config.output_logits
    return_dict_in_generate = generation_config.return_dict_in_generate

    num_beams = beam_scorer.num_beams
    num_beam_groups = beam_scorer.num_beam_groups
    num_sub_beams = num_beams // num_beam_groups
    batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
    device = input_ids.device

    batch_beam_size, cur_len = input_ids.shape
    model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

    # 与束搜索类似操作，但代码未统一
    if return_dict_in_generate and output_scores:
        beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
    else:
        beam_indices = None

    if num_beams * batch_size != batch_beam_size:
        raise ValueError(
            f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
        )

    # init attention / hidden states / scores tuples
    scores = () if (return_dict_in_generate and output_scores) else None
    raw_logits = () if (return_dict_in_generate and output_logits) else None
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

    # 如果模型是编码器-解码器模型，则获取编码器的注意力权重和隐藏状态
    if return_dict_in_generate and self.config.is_encoder_decoder:
        encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

    # 初始化每个束组的第一个束的分数为0，其余的束的分数为-1e9。这确保了同一组中的束不会每次都生成相同的token。
    beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
    # 将每个束组的第一个束的分数设置为0，这里::num_sub_beams代表每隔num_sub_beams步设置为0
    beam_scores[:, ::num_sub_beams] = 0
    # 将束的分数展平为一维
    beam_scores = beam_scores.view((batch_size * num_beams,))

    this_peer_finished = False

    decoder_prompt_len = input_ids.shape[-1]  # record the prompt length of decoder
    while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
        # predicted tokens in cur_len step
        current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)

        # indices which will form the beams in the next time step
        reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)

        # do one decoder step on all beams of all sentences in batch
        model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

        # prepare variable output controls (note: some models won't accept all output controls)
        model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
        model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

        outputs = self(**model_inputs, return_dict=True)

        # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
        model_kwargs = self._update_model_kwargs_for_generation(
            outputs,
            model_kwargs,
            is_encoder_decoder=self.config.is_encoder_decoder,
        )
        if synced_gpus and this_peer_finished:
            cur_len = cur_len + 1
            continue

        if output_scores:
            processed_score = torch.zeros_like(outputs.logits[:, -1, :])
        if output_logits:
            # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
            # (the clone itself is always small)
            raw_logit_score = outputs.logits[:, -1, :].clone()
            raw_logit_score = raw_logit_score.to(input_ids.device)

        for beam_group_idx in range(num_beam_groups):
            # 计算当前束组的起始索引和结束索引
            group_start_idx = beam_group_idx * num_sub_beams
            # 计算当前束组的结束索引，这里min是防止group_start_idx + num_sub_beams超出num_beams，但属于重复判断，前面校验已经很严格了
            group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
            # 计算当前束组的大小
            group_size = group_end_idx - group_start_idx

            # indices of beams of current group among all sentences in batch
            batch_group_indices = []

            for batch_idx in range(batch_size):
                batch_group_indices.extend(
                    [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
                )
            group_input_ids = input_ids[batch_group_indices]

            # select outputs of beams of current group only
            # No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
            # .float() is needed to retain precision for later logits manipulations
            next_token_logits = outputs.logits[batch_group_indices, -1, :].float()
            next_token_logits = next_token_logits.to(input_ids.device)

            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * group_size, vocab_size)
            vocab_size = next_token_scores.shape[-1]

            next_token_scores_processed = logits_processor(
                group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
            )
            next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
            next_token_scores = next_token_scores.expand_as(next_token_scores_processed)

            if output_scores:
                processed_score[batch_group_indices] = next_token_scores_processed

            # reshape for beam search
            next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)

            # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
            n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
            next_token_scores, next_tokens = torch.topk(
                next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True
            )

            next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
            next_tokens = next_tokens % vocab_size

            # stateless
            process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
            beam_outputs = beam_scorer.process(
                group_input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                beam_indices=process_beam_indices,
                group_index=beam_group_idx,
                decoder_prompt_len=decoder_prompt_len,
            )
                beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
                beam_next_tokens = beam_outputs["next_beam_tokens"]
                beam_idx = beam_outputs["next_beam_indices"]

                if return_dict_in_generate and output_scores:
                    beam_indices[beam_group_idx] = tuple(
                        beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
                    )

                input_ids[batch_group_indices] = group_input_ids[beam_idx]
                group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
                current_tokens[batch_group_indices] = group_input_ids[:, -1]

                # (beam_idx // group_size) -> batch_idx
                # (beam_idx % group_size) -> offset of idx inside the group
                reordering_indices[batch_group_indices] = (
                    num_beams * torch.div(beam_idx, group_size, rounding_mode="floor")
                    + group_start_idx
                    + (beam_idx % group_size)
                )

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (processed_score,)
                if output_logits:
                    raw_logits += (raw_logit_score,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

            input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

            # This is needed to properly delete outputs.logits which may be very large for first iteration
            # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
            # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
            # (that way the memory peak does not include outputs.logits)
            del outputs

            if model_kwargs.get("past_key_values", None) is not None:
                model_kwargs["past_key_values"] = self._temporary_reorder_cache(
                    model_kwargs["past_key_values"], reordering_indices
                )

            # increase cur_len
            cur_len = cur_len + 1

            if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
                this_peer_finished = True

        final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
        sequence_outputs = beam_scorer.finalize(
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
            beam_indices=final_beam_indices,
            decoder_prompt_len=decoder_prompt_len,
        )

        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None

            if self.config.is_encoder_decoder:
                return GenerateBeamEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    logits=raw_logits,
                    beam_indices=sequence_outputs["beam_indices"],
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
                    cross_attentions=cross_attentions,
                    decoder_hidden_states=decoder_hidden_states,
                    past_key_values=model_kwargs.get("past_key_values"),
                )
            else:
                return GenerateBeamDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    logits=raw_logits,
                    beam_indices=sequence_outputs["beam_indices"],
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                    past_key_values=model_kwargs.get("past_key_values"),
                )
        else:
            return sequence_outputs["sequences"]