#### **Top-K Sampling**

Top-k sampling restricts token selection to only the **top K most likely tokens**, preventing extremely low-probability tokens from being sampled.

In temperature sampling alone, even very low-probability tokens still have a small chance of being selected.  
Top-K fixes this by **keeping only the top K logits** and masking the rest.

---

### **Example**

**Logits:**  
`[4.51, 0.89, -1.90, 6.75, 1.63, -1.62, -1.89, 6.28, 1.79]`

**Top-k = 3**

Step 1 — **Select top 3 highest logits**  
Top 3 logits are:  
- 6.75  
- 6.28  
- 4.51  

Step 2 — **Mask all other logits with `-inf`**  
This removes them from the softmax calculation.

**Masked logits:**  
`[4.51, -inf, -inf, 6.75, -inf, -inf, -inf, 6.28, -inf]`

Step 3 — **Apply temperature scaling (optional)**  

Step 4 — **Apply softmax**  

**Softmax probabilities:**  
`[0.06, 0, 0, 0.57, 0, 0, 0, 0.36, 0]`

Step 5 — **Sample from these probabilities**  
Sampling occurs **only among the top-K tokens**.

---

### **Key Insight**

Top-K sampling ensures that:

- Only the top **K** most likely tokens have a chance to be selected  
- Extremely unlikely (noisy) tokens are completely removed  
- The model remains creative, but controlled  

**→ The next token will *always* be chosen from the top K logits.**

---


In [3]:
# loading the model 
import torch
from transformer_blocks import GPTModel

GPT_CONFIG_124M = {
    "vocab_size": 50257,   # Vocabulary size
    "context_length": 256, # Shortened context length (orig: 1024)
    "emb_dim": 768,        # Embedding dimension
    "n_heads": 12,         # Number of attention heads
    "n_layers": 12,        # Number of layers
    "drop_rate": 0.1,      # Dropout rate
    "qkv_bias": False      # Query-key-value bias
}

gpt_model = GPTModel(GPT_CONFIG_124M)

In [4]:
# loading the weights
state_dict = torch.load("checkpoints/gpt_model.pth", map_location='cpu')
gpt_model.load_state_dict(state_dict)

<All keys matched successfully>

In [11]:
t1 = torch.rand(2, 3, 7)
t2 = torch.topk(t1, k=3, dim=-1)

In [16]:
t2.values.shape

torch.Size([2, 3, 3])

In [54]:
import pdb
# function that applies temperature scaling and top_k sampling
def generate_next_token(model, top_k, temp, inputs):
    with torch.no_grad():
        logits = model(inputs)
        logits = logits[:, -1, :]
        if top_k != None:
            top_logits, _ = torch.topk(logits, top_k)
            min_value = top_logits[:, -1].unsqueeze(1)
            logits = torch.where(logits < min_value, torch.tensor(float("-inf")).to(logits.device), logits)
        # now use we'll use temeprature scaling
        if temp != 0.0:
            logits = logits / temp 
            # now do the softmax
            probs = torch.softmax(logits, dim=-1)
            # do the sampling from probability distribution
            next_token = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
        else:
            next_token = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)

        return next_token



In [55]:
inputs = torch.randint(0, 50257, (2, 256))

In [56]:
generate_next_token(gpt_model, 3, 0.3, inputs) 

tensor([[314],
        [339]])