In [None]:
@torch.no_grad()
def greedy_search(model, max_length, bos_token_id, eos_token_id, num_beams):
    input_ids = torch.full((batch_size, 1), bos_token_id)
    finished  = torch.zeros(batch_size, dtype=torch.bool)

    for _ in range(max_length):
        logits = model(input_ids)[:, -1, :] # (batch_size, seq_len, vocab_size) -> (batch_size, vocab_size)
        next_token = torch.argmax(logits, dim=-1)
        next_token = torch.where(finished, torch.full_like(next_token, eos_token_id), next_token)
        input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)

        finished |= (next_token == eos_token_id)
        if torch.all(finished):
            break

    return input_ids
            
            

In [None]:
def beam_search(probs, beam_size, eps=1e-12):
    B, T, V = probs.shape
    log_probs = (probs + eps).log()             # to log-probs for numerical stability

    # t=0: pick top-K tokens for each batch as the initial beams
    init_scores, init_tokens = log_probs[:, 0, :].topk(beam_size, dim=-1)  # both (B, K)
    seq    = init_tokens.unsqueeze(-1)          # (B, K, 1) store the first token of each beam
    scores = init_scores                        # (B, K)    accumulated log-scores per beam

    for t in range(1, T):
        expanded = scores.unsqueeze(-1) + log_probs[:, t, :].unsqueeze(1)  # (B, K, 1) + (B, 1, V) -> (B, K, V)

        # From K*V candidates, keep the best K per batch
        flat_scores, flat_indices = expanded.reshape(B, -1).topk(beam_size, dim=-1)  # (B, K*V) -> (B, K)

        parent_beam = (flat_indices // V).long()   # (B, K) which previous beam each new candidate came from
        next_token  = (flat_indices %  V).long()   # (B, K) which token was chosen at this step

        # Gather parent sequences along the beam dimension (dim=1)
        # seq: (B, K, t) -> pick the parent beams specified by parent_beam
        gather_idx  = parent_beam.unsqueeze(-1).expand(B, beam_size, seq.size(-1))  # (B, K, t)
        parent_seqs = seq.gather(dim=1, index=gather_idx)                            # (B, K, t)

        # Append the chosen token for this time step to each selected parent sequence
        seq    = torch.cat([parent_seqs, next_token.unsqueeze(-1)], dim=-1)  # (B, K, t+1)
        scores = flat_scores                                                 # (B, K) update accumulated scores

    return seq, scores


In [None]:
def temperature_sampling(logits, temperature=1.0):
    probs = torch.softmax(logits / max(1e-8, float(temperature)), dim=-1) # (B, V)
    idx = torch.multinomial(probs, num_samples=1).squeeze(-1)  # [B]
    return idx

@torch.no_grad()
def top_k_sampling(logits, temperature=1.0, top_k=10):
    V = logits.size(-1)
    scaled_logits = logits / max(1e-8, float(temperature)) # (B, V)

    if (not top_k) or (top_k >= V) or (top_k <= 0):
        return torch.multinomial(scaled_logits, num_samples=1).squeeze(-1)

    topk_logits, topk_indices = torch.topk(scaled_logits, top_k, dim=-1)
    filtered_logits = torch.full_like(scaled_logits, -float('inf'))
    filtered_logits.scatter_(dim=-1, index=topk_indices, src=topk_logits)

    probs = torch.softmax(filtered_probs, dim=-1)
    idx = torch.multinomial(probs, num_samples=1).squeeze(-1)  # [B]
    return idx

@torch.no_grad()
def top_p_sampling(logits, temperature=1.0, top_p=0.9, min_tokens_to_keep=1):
    scaled_logits = logits / max(1e-8, float(temperature))
    probs = torch.softmax(scaled_logits, dim=-1)
    sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) # (B, V)
    cum_probs = torch.cumsum(sorted_probs, dim=-1) # (B, V)

    # includes the token that makes cumulative cross the threshold
    keep_mask = (cum_probs - sorted_probs) < top_p
    m = max(min_tokens_to_keep, 0)

    if m > 0:
        keep_mask[..., :m] = True

    probs = sorted_probs * keep_mask
    normalized_probs = probs / probs.sum(dim=-1, keepdim=True)
    
    sampled_sorted = torch.multinomial(normalized_probs, num_samples=1).squeeze(-1) # [B]
    idx = torch.gather(sorted_indices, dim=-1, index=sampled_sorted) # [B]
    return idx
    
    
    

    

