In [None]:
import torch

@torch.no_grad()
def greedy_search(model, max_length, bos_token_id, eos_token_id, num_beams):
    input_ids = torch.full((batch_size, 1), bos_token_id)
    finished  = torch.zeros(batch_size, dtype=torch.bool)

    for _ in range(max_length):
        logits = model(input_ids)[:, -1, :] # (batch_size, seq_len, vocab_size) -> (batch_size, vocab_size)
        next_token = torch.argmax(logits, dim=-1)
        next_token = torch.where(finished, torch.full_like(next_token, eos_token_id), next_token)
        input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1) # (batch_size, seq_len)  + (batch_size, 1) -> (batch_size, seq_len+1)

        finished |= (next_token == eos_token_id)
        if torch.all(finished):
            break

    return input_ids
            

In [None]:
@torch.no_grad()
def beam_search(model, batch_size, max_new_tokens, bos_token_id, eos_token_id, beam_size):
    """
    Autoregressive beam search (decoder-only).
    Shapes in comments use: B=batch_size, K=beam_size, t=seq_len, V=vocab_size.
    """

    # ---- init ----
    device = next(model.parameters()).device
    V = None  # determined after first forward

    # start tokens: (B, 1)
    bos = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)

    # t=0 forward to seed beams
    out0 = model(bos)
    logits0 = out0.logits[:, -1, :] if hasattr(out0, "logits") else out0[:, -1, :]  # (B, V)
    V = logits0.size(-1)
    logp0 = torch.log_softmax(logits0, dim=-1)                                      # (B, V)

    # top-K tokens to initialize beams
    init_scores, init_tokens = logp0.topk(beam_size, dim=-1)                        # (B, K)
    seq      = init_tokens.unsqueeze(-1)                                            # (B, K, 1)
    scores   = init_scores.clone()                                                  # (B, K)
    finished = (init_tokens == eos_token_id)                                        # (B, K)

    # ---- iterate steps 1..max_new_tokens-1 ----
    for _ in range(1, max_new_tokens):
        # flatten beams → run model once for all beams
        # input: (B*K, t)
        flat_seq = seq.reshape(batch_size * beam_size, -1)

        out = model(flat_seq)
        logits = out.logits[:, -1, :] if hasattr(out, "logits") else out[:, -1, :]  # (B*K, V)
        logp   = torch.log_softmax(logits, dim=-1).view(batch_size, beam_size, -1)  # (B, K, V)

        # finished beams: only allow EOS to be extended (keep length; no new content)
        if eos_token_id is not None:
            mask_finished = finished.unsqueeze(-1)                                   # (B, K, 1)
            logp = torch.where(mask_finished, torch.full_like(logp, float("-inf")), logp)  # block all
            # pick up log-prob of EOS
            eos_slice = logp[..., eos_token_id]
            # for those finished beams, EOS is the only choice (others are -inf)
            eos_slice = torch.where(finished, torch.zeros_like(eos_slice), eos_slice)
            logp[..., eos_token_id] = eos_slice

        # expand: scores (B,K,1) + logp (B,K,V) -> candidate scores (B,K,V)
        candidate = scores.unsqueeze(-1) + logp                                      # (B, K, V)

        # select top-K from K*V candidates per batch
        flat_scores, flat_idx = candidate.view(batch_size, -1).topk(beam_size, dim=-1)  # (B, K*V)

        # map flat indices back to (parent_beam, next_token)
        parent = (flat_idx // V).long()                                              # (B, K)
        next_t = (flat_idx %  V).long()                                              # (B, K)

        # (B, K) → (B, K, t)
        gather_idx  = parent.unsqueeze(-1).expand(batch_size, beam_size, seq.size(-1))
        # pick up parent sequences
        parent_seqs = seq.gather(dim=1, index=gather_idx)

        # append chosen token → new beams
        seq    = torch.cat([parent_seqs, next_t.unsqueeze(-1)], dim=-1)              # (B, K, t+1)
        scores = flat_scores                                                         # (B, K)

        # update finished flags (propagate from parent and new EOS)
        finished = finished.gather(1, parent) | (next_t == eos_token_id)             # (B, K)

        # early stop if all beams in all batches are finished
        if torch.all(finished):
            break

    # pick best beam per batch
    best_scores, best_idx = scores.max(dim=1)                                        # (B,)
    take = best_idx.view(batch_size, 1, 1).expand(batch_size, 1, seq.size(-1))       # (B,1,t)
    best_seq = seq.gather(dim=1, index=take).squeeze(1)                               # (B, t)

    return best_seq, best_scores  # also return (seq, scores) if you want all beams


In [None]:
def beam_search(probs, beam_size, eps=1e-12):
    B, T, V = probs.shape
    log_probs = (probs + eps).log()             # to log-probs for numerical stability

    # t=0: pick top-K tokens for each batch as the initial beams
    init_scores, init_tokens = log_probs[:, 0, :].topk(beam_size, dim=-1)  # both (B, K)
    seq    = init_tokens.unsqueeze(-1)          # (B, K, 1) store the first token of each beam
    scores = init_scores                        # (B, K)    accumulated log-scores per beam

    for t in range(1, T):
        expanded = scores.unsqueeze(-1) + log_probs[:, t, :].unsqueeze(1)  # (B, K, 1) + (B, 1, V) -> (B, K, V)

        # From K*V candidates, keep the best K per batch
        flat_scores, flat_indices = expanded.reshape(B, -1).topk(beam_size, dim=-1)  # (B, K*V) -> (B, K)

        parent_beam = (flat_indices // V).long()   # (B, K) which previous beam each new candidate came from
        next_token  = (flat_indices %  V).long()   # (B, K) which token was chosen at this step

        # Gather parent sequences along the beam dimension (dim=1)
        # seq: (B, K, t) -> pick the parent beams specified by parent_beam
        gather_idx  = parent_beam.unsqueeze(-1).expand(B, beam_size, seq.size(-1))  # (B, K, t)
        parent_seqs = seq.gather(dim=1, index=gather_idx)                            # (B, K, t)

        # Append the chosen token for this time step to each selected parent sequence
        seq    = torch.cat([parent_seqs, next_token.unsqueeze(-1)], dim=-1)  # (B, K, t+1)
        scores = flat_scores                                                 # (B, K) update accumulated scores

    return seq, scores


In [None]:
@torch.no_grad()
def temperature_sampling(logits, temperature=1.0):
    probs = torch.softmax(logits / max(1e-8, float(temperature)), dim=-1) # (B, V)
    idx = torch.multinomial(probs, num_samples=1).squeeze(-1)  # [B]
    return idx

@torch.no_grad()
def top_k_sampling(logits, temperature=1.0, top_k=10):
    V = logits.size(-1)
    scaled_logits = logits / max(1e-8, float(temperature)) # (B, V)

    if (not top_k) or (top_k >= V) or (top_k <= 0):
        return torch.multinomial(scaled_logits, num_samples=1).squeeze(-1)

    topk_logits, topk_indices = torch.topk(scaled_logits, top_k, dim=-1)
    filtered_logits = torch.full_like(scaled_logits, -float('inf'))
    # use topk_indices to pick topk_logits for each batch
    filtered_logits.scatter_(dim=-1, index=topk_indices, src=topk_logits)
    # convert logits to probs again
    filtered_probs = torch.softmax(filtered_logits, dim=-1)
    idx = torch.multinomial(filtered_probs, num_samples=1).squeeze(-1)  # [B]
    return idx

@torch.no_grad()
def top_p_sampling(logits, temperature=1.0, top_p=0.9, min_tokens_to_keep=1):
    scaled_logits = logits / max(1e-8, float(temperature))
    probs = torch.softmax(scaled_logits, dim=-1)
    sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) # (B, V)
    cum_probs = torch.cumsum(sorted_probs, dim=-1) # (B, V)

    # includes the token that makes cumulative cross the threshold
    keep_mask = (cum_probs - sorted_probs) < top_p

    if min_tokens_to_keep > 0:
        keep_mask[..., :m] = True

    probs = sorted_probs * keep_mask
    normalized_probs = probs / probs.sum(dim=-1, keepdim=True)
    
    sampled_sorted = torch.multinomial(normalized_probs, num_samples=1).squeeze(-1) # [B]
    
    # convert sorted indices to original indices
    idx = torch.gather(sorted_indices, dim=-1, index=sampled_sorted) # [B]
    return idx
    