<b>(A) Hugging Face generate() with advanced sampling knobs</b>

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

model_id = "gpt2"  # swap to your local Llama/Mistral checkpoint if available
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_id).eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

prompt = "Explain self-attention in one paragraph."
inputs = tokenizer(prompt, return_tensors="pt").to(device)

set_seed(42)

gen = model.generate(
    **inputs,
    max_new_tokens=120,

    # ----- sampling must be ON -----
    do_sample=True,

    # ----- core sampling -----
    temperature=0.8,      # <1 = more deterministic
    top_p=0.9,            # nucleus
    top_k=50,             # top-k (often combined with top_p)
    typical_p=0.95,       # locally typical sampling (optional)

    # ----- anti-repetition -----
    repetition_penalty=1.1,
    no_repeat_ngram_size=3,
)

print(tokenizer.decode(gen[0], skip_special_tokens=True))


  from .autonotebook import tqdm as notebook_tqdm
Using pad_token, but it is not set yet.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Explain self-attention in one paragraph.
Acknowledge the importance of reading and writing while staying true to your original purpose, such as understanding how things work (which will help you keep track of what is important) or keeping an eye on yourself when it comes time for a "drink." This can be helpful if there are situations where something seems out of place: For example, some people don't want to talk about their weight problems because they feel that's upsetting them; but others just think this may have been done by someone else who wants answers from other folks without having read through all his/her thoughts first before giving him


<b>(B) Implement temperature + top-k + top-p from scratch (single step)</b>

In [3]:
import torch

def sample_next_token(logits, temperature=1.0, top_k=None, top_p=None):
    """
    logits: (V,) tensor for next token
    returns: sampled token_id (int)
    """
    # 1) temperature
    if temperature is not None and temperature > 0:
        logits = logits / temperature

    # 2) top-k filter
    if top_k is not None and top_k > 0:
        v, _ = torch.topk(logits, top_k)
        min_keep = v[-1]
        logits = torch.where(logits < min_keep, torch.tensor(float("-inf"), device=logits.device), logits)

    # 3) top-p (nucleus) filter
    if top_p is not None and 0 < top_p < 1:
        probs = torch.softmax(logits, dim=-1)
        sorted_probs, sorted_idx = torch.sort(probs, descending=True)
        cum = torch.cumsum(sorted_probs, dim=-1)

        # keep smallest set with cum prob >= top_p
        keep = cum <= top_p
        keep[..., 0] = True  # always keep at least 1
        filtered_idx = sorted_idx[keep]
        mask = torch.ones_like(logits, dtype=torch.bool)
        mask[filtered_idx] = False
        logits = logits.masked_fill(mask, float("-inf"))

    # 4) sample
    probs = torch.softmax(logits, dim=-1)
    return int(torch.multinomial(probs, num_samples=1).item())


<b>Here is a PyTorch implementation of Temperature + Top-K + Top-P.</b>

In [4]:
import torch
import torch.nn.functional as F

def sample_next_token(logits, temperature=1.0, top_k=0, top_p=0.0):
    """
    logits: [batch_size, vocab_size] - Raw output from the model's last layer
    """
    
    # 1. Apply Temperature
    # Higher T makes distribution flatter (more random)
    # Lower T makes distribution sharper (more deterministic)
    if temperature != 1.0:
        logits = logits / temperature

    # 2. Filter: Top-K Sampling
    # Keep only the top K tokens, mask the rest to -infinity
    if top_k > 0:
        # Find the value of the K-th sorted element
        top_k_values, _ = torch.topk(logits, top_k)
        # The cutoff is the smallest value in the top K
        k_cutoff = top_k_values[:, -1].unsqueeze(1)
        # Mask everything below cutoff
        logits = torch.where(logits < k_cutoff, torch.tensor(float('-inf')), logits)

    # 3. Filter: Top-P (Nucleus) Sampling
    # Keep the top tokens with cumulative probability >= top_p
    if top_p > 0.0:
        # Sort logits in descending order
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        # Convert to probabilities
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Create a mask for tokens to REMOVE
        # We shift the mask right by 1 to always keep at least the first token
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Scatter the mask back to the original indices
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits = logits.masked_fill(indices_to_remove, float('-inf'))

    # 4. Final Sampling
    # Convert filtered logits to probabilities
    probs = F.softmax(logits, dim=-1)
    
    # Sample from the distribution
    next_token = torch.multinomial(probs, num_samples=1)
    
    return next_token

# --- Example Usage ---
# Fake logits for a vocab of 10 words
fake_logits = torch.tensor([[10.0, 9.0, 8.0, 2.0, 1.0, 1.0, 1.0, 0.5, 0.5, 0.1]])

# Scenario A: Greedy (Effectively Temp -> 0)
print(f"Greedy: {torch.argmax(fake_logits)}") 

# Scenario B: High Temp + Nucleus (Creative)
token = sample_next_token(fake_logits, temperature=1.2, top_p=0.9)
print(f"Sampled Index: {token.item()}")

Greedy: 0
Sampled Index: 0


In [5]:
# Pseudo-code for TensorRT-LLM Sampling Configuration
sampling_config = trtllm.runtime.SamplingConfig(
    temperature=0.7,       # Balance creativity/coherence
    top_k=0,               # Disable Top-K (rely on Top-P)
    top_p=0.95,            # Nucleus Sampling
    repetition_penalty=1.2 # Critical: Penalize tokens that appeared recently
)

NameError: name 'trtllm' is not defined