### Decoding strategy 2: Top-k sampling

In [1]:
### In the previous strategy, we implemented a probabilistic sampling approach coupled with temperature scaling to increase
### the diversity of the outputs.
### We saw that higher temperature values result in more uniformly distributed next-token probabilities, which result in more
### diverse outputs as it reduces the likelihood of the model repeatedly selecting the most probable token.
### This method allows for exploring less likely but potentially more interesting and creative paths in the generation process.
### However, one downside of this approach is that it sometimes leads to grammatically incorrect or completely nonsensical outputs.

In [2]:
import torch

vocab = {
    "closer": 0,
    "every": 1,
    "effort": 2,
    "forward": 3,
    "inches": 4,
    "moves": 5,
    "pizza": 6,
    "towards": 7,
    "you": 8
}

inverse_vocab = {v: k for k, v in vocab.items()}

# Assume the LLM is given the start context "every effort moves you" and generates the following next output logits
next_token_logits = torch.tensor(
    [4.51, 0.89, -1.90, 6.75, 1.63, -1.62, -1.89, 6.28, 1.79]
)

In [4]:
top_k = 3
top_logits, top_pos = torch.topk(next_token_logits, top_k)
print("Top logits:", top_logits)
print("Top positions:", top_pos)

Top logits: tensor([6.7500, 6.2800, 4.5100])
Top positions: tensor([3, 7, 0])


In [5]:
new_logits = torch.where(
    condition=next_token_logits < top_logits[-1],
    input=torch.tensor(float("-inf")),
    other=next_token_logits
)

print(new_logits)

tensor([4.5100,   -inf,   -inf, 6.7500,   -inf,   -inf,   -inf, 6.2800,   -inf])


In [6]:
topk_probas = torch.softmax(new_logits, dim=0)
print(topk_probas)

tensor([0.0615, 0.0000, 0.0000, 0.5775, 0.0000, 0.0000, 0.0000, 0.3610, 0.0000])


### Merge temperature scaling and top-k sampling

In [7]:
### We can now apply the temperature scaling and multinomial function for probabilistic sampling introduced previously to select
### the next token among these 3 nonzero probability scores to generate the next token.

In [8]:
### Step 1: For-loop is the same as before: get logits, and only focus on last time step
### Step 2: In this new section we filter logits with top_k sampling
### Step 3: This is the new section where we apply temperature scaling
### Step 4: Carry out greedy next token selection as before when temperature scaling is disabled
### Step 5: Stop generating early if end of sequence token is encountered and eos_id is specified

In [9]:
def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
    
    # For-loop is the same as before: get logits, and only focus on last time step
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
        logits = logits[:, -1, :]
    
        # New: filter logits with top_k sampling
        if top_k is not None:
            # Keep only top_k values
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(
                logits < min_val,
                torch.tensor(float("-inf")).to(logits.device),
                logits
            )
        
        # New: apply temperature scaling
        if temperature > 0.0:
            logits = logits / temperature
                
            # Apply softmax to get probabilities
            probs = torch.softmax(logits, dim=-1)   # (batch_size, context_len)
                
            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)
        # Otherwise same as before: get the idx of the vocab entry with the highest logits value
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)   # (batch_size, 1)
            
        if idx_next == eos_id:  # Stop generating early if end of sequence token is encountered and eos_id is specified
            break
        
        # Same as before: append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
        
    return idx