In [3]:
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HF_TOKEN")
wandb_token = user_secrets.get_secret("wandb")

import os
os.environ["HF_TOKEN"] = hf_token
os.environ["WANDB_API_KEY"] = wandb_token
print("HF token loaded—rate limit bypassed!")  # Optional confirm

HF token loaded—rate limit bypassed!


In [4]:
!uv pip install -q wandb transformers

In [5]:
import torch, random, os, math, time, wandb
import numpy as np
# import tiktoken
# from datatrove.pipeline.readers import ParquetReader
from transformers import AutoTokenizer
from datasets import load_dataset
# from itertools import cycle
from torch.utils.data import DataLoader, IterableDataset
import torch.nn as nn
import torch.nn.functional as F  # For scaled_dot_product_attention
import torch.optim as optim
from torch.amp import autocast, GradScaler
import matplotlib.pyplot as plt
# from torch.profiler import profile, record_function, ProfilerActivity

wandb.login(key=wandb_token) 

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mtusharmishra802[0m ([33mtusharmishra802-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [6]:
tokenizer = AutoTokenizer.from_pretrained('huggyllama/llama-30b')
tokenizer.pad_token = tokenizer.eos_token

tokenizer_config.json:   0%|          | 0.00/700 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/411 [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.


In [7]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [8]:
CFG = {
    "seed": 42,
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # Model
    "vocab_size": 32000,          # llama vocab
    "emb_dim": 128,               # 4096 
    "context_length": 8,          # 4096
    "n_heads": 4,                 # 32
    "num_kv_heads": 2,
    "n_layers": 1,                # 32
    "drop_rate": 0.1,
    "qkv_bias": False,
    'base': 10000,
    "intermediate_size": int(8/3 * 4096),  # ~11008 for full

    # Data
    "max_tokens": 500_000,        # STOP after this many tokens
    "warmup_tokens": 10_000,      # linear warm-up
    "batch_size": 32,
    "shuffle_buffer": 5_000,

    # Optimiser
    "optimizer": "adamw",
    "lr": 3e-4,
    "final_lr": 3e-5,
    "weight_decay": 0.1,
    "betas": (0.9, 0.95),

    # SwiGLU
    'beta' : 1,

    # Misc
    "log_interval": 20,           # steps
    "wandb_project": "llama-demo",
    "wandb_run_name": None,       # auto-generated
}

In [9]:
torch.manual_seed(CFG["seed"])
if CFG["device"] == "cuda":
    torch.cuda.manual_seed_all(CFG["seed"])

# ------------------- 2. WANDB INIT -------------------
wandb.init(
    project=CFG["wandb_project"],
    name=CFG["wandb_run_name"],
    config=CFG,
    mode="online",   # set "offline" if you have no internet
)

In [10]:
class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(emb_dim))

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        x_norm = x / rms
        return x_norm * self.weight

class SiLU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * F.sigmoid(CFG["beta"] * x)

class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        intermediate_size = cfg.get("intermediate_size", int(8/3 * cfg["emb_dim"]))
        self.gate_proj = nn.Linear(cfg["emb_dim"], intermediate_size, bias=False)
        self.up_proj = nn.Linear(cfg["emb_dim"], intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, cfg["emb_dim"], bias=False)
        self.act_fn = SiLU()

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


In [11]:
class RotaryEmbeddings(nn.Module):
    def __init__(self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device=CFG["device"]):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        
        t = torch.arange(self.max_position_embeddings, dtype=torch.float32, device=device)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  # (T, D/2)
        self.register_buffer("cos_cached", freqs.cos(), persistent=False)
        self.register_buffer("sin_cached", freqs.sin(), persistent=False)

    def forward(self, x, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[-2]
        cos = self.cos_cached[:seq_len, ...].unsqueeze(0).unsqueeze(0)  # (1,1,T,D//2)
        sin = self.sin_cached[:seq_len, ...].unsqueeze(0).unsqueeze(0)
        
        x1 = x[..., : self.dim : 2]  # (..., D//2)
        x2 = x[..., 1 : self.dim : 2]
        
        rotated_x1 = x1 * cos - x2 * sin
        rotated_x2 = x1 * sin + x2 * cos
        
        return torch.cat((rotated_x1, rotated_x2), dim=-1)

In [12]:
ds = load_dataset("HuggingFaceFW/fineweb", split="train", streaming=True)

def tokenize_function(examples):
    texts = examples['text']
    tokenized = tokenizer(texts, truncation=False, add_special_tokens=False)  # Batched for speed
    tokenized['input_ids'] = [ids + [tokenizer.eos_token_id] for ids in tokenized['input_ids']]
    return tokenized

ds = ds.map(tokenize_function, batched=True, batch_size=1000, remove_columns=['text'])  # Drops raw text, keeps input_ids
ds = ds.shuffle(buffer_size=CFG['shuffle_buffer'])

README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/27468 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/27468 [00:00<?, ?it/s]

In [13]:
class SlidingWindowDataset(IterableDataset):
    def __init__(self, ds, tokenizer, context_len, stride, target_tokens):
        self.ds = ds  # HF streaming iterable
        self.tokenizer = tokenizer
        self.context_len = context_len
        self.stride = stride
        self.max_tokens = target_tokens
        self.pad_id = tokenizer.pad_token_id

    def __iter__(self):
        buffer = []
        token_count = 0
        for example in self.ds:  # Streams tokenized input_ids
            toks = example['input_ids']
            if not toks:
                continue
            buffer.extend(toks)

            # Fixed buffer logic
            while len(buffer) > self.context_len:
                x = buffer[:self.context_len]
                y = buffer[1:self.context_len + 1]
                # Pad if needed (rare post-fix)
                if len(y) < self.context_len:
                    y += [self.pad_id] * (self.context_len - len(y))
                yield {'input_ids': torch.tensor(x, dtype=torch.long),
                       'labels': torch.tensor(y, dtype=torch.long)}  # Dict for HF Trainer
                buffer = buffer[self.stride:]
                token_count += self.context_len
                if token_count >= self.max_tokens:
                    return

            # Cap buffer to prevent OOM
            if len(buffer) > 2 * self.context_len:
                buffer = buffer[-self.context_len:]

        # Remnant with padding
        if len(buffer) >= 128:  # Min threshold
            x = buffer[:self.context_len]
            y = buffer[1:min(self.context_len + 1, len(buffer) + 1)]
            if len(y) < self.context_len:
                y += [self.pad_id] * (self.context_len - len(y))
            yield {'input_ids': torch.tensor(x, dtype=torch.long),
                   'labels': torch.tensor(y, dtype=torch.long)}

# Usage
dataset = SlidingWindowDataset(ds, tokenizer, context_len=CFG['context_length'], stride=CFG["context_length"] // 2, target_tokens=CFG['max_tokens'])
dataloader = DataLoader(dataset, batch_size=32, num_workers=2, pin_memory=True, prefetch_factor=2, collate_fn=lambda b: {k: torch.stack([d[k] for d in b]) for k in b[0]})

In [14]:
if torch.cuda.is_available():
    try:
        torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=False)
        print("PyTorch SDP kernels enabled (mem_efficient & not math for speed on T4/P100)!")
    except Exception as e:
        print(f"Could not enable SDP kernels: {e}")

PyTorch SDP kernels enabled (mem_efficient & not math for speed on T4/P100)!


  self.gen = func(*args, **kwds)


In [15]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, dropout, num_heads, qkv_bias=False, device=CFG["device"]):
        super().__init__()
        assert d_out % num_heads == 0
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out, bias=False)
        self.rope = RotaryEmbeddings(self.head_dim, device=device)
        self.dropout = dropout
        print(f"MHA: {num_heads} heads, head_dim={self.head_dim}")

    def forward(self, x):
        b, t, _ = x.shape
        q = self.W_query(x).reshape(b, t, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.W_key(x).reshape(b, t, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.W_value(x).reshape(b, t, self.num_heads, self.head_dim).transpose(1, 2)
        q = self.rope(q)
        k = self.rope(k)
        attn = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
        attn = attn.transpose(1, 2).reshape(b, t, self.d_out)
        return self.out_proj(attn)


In [16]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_in, d_out, dropout, num_heads, num_kv_heads=None, qkv_bias=False, device=CFG["device"]):
        super().__init__()
        assert d_out % num_heads == 0
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.num_kv_heads = num_kv_heads or num_heads
        assert self.num_kv_heads <= self.num_heads, "num_kv_heads must <= num_heads"
        assert d_out % self.num_kv_heads == 0 or self.num_kv_heads == num_heads, "Inconsistent head dims for GQA"
        
        # Projections
        self.W_query = nn.Linear(d_in, self.num_heads * self.head_dim, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, self.num_kv_heads * self.head_dim, bias=qkv_bias)  # Renamed for clarity
        self.W_value = nn.Linear(d_in, self.num_kv_heads * self.head_dim, bias=qkv_bias)
        self.out_proj = nn.Linear(self.num_heads * self.head_dim, d_out, bias=False)
        
        # RoPE with device
        self.rope = RotaryEmbeddings(self.head_dim, device=device)
        self.dropout = dropout

    def forward(self, x):
        b, t, d_in = x.shape
        
        # Project Q/K/V
        q = self.W_query(x).reshape(b, t, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.W_key(x).reshape(b, t, self.num_kv_heads, self.head_dim).transpose(1, 2)
        v = self.W_value(x).reshape(b, t, self.num_kv_heads, self.head_dim).transpose(1, 2)

        # GQA Repeat
        if self.num_kv_heads != self.num_heads:
            repeat_factor = self.num_heads // self.num_kv_heads
            k = k.repeat_interleave(repeat_factor, dim=1)
            v = v.repeat_interleave(repeat_factor, dim=1)

        q = self.rope(q)
        k = self.rope(k)

        # SDPA (use full k/v seq_len for mask/attn)
        attn_output = F.scaled_dot_product_attention(
            q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=True, attn_mask=None
        )
        
        # Merge
        attn_output = attn_output.transpose(1, 2).contiguous().reshape(b, t, self.d_out)
        output = self.out_proj(attn_output)
                
        return output

In [17]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg, device=CFG["device"]):
        super().__init__()
        self.att = GroupedQueryAttention(cfg["emb_dim"], cfg["emb_dim"], cfg["drop_rate"], cfg["n_heads"], cfg['num_kv_heads'], cfg["qkv_bias"], device)
        self.ff = FeedForward(cfg)
        self.norm1 = RMSNorm(cfg["emb_dim"])
        self.norm2 = RMSNorm(cfg["emb_dim"])

    def forward(self, x):
        x = x + self.att(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

In [18]:
class LlamaModel(nn.Module):
    def __init__(self, cfg, device=CFG["device"]):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg, device) for _ in range(cfg["n_layers"])])
        self.final_norm = RMSNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
        self.out_head.weight = self.tok_emb.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx):
        x = self.tok_emb(idx)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        return self.out_head(x)

In [19]:
model_cfg = {
    "vocab_size": CFG["vocab_size"],
    "context_length": CFG["context_length"],
    "emb_dim": CFG["emb_dim"],
    "n_heads": CFG["n_heads"],
    "num_kv_heads": CFG["num_kv_heads"],
    "n_layers": CFG["n_layers"],
    "drop_rate": CFG["drop_rate"],
    "qkv_bias": CFG["qkv_bias"],
}
model = LlamaModel(model_cfg, device=CFG["device"])  # Pass device here!

# 2. Move to GPU + FP16
model = model.to(CFG["device"]).half()

In [20]:
def count_params(model):
    return sum(p.numel() for p in model.parameters())

print(f"Total params: {count_params(model):,}")
print("Data equal:", torch.equal(model.out_head.weight, model.tok_emb.weight))

model.tok_emb.weight.data[0, 0] = 999.0  # Modify embedding
print("Shared? out_head[0,0] after change:", model.out_head.weight.data[0, 0])  # Should be 999.0
print("Total params after mod:", count_params(model))  # Still ~124M—no extra

Total params: 4,276,480
Data equal: True
Shared? out_head[0,0] after change: tensor(999., device='cuda:0', dtype=torch.float16)
Total params after mod: 4276480


In [21]:
# ------------------- 7. OPTIMISER + SCALER -------------------
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CFG["lr"],
    betas=CFG["betas"],
    weight_decay=CFG["weight_decay"],
    fused=True,                          # works on P100
)

scaler = GradScaler()

# ------------------- 8. LR SCHEDULER -------------------
total_train_tokens = CFG["max_tokens"]
warmup_tokens      = CFG["warmup_tokens"]
base_lr            = CFG["lr"]
final_lr           = CFG["final_lr"]

def lr_lambda(tokens_seen):
    if tokens_seen <= warmup_tokens:
        return tokens_seen / max(1, warmup_tokens)               # linear warm-up
    progress = (tokens_seen - warmup_tokens) / max(1, total_train_tokens - warmup_tokens)
    cosine = 0.5 * (1.0 + math.cos(math.pi * min(1.0, progress)))
    return (final_lr / base_lr) + (1.0 - final_lr / base_lr) * cosine

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

print(f"LR Scheduler: Warmup {warmup_tokens:,} → Cosine decay to {final_lr:.1e}")

LR Scheduler: Warmup 10,000 → Cosine decay to 3.0e-05


In [22]:
def validate(model, loader):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    with torch.no_grad():
        for batch in loader:
            batch = {k: v.to(CFG["device"], non_blocking=True) for k, v in batch.items()}
            with autocast(device_type=CFG['device'], dtype=torch.bfloat16, enabled=True):
                logits = model(batch["input_ids"])
                loss = criterion(
                    logits.view(-1, logits.size(-1)),
                    batch["labels"].view(-1),
                )
            num_tokens = (batch["labels"] != -100).sum().item()
            total_loss += loss.item() * num_tokens
            total_tokens += num_tokens
    model.train()
    return total_loss / total_tokens

In [23]:
# ------------------- 9. TRAINING LOOP -------------------
criterion = nn.CrossEntropyLoss(ignore_index=-100)

tokens_seen = 0
step = 0
start_time = time.time()

print("Starting training …")
for batch in dataloader:
    step += 1
    batch = {k: v.to(CFG["device"], non_blocking=True) for k, v in batch.items()}

    # ---- forward + loss -------------------------------------------------
    with autocast(device_type=CFG["device"], dtype=torch.float16):
        logits = model(batch["input_ids"])
        loss = criterion(
            logits.view(-1, logits.size(-1)),
            batch["labels"].view(-1),
        )

    # ---- backward -------------------------------------------------------
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)

    # ---- LR step (per token, not per step) -------------------------------
    tokens_seen += batch["input_ids"].numel()
    scheduler.step()                     # LambdaLR uses the *current* token count

    # ---- logging ---------------------------------------------------------
    if step % CFG["log_interval"] == 0:
        elapsed = time.time() - start_time
        tokens_per_sec = tokens_seen / elapsed
        lr = optimizer.param_groups[0]["lr"]

        wandb.log({
            "step": step,
            "loss": loss.item(),
            "lr": lr,
            "tokens_seen": tokens_seen,
            "tokens_per_sec": tokens_per_sec,
            "gpu_mem_gb": torch.cuda.max_memory_allocated() / 1e9,
        }, step=step)
        print(
            f"Step {step:5d} | "
            f"Loss {loss.item():.4f} | "
            f"LR {lr:.2e} | "
            f"Tokens {tokens_seen:,}/{CFG['max_tokens']:,} | "
            f"Speed {tokens_per_sec:,.0f} t/s"
        )

    # ---- early stop ------------------------------------------------------
    if tokens_seen >= CFG["max_tokens"]:
        print("\nReached target token count → stopping.")
        break

# --------------------------------------------------------------
wandb.finish()
print("Training finished!")
# --------------------------------------------------------------

Starting training …


Token indices sequence length is longer than the specified maximum sequence length for this model (2092 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (3506 > 2048). Running this sequence through the model will result in indexing errors


Step    20 | Loss 610.6035 | LR 6.00e-07 | Tokens 5,120/500,000 | Speed 324 t/s
Step    40 | Loss 565.7599 | LR 1.20e-06 | Tokens 10,240/500,000 | Speed 641 t/s
Step    60 | Loss 422.3762 | LR 1.80e-06 | Tokens 15,360/500,000 | Speed 950 t/s
Step    80 | Loss 445.9661 | LR 2.40e-06 | Tokens 20,480/500,000 | Speed 1,251 t/s
Step   100 | Loss 323.7099 | LR 3.00e-06 | Tokens 25,600/500,000 | Speed 1,550 t/s
Step   120 | Loss 501.3358 | LR 3.60e-06 | Tokens 30,720/500,000 | Speed 1,845 t/s
Step   140 | Loss 483.0356 | LR 4.20e-06 | Tokens 35,840/500,000 | Speed 2,136 t/s
Step   160 | Loss 275.6973 | LR 4.80e-06 | Tokens 40,960/500,000 | Speed 2,421 t/s
Step   180 | Loss 296.7721 | LR 5.40e-06 | Tokens 46,080/500,000 | Speed 2,701 t/s
Step   200 | Loss 502.7380 | LR 6.00e-06 | Tokens 51,200/500,000 | Speed 2,974 t/s
Step   220 | Loss 500.6068 | LR 6.60e-06 | Tokens 56,320/500,000 | Speed 3,243 t/s
Step   240 | Loss 151.7938 | LR 7.20e-06 | Tokens 61,440/500,000 | Speed 3,510 t/s
Step   260 

0,1
gpu_mem_gb,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,███▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇██
tokens_per_sec,▁▁▁▁▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇█████
tokens_seen,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██

0,1
gpu_mem_gb,0.14571
loss,10.4626
lr,6e-05
step,1940.0
tokens_per_sec,16549.27979
tokens_seen,496640.0


Training finished!


In [24]:
# Fixed generation function (top-k sampling)
def generate(model, tokenizer, prompt, max_new_tokens=50, top_k=50, temperature=0.8, pad_token_id=None):
    if pad_token_id is None:
        pad_token_id = tokenizer.eos_token_id
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(CFG["device"])
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            with autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.bfloat16, enabled=True):
                logits = model(input_ids)[:, -1, :]  # Last position logits
                logits = logits / temperature
                # Top-k filtering
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                probs = torch.softmax(logits, dim=-1)
                mask = logits[0] < v[0].min()  # 1D boolean [vocab]
                probs[0][mask] = 0  # Set low-prob to 0
                next_token = torch.multinomial(probs, num_samples=1)
                input_ids = torch.cat([input_ids, next_token], dim=-1)
                if next_token.item() == pad_token_id:
                    break  # Stop at EOS
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Test prompts (web/edu themed)
prompts = [
    "Web hosting is essential for",
    "Machine learning models train on datasets like",
    "A good developer should know",
    "FineWeb-Edu is a filtered version of",
    "The future of AI in education involves",
    "Why do you think people follow me"
]

print("=== Generation Tests (Final Train Loss: 4.19 | PPL: {:.0f}) ===".format(math.exp(4.19)))
for prompt in prompts:
    with torch.no_grad():  # Per-prompt for safety
        generated = generate(model, tokenizer, prompt, max_new_tokens=50, top_k=50, temperature=0.8)
        continuation = generated[len(prompt):].strip()  # Continuation only
        print(f"\nPrompt: {prompt}")
        print(f"Output: {continuation}")
        print("-" * 80)

=== Generation Tests (Final Train Loss: 4.19 | PPL: 66) ===

Prompt: Web hosting is essential for
Output: 0, they- in is, it have will who of is that,' at , “ the
 it in when of not I the one't an but0. we. for you, they to of to to with
 that0
--------------------------------------------------------------------------------

Prompt: Machine learning models train on datasets like
Output: the of and have as. his of I to is be it
 to haveed of to1 we is,
 you of aed
 for who in as1 is of
 in in be and'’  to
ation be
--------------------------------------------------------------------------------

Prompt: A good developer should know
Output: 5 and the for the that,
 to' of it in1 we that your have I to
 to a for I bes them the one ofation when is them with is, his in1. their not15 it 0
--------------------------------------------------------------------------------

Prompt: FineWeb-Edu is a filtered version of
Output: . of it1 for can of in that one is  of so, of the to with we. to can the