### Generating text from output tokens

In [1]:
### Step 1: idx is a (batch, n_tokens) array of indices in the current context
### Step 2: Crop current context if it exceeds the supported context size ex: if LLM supports only 5 tokens, and the context size
### is 10 then only the last 5 tokens are used as context.
### Step 3: Focus only on the last time step, so that (batch, n_token, vocab_size) becomes (batch, vocab_size)
### Step 4: probas has shape (batch, vocab_size)
### Step 5: idx_next has shape (batch, 1)
### Step 6: Append sampled index to the running sequence, where idx has shape (batch, n_tokens + 1)

In [2]:
import torch
import torch.nn as nn

In [3]:
def generate_text_simple(model, idx, max_new_tokens, context_size):
    
    for _ in range(max_new_tokens):
        
        # Crop current context if it exceed the supported context size
        # ex: if LLM supports only 5 tokens, and the context size is 10
        # then only the last 5 tokens are used as context
        idx_cond = idx[:, -context_size:]
        
        # Get the predictions
        with torch.no_grad():
            logits = model(idx_cond)
            
        # Focus only on the last time step
        # (batch, n_tokens, voacb_size) becomes (batch, vocab_size)
        logits = logits[:, -1, :]
        
        # Apply softmax to get probabilities
        probas = torch.softmax(logits, dim=-1)  # (batch, vocab_size)
        
        # Get the idx of the vocab entry with the highest probability value
        idx_next = torch.argmax(probas, dim=-1, keepdim=True)   # (batch, 1)
        
        # Append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)    # (batch, n_tokens + 1)
    
    return idx

In [4]:
### The softmax function is monotonic, meaning it preserves the order of its inputs when transformed into outputs.
### So, in practice, the softmax step is redundant since the position with the highest score in the softmax output tensor
### is the same position in the logit tensor.
### In other words, we could apply the torch.argmax function to the logits tensor directly and get identical results.
### However, we coded the conversion to illustrate the full process of transforming logits to probabilities, which can add
### additional intuition, such as that the model generates the most likely next token, which is known as greedy decoding.