<a href="https://colab.research.google.com/github/QasimWani/simple-transformer/blob/main/transformers/decoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Implements several common decoding algorithms such as top-k and top-p (nucleus sampling)

import torch
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper

In [2]:
def get_top_k(logits, k: int, temperature: float, generator: torch._C.Generator = None):
  '''
  Step 1. Retrieve the top k logits
  Step 2. Convert logits to probs
  Step 3. Sample from top k probs (output shape: batch_size, where each element is in range(0, k))
  Step 4. Return top id for each sample in the batch (need to re-align to range(0, vocab_size))
  '''

  batch_size, vocab_size = logits.shape
  k = min(k, vocab_size)

  if temperature <= 1e-8: # deterministic sampling
    return logits.argmax(dim=-1) # batch_size

  top_k_logits, top_k_indices = torch.topk(logits / temperature, k) # batch_size, k

  top_k_probs = torch.softmax(top_k_logits, dim=-1) # high temperature indicates more randomness
  top_sample_indices = torch.multinomial(top_k_probs, num_samples=1, generator=generator) # (batch_size, 1) - use torch.rand for uniform sampling

  # Now we have the top sample indices per sample.
  # We would like to match this with respect to the original top k indices (which maintains order relative to logits)
  # Gather:
  # y[i, j] = x[ idx[i, j], j ] , for all i, j in idx. dim = 0
  # y[i, j] = x[ i, idx[i, j] ] , for all i, j in idx. dim = 1
  token_idx = torch.gather(top_k_indices, dim=1, index=top_sample_indices).squeeze(1) # batch_size
  return token_idx # token idx in this case refers to the index in vocab_size which is equivalent to the actual token id


In [3]:
def get_top_p(logits: torch.Tensor, p: float, temperature: float, generator: torch._C.Generator = None):
  '''
  Nucleus sampling where the goal is to pick the smallest subset of logits that is greater than or equal to p
  Step 1. Sort logits
  Step 2. Convert to probabilities
  Step 3. Find the smallest subset s.t. the sum >= p
  Step 4. Retrieve the indices that belong to that subset
  '''
  batch_size, vocab_size = logits.shape

  if temperature <= 1e-8: # deterministic sampling
    return logits.argmax(dim=-1) # batch_size

  # Step 1. Sort logits
  sorted_logits, sorted_idx = (logits / temperature).sort(dim=-1, descending=True) # batch_size, vocab_size ; O(nlogn)

  # Step 2. Convert to probabilities
  sorted_probs = torch.softmax(sorted_logits, dim=-1) # batch_size, vocab_size. higher temperature more uniform aka random

  # Step 3. Find the smallest subset s.t. cdf ≥ p
  # Let's illustrate with an example
  # probs: [0.5, 0.3, 0.2] , p = 0.9
  # cumsum = [0.5, 0.8, 1.0] # Note: higher p leads to more uniform sampling while lower values of p leads to a more greedy sampling
  # mask (if current value less than p? if so, include that as part of sampling) = [True, True, False]. Only sample from index [0, 1]
  cdf = torch.cumsum(sorted_probs, dim=-1) # batch_size, vocab_size

  counts = (cdf < p).sum(dim=-1) + 1 # count-based indexing. shape: batch_size, k where k is in range(0, vocab_size)
  # NOTE: why not do just cdf <= p? Because the rule is that we need to find the smallest set that is greater than or equal to p.
  # Suppose you have cdf as [0.5, 0.8, 1.0] and p = 0.7
  # The smallest subset that is at greater than or equal to p=0.7 is [0.5, 0.8]. But if we did (cdf <= p) we'd get [0.5]

  # We now need to conver this into a mask because as it stands right now the tensor is imbalanced
  arange = torch.arange(0, vocab_size).unsqueeze(0) # (1, vocab_size) ; alternative do: torch.arange(0, vocab_size).expand(batch_size, vocab_size)
  mask = arange < counts.unsqueeze(-1) # batch_size, vocab_size

  top_probs = sorted_probs * mask.float()
  denom = top_probs.sum(dim=1, keepdim=True).clamp(min=1e-8)
  normalized_top_probs = top_probs / denom # batch_size, vocab_size

  # Sample from top probabilites. Note: because we make use of multinomial probabilities we do not need to worry about actually removing samples that are close to 0
  top_sample_indices = torch.multinomial(normalized_top_probs, num_samples=1, generator=generator) # batch_size, 1

  # Step 4. Retrieve the indices that belong to that subset

  # Because top_sample_indices is a position in normalized_top_probs we need to align it back to sorted_idx
  # Rule of thumb of when to use gather. X.shape = m, n and idx.shape = m. Use gather.
  token_id = torch.gather(sorted_idx, dim=1, index=top_sample_indices).squeeze(1) # batch_size
  return token_id # this represents the top token id since token index in range vocab_size is equivalent to the index it belongs to


In [4]:
def verification(logits, **kwargs):
    '''
    Apply HuggingFace's top-k or top-p filtering and return sampled tokens.
    logits.shape = batch_size, vocab_size
    kwargs: Either {'p': float, 'temperature': float} for top-p or {'k': int, 'temperature': float} for top-k
    '''
    batch_size = logits.shape[0]

    # Create dummy input_ids for HF warpers
    input_ids = torch.zeros((batch_size, 1), dtype=torch.long)

    # Apply temperature
    temp_warper = TemperatureLogitsWarper(kwargs['temperature'])
    logits = temp_warper(input_ids, logits)

    # Apply top-p or top-k
    if 'p' in kwargs:
        topp_warper = TopPLogitsWarper(top_p=kwargs['p'])
        logits = topp_warper(input_ids, logits)
    elif 'k' in kwargs:
        topk_warper = TopKLogitsWarper(top_k=kwargs['k'])
        logits = topk_warper(input_ids, logits)

    # Convert to probabilities
    probs = torch.softmax(logits, dim=-1)

    # Sample tokens
    tokens = torch.multinomial(probs, num_samples=1, generator=kwargs['generator']).squeeze(-1)  # Shape: [batch_size]

    return tokens

In [5]:
batch_size, vocab_size = 10, 50_000
k = 10
p = 0.8
temperature = 1e-2 # Note: because huggingface implementation has few differences than my version, let's reduce temperature to sample more greedily

# Seed for reproducibility
gen = lambda: torch.Generator().manual_seed(314159)
accuracy = lambda gt, pred: round(((gt == pred).sum() / len(gt)).item(), 3)

logits = torch.randn(batch_size, vocab_size, generator=gen()) # Normal distribution: 65% of data from [-1, 1], 95% from [-2, 2], and 99.7% from [-3, 3]

In [6]:
gt_top_p_samples = verification(logits, temperature=temperature, p=p, generator=gen())
pred_top_p_samples = get_top_p(logits, p=p, temperature=temperature, generator=gen())
print("Accuracy of Top-p with official implementation:", accuracy(gt_top_p_samples, pred_top_p_samples))


gt_top_k_samples = verification(logits, temperature=temperature, k=k, generator=gen())
pred_top_k_samples = get_top_k(logits, k=k, temperature=temperature, generator=gen())
print("Accuracy of Top-k with official implementation:", accuracy(gt_top_k_samples, pred_top_k_samples))

Accuracy of Top-p with official implementation: 1.0
Accuracy of Top-k with official implementation: 1.0
