# 07. Text Generation

Once the model is trained, we can use it to generate text. This notebook explores different decoding strategies:

1.  **Greedy Decoding**: Selecting the most likely token at each step.
2.  **Sampling**: Sampling from the probability distribution (with temperature).
3.  **Top-k Sampling**: Restricting sampling to the top $k$ tokens.
4.  **Top-p (Nucleus) Sampling**: Restricting sampling to the smallest set of tokens with cumulative probability $p$.
5.  **Beam Search**: Exploring multiple paths to find the most likely sequence.

In [1]:
import torch
import torch.nn.functional as F

# Dummy model and tokenizer for demonstration
class DummyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.config = type('Config', (), {'n_positions': 1024})()
    def forward(self, x):
        # Return random logits: [batch, seq_len, vocab_size]
        return torch.randn(x.size(0), x.size(1), 1000), None

class DummyTokenizer:
    def encode(self, text): return [1, 2, 3]
    def decode(self, ids): return "generated text"

model = DummyModel()
tokenizer = DummyTokenizer()
device = torch.device("cpu")

## 1. Greedy Decoding

Greedy decoding simply selects the token with the highest probability at each step. It is fast but can lead to repetitive and dull text.

In [2]:
def generate_greedy(model, tokenizer, prompt, max_new_tokens=20):
    input_ids = torch.tensor([tokenizer.encode(prompt)]).to(device)
    
    for _ in range(max_new_tokens):
        logits, _ = model(input_ids)
        next_token_logits = logits[:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
        input_ids = torch.cat([input_ids, next_token], dim=1)
        
    return tokenizer.decode(input_ids[0].tolist())

print(generate_greedy(model, tokenizer, "Hello world"))

generated text


## 2. Sampling with Temperature

Sampling introduces randomness. The **temperature** parameter controls the "sharpness" of the distribution.
- Low temperature ($T < 1$): Makes the distribution sharper (more confident, less random).
- High temperature ($T > 1$): Flattens the distribution (more random, more creative).

In [3]:
def generate_sampling(model, tokenizer, prompt, max_new_tokens=20, temperature=1.0):
    input_ids = torch.tensor([tokenizer.encode(prompt)]).to(device)
    
    for _ in range(max_new_tokens):
        logits, _ = model(input_ids)
        next_token_logits = logits[:, -1, :] / temperature
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        input_ids = torch.cat([input_ids, next_token], dim=1)
        
    return tokenizer.decode(input_ids[0].tolist())

print(generate_sampling(model, tokenizer, "Hello world", temperature=0.8))

generated text


## 3. Top-k Sampling

Top-k sampling restricts the sampling pool to the $k$ most likely tokens. This prevents the model from choosing very unlikely (and potentially irrelevant) tokens.

In [4]:
def generate_top_k(model, tokenizer, prompt, max_new_tokens=20, k=50):
    input_ids = torch.tensor([tokenizer.encode(prompt)]).to(device)
    
    for _ in range(max_new_tokens):
        logits, _ = model(input_ids)
        next_token_logits = logits[:, -1, :]
        
        # Filter logits
        top_k_logits, top_k_indices = torch.topk(next_token_logits, k)
        next_token_logits[next_token_logits < top_k_logits[:, [-1]]] = -float('Inf')
        
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        input_ids = torch.cat([input_ids, next_token], dim=1)
        
    return tokenizer.decode(input_ids[0].tolist())

print(generate_top_k(model, tokenizer, "Hello world", k=10))

generated text


## 4. Top-p (Nucleus) Sampling

Top-p sampling selects the smallest set of tokens whose cumulative probability exceeds $p$. This dynamically adjusts the size of the sampling pool based on the confidence of the model.

In [5]:
def generate_top_p(model, tokenizer, prompt, max_new_tokens=20, p=0.9):
    input_ids = torch.tensor([tokenizer.encode(prompt)]).to(device)
    
    for _ in range(max_new_tokens):
        logits, _ = model(input_ids)
        next_token_logits = logits[:, -1, :]
        
        # Sort logits
        sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        
        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        next_token_logits[indices_to_remove] = -float('Inf')
        
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        input_ids = torch.cat([input_ids, next_token], dim=1)
        
    return tokenizer.decode(input_ids[0].tolist())

print(generate_top_p(model, tokenizer, "Hello world", p=0.9))

generated text
