# 采样

采样顺序需按照：temperature -> top-k -> top-p

In [7]:
import torch

def sample(logits, greedy=False, temperature=1.0, top_k=0, top_p=0.0):
    """
    logits: [batch_size, vocab_size]  # 简化为单步采样
    """
    if temperature == 0 or greedy: # 贪婪采样
        return torch.argmax(logits, dim=-1).unsqueeze(-1) # [batch_size, 1]

    if temperature > 0:
        logits = logits / temperature

    if top_k > 0:
        values, _ = torch.topk(logits, top_k) # [batch_size, top_k]
        min_values = values[:, -1].unsqueeze(-1) # [batch_size, 1]
        # 需要将topk logits散布回原来的位置，保持形状不变，方便后续的multinomial
        logits = torch.where(logits < min_values, torch.full_like(logits, -float("inf")), logits)

    if 0 < top_p < 1:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        probs = torch.softmax(sorted_logits, dim=-1)
        cumprobs = torch.cumsum(probs, dim=-1)

        mask = cumprobs > top_p
        mask[:, 1:] = mask[:, :-1].clone() # 将mask右移一位，表示当前位置之前的累积prob是否大于top_p
        mask[:, 0] = False

        sorted_logits[mask] = -float("inf")
        logits = torch.full_like(logits, -float("inf")).scatter(-1, sorted_indices, sorted_logits)

    probs = torch.softmax(logits, dim=-1)
    next_token_id = torch.multinomial(probs, num_samples=1) # 根据prob进行随机抽样
    return next_token_id