In [1]:
import torch
import torch.nn.functional as F

In [2]:
# Helper
class Sampler:
    """
    Sampler base class
    """
    def __call__(self, logits):
        raise NotImplementedError

# Greedy Sampling

Greedy sampling (also called greedy decoding) selects the most probable next token at each step in the sequence generation — no randomness, just the top choice every time.

In [3]:
class GreedySampler(Sampler):
    def __call__(self, logits):
        return logits.argmax(dim=-1)

In [4]:
# Test
sampler = GreedySampler()
vocabulary = list(set("abcdefghijklmnopqrstuvwxyz."))
logits = torch.randn(1, 27)
next_token_id = sampler(logits)
next_token = vocabulary[next_token_id.item()]
print(f"{next_token = }")

next_token = 'f'


# Temperature Sampling

Temperature sampling is a technique used in language model sampling to control the randomness (or confidence) of token selection by scaling the logits before applying softmax.

$$
\text{scaled logits}_i = \frac{u_i}{T}
$$

$$
P(x_i = V_l \mid x_{1:i-1}) = \frac{\exp\left( \frac{u_l}{T} \right)}{\sum_{j} \exp\left( \frac{u_j}{T} \right)}
$$

where:
- $u_j$: Logit for token $j$
- $T$: Temperature
- $V_l$: Vocabulary token at index $l$
- $P(x_i = V_l \mid x_{1:i-1})$ is probability that the next token is vocabulary token $V_l$, given the sequence so far.

---

When $\text{temperature} = 0$ the sampling is treated as greedy sampling.

As $T \to 0$, the scaled logits $\frac{u_i}{T}$ become very large for the maximum logit $u_{max}$ and relatively much smaller for others. This causes the softmax distribution to concentrate all probability mass on the token with the highest logit:

$$
\lim_{T \to 0} \text{softmax}\left(\frac{u}{T}\right) = \text{one-hot vector at } \arg\max_i u_i
$$

Therefore, sampling from this distribution is equivalent to choosing the token with the highest logit, which is exactly greedy sampling. In practice, since dividing by zero is undefined and numerically unstable, return the $\arg\max$ when $T = 0$.


In [5]:
class TemperatureSampler(Sampler):
    def __init__(self, temperature=1):
        self.temperature = temperature

    def __call__(self, logits):
        if self.temperature == 0:
            # Greedy sampling
            return logits.argmax(dim=-1)

        # Scale logits
        scaled_logits = logits / self.temperature

        # Convert to probabilities
        probs = F.softmax(scaled_logits, dim=-1)

        # Sample from the probability distribution
        return torch.multinomial(probs, num_samples=1).squeeze(-1)

In [6]:
# Test
sampler = TemperatureSampler(temperature=5)
vocabulary = list(set("abcdefghijklmnopqrstuvwxyz."))
logits = torch.randn(1, 27)
next_token_id = sampler(logits)
next_token = vocabulary[next_token_id.item()]
print(f"{next_token = }")

next_token = 'b'


# Top-K Sampling

Top-k sampling limits the candidate tokens to the k highest probability tokens. It sets probabilities of all other tokens to zero. Then samples the next token only from these top-k tokens. This controls randomness by restricting to a smaller subset of likely tokens.

Top-k sampling can be combined with Temperature Scaling.

In [7]:
class TopKSampler(Sampler):
    def __init__(self, k=10):
        self.k = k

    def __call__(self, logits):
        # Get top-k logits and their indices
        topk_logits, topk_indices = torch.topk(logits, self.k)

        # Convert logits to probabilities
        topk_probs = F.softmax(topk_logits, dim=-1)

        # Sample from the top-k probability distribution
        next_token_idx = torch.multinomial(topk_probs, num_samples=1)

        # Map back to original vocabulary index and return
        return topk_indices[next_token_idx]

In [8]:
# Test
sampler = TopKSampler(3)
vocabulary = list(set("abcdefghijklmnopqrstuvwxyz."))
logits = torch.randn(1, 27)
next_token_id = sampler(logits.squeeze(0))
next_token = vocabulary[next_token_id.item()]
print(f"{next_token = }")

next_token = 'z'


# Top-P Sampling (Nucleus Sampling)

Top-p sampling selects tokens from the smallest set of tokens whose cumulative probability mass exceeds a threshold $p$
p (e.g., 0.75). Instead of restricting to a fixed top-k tokens, it dynamically chooses how many tokens to consider based on their cumulative probability.

- Sort tokens by probability.
- Take the top tokens until their cumulative probability $\geq p$
- Sample the next token from this reduced set.

In [9]:
class TopPSampler(Sampler):
    def __init__(self, p=0.75):
        self.p = p

    def __call__(self, logits):
        # Convert logits to probailities
        probs = F.softmax(logits, dim=-1)

        # Sort probabilities in descending order
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)

        # Compute cumulative probabilities
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # Find the cutoff index where cumulative prop exceeds p
        cutoff_index = torch.searchsorted(cumulative_probs, self.p)

        # Mask tokens beyond cutoff index by setting their probabilities 
        # to 0 so they can't be sampled
        filtered_probs = sorted_probs.clone()
        filtered_probs[cutoff_index + 1:] = 0

        # Normalize filtered probabilities
        filtered_probs = filtered_probs / filtered_probs.sum()

        # Sample from filtered distribution
        next_token_index = torch.multinomial(filtered_probs, num_samples=1)

        # Map back to original token index and return
        return sorted_indices[next_token_index]

In [10]:
# Test
sampler = TopPSampler(0.7)
vocabulary = list(set("abcdefghijklmnopqrstuvwxyz."))
logits = torch.randn(1, 27)
next_token_id = sampler(logits.squeeze(0))
next_token = vocabulary[next_token_id.item()]
print(f"{next_token = }")

next_token = 'e'
