In [1]:
import os
import torch

# === PATHS ===
input_dir = "/kaggle/input/gpt-fineweb-100m-checkpoint"
extract_dir = "/kaggle/working/checkpoints"  # Not needed, but keeping for clarity
os.makedirs(extract_dir, exist_ok=True)

# === FIND .pth FILE ===
ckpt_files = [f for f in os.listdir(input_dir) if f.endswith('.pth')]

if not ckpt_files:
    raise FileNotFoundError("No .pth file found in input directory")

# Use the only (or latest) one
checkpoint_path = os.path.join(input_dir, ckpt_files[0])
print(f"Found checkpoint: {checkpoint_path}")

# === LOAD CHECKPOINT ===
# Replace with your actual device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(checkpoint_path, map_location=device)

tokens_seen = checkpoint.get('tokens_seen', 0)
step = checkpoint.get('step', 0)

print(f"Loaded: {checkpoint_path}")
print(f"   → Resuming at step {step}, tokens seen: {tokens_seen:,}")

Found checkpoint: /kaggle/input/gpt-fineweb-100m-checkpoint/checkpoint_100016128.pth
Loaded: /kaggle/input/gpt-fineweb-100m-checkpoint/checkpoint_100016128.pth
   → Resuming at step 12210, tokens seen: 100,016,128


In [2]:
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HF_TOKEN")
wandb_token = user_secrets.get_secret("wandb")

import os
os.environ["HF_TOKEN"] = hf_token
os.environ["WANDB_API_KEY"] = wandb_token
print("HF token loaded—rate limit bypassed!")  # Optional confirm

!uv pip install -q wandb transformers

import torch, random, os, math, time, wandb
import numpy as np
# import tiktoken
# from datatrove.pipeline.readers import ParquetReader
from transformers import AutoTokenizer
from datasets import load_dataset
# from itertools import cycle
from torch.utils.data import DataLoader, IterableDataset
import torch.nn as nn
import torch.nn.functional as F  # For scaled_dot_product_attention
import torch.optim as optim
from torch.amp import autocast, GradScaler
import matplotlib.pyplot as plt
# from torch.profiler import profile, record_function, ProfilerActivity

tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
wandb.login(key=wandb_token) 

HF token loaded—rate limit bypassed!




tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mtusharmishra802[0m ([33mtusharmishra802-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
CFG = {
    "seed": 42,
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # Model
    "vocab_size": 50257,          # GPT-2 vocab
    "emb_dim": 768,
    "context_length": 256,        # keep small for speed
    "n_heads": 12,
    "n_layers": 12,
    "drop_rate": 0.0, # 0.1 or 0.2 during fine tuning
    "qkv_bias": False,

    # Data
    "max_tokens": 200_000_000,        # STOP after this many tokens
    "warmup_tokens": 2_000_000,      # linear warm-up
    "batch_size": 16,
    "shuffle_buffer": 5_000,
    "val_max_tokens" : 10_000,
    "tokens_per_step": 16 * 256,

    # Optimiser
    "optimizer": "adamw",
    "lr": 3e-4,
    "final_lr": 6e-5,
    "eps": 1e-8,
    "weight_decay": 0.1,
    "betas": (0.9, 0.95),

    # Misc
    "log_interval": 20,           # steps
    "checkpoint_interval": 50_000_000,  # New: Save checkpoint every 50M tokens
    "validate_interval": 50_000_000,    # New: Validate every 50M tokens (align with checkpoints)
    "wandb_project": "gpt-fineweb-demo",
    "wandb_run_id": 'icy-night-14',       # auto-generated
}

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

if torch.cuda.is_available():
    try:
        torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=False)
        print("PyTorch SDP kernels enabled (mem_efficient & not math for speed on T4/P100)!")
    except Exception as e:
        print(f"Could not enable SDP kernels: {e}")

PyTorch SDP kernels enabled (mem_efficient & not math for speed on T4/P100)!


  self.gen = func(*args, **kwds)


In [4]:
torch.manual_seed(CFG["seed"])
if CFG["device"] == "cuda":
    torch.cuda.manual_seed_all(CFG["seed"])

# ------------------- 2. WANDB INIT -------------------
# wandb.init(
#     project=CFG["wandb_project"],
#     name=CFG["wandb_run_name"],
#     config=CFG,
#     mode="online",   # set "offline" if you have no internet
# )

# Wandb: Resume existing run (or start new with resume='allow')
if 'wandb_run_id' in checkpoint:  # Saved from prior
    wandb.init(project=CFG['wandb_project'], id=checkpoint['wandb_run_id'], resume='allow')
else:
    wandb.init(project=CFG['wandb_project'], name=CFG.get('wandb_run_name', f"resume_{tokens_seen}"), resume='allow')

print(f"Wandb resumed: step {step}, tokens {tokens_seen:,}")

Wandb resumed: step 12210, tokens 100,016,128


In [5]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))
        
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x-mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift

class GELU(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return 0.5*x*(1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) * 
            (x + 0.044715*torch.pow(x, 3))
        ))

class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg['emb_dim'], 4*cfg['emb_dim']),
            GELU(),
            nn.Linear(4*cfg['emb_dim'], cfg['emb_dim'])
        )
        
    def forward(self, x):
        return self.layers(x)

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # d_k = d_v = d_out / num_heads

        # Linear projections for Q, K, V
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Output projection

        self.dropout = dropout  # Store value for SDPA
        self.scale = self.head_dim ** -0.5

        # Optional: causal mask buffer (not needed with is_causal=True)
        # self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

        print(f"MHA: {num_heads} heads, head_dim={self.head_dim}, SDPA enabled")

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        # (b, num_tokens, d_out)
        queries = self.W_query(x)
        keys    = self.W_key(x)
        values  = self.W_value(x)

        # Split into heads: (b, num_tokens, num_heads, head_dim)
        # Use .reshape() instead of .view() for safety
        queries = queries.reshape(b, num_tokens, self.num_heads, self.head_dim)
        keys    = keys.reshape(b, num_tokens, self.num_heads, self.head_dim)
        values  = values.reshape(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_heads, num_tokens, head_dim)
        queries = queries.transpose(1, 2)
        keys    = keys.transpose(1, 2)
        values  = values.transpose(1, 2)

        # Use SDPA with causal masking
        # IMPORTANT: dropout_p only applies during training
        attn_output = F.scaled_dot_product_attention(
            query=queries,
            key=keys,
            value=values,
            attn_mask=None,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=True  # This creates the causal mask automatically
        )

        # Merge heads: (b, num_tokens, num_heads, head_dim) → (b, num_tokens, d_out)
        attn_output = attn_output.transpose(1, 2).contiguous()  # (b, num_tokens, num_heads, head_dim)
        context_vec = attn_output.reshape(b, num_tokens, self.d_out)

        # Final linear projection + dropout
        context_vec = self.out_proj(context_vec)
        context_vec = F.dropout(context_vec, p=self.dropout, training=self.training)

        return context_vec

In [7]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in = cfg['emb_dim'],
            d_out = cfg['emb_dim'],
            context_length = cfg['context_length'],
            num_heads = cfg['n_heads'],
            dropout = cfg['drop_rate'],
            qkv_bias = cfg['qkv_bias']        
        )
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg['emb_dim'])
        self.norm2 = LayerNorm(cfg['emb_dim'])
        self.drop_shortcut = nn.Dropout(cfg['drop_rate'])
        
    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)
        x = shortcut + self.drop_shortcut(x)
        
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = shortcut + self.drop_shortcut(x)
        return x

In [8]:
class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg['vocab_size'], cfg['emb_dim'])
        self.pos_emb = nn.Embedding(cfg['context_length'], cfg['emb_dim'])
        self.drop_emb = nn.Dropout(cfg['drop_rate'])
        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg['n_layers'])]
        )
        self.final_norm = LayerNorm(cfg['emb_dim'])
        self.out_head = nn.Linear(
            cfg['emb_dim'], cfg['vocab_size'], bias=False
        )
        self.out_head.weight = self.tok_emb.weight # weight tying
        # === Weight Initialization (CRITICAL) ===
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        
    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(
            torch.arange(seq_len, device=in_idx.device)
        )
        x = tok_embeds + pos_embeds
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

In [9]:
model_cfg = {
    "vocab_size": CFG["vocab_size"],
    "context_length": CFG["context_length"],
    "emb_dim": CFG["emb_dim"],
    "n_heads": CFG["n_heads"],
    "n_layers": CFG["n_layers"],
    "drop_rate": CFG["drop_rate"],
    "qkv_bias": CFG["qkv_bias"],
}
model = GPTModel(model_cfg).to(CFG["device"])
model = model.to(dtype=torch.bfloat16)      # FP16 for P100
model.load_state_dict(checkpoint['model_state_dict'])

MHA: 12 heads, head_dim=64, SDPA enabled
MHA: 12 heads, head_dim=64, SDPA enabled
MHA: 12 heads, head_dim=64, SDPA enabled
MHA: 12 heads, head_dim=64, SDPA enabled
MHA: 12 heads, head_dim=64, SDPA enabled
MHA: 12 heads, head_dim=64, SDPA enabled
MHA: 12 heads, head_dim=64, SDPA enabled
MHA: 12 heads, head_dim=64, SDPA enabled
MHA: 12 heads, head_dim=64, SDPA enabled
MHA: 12 heads, head_dim=64, SDPA enabled
MHA: 12 heads, head_dim=64, SDPA enabled
MHA: 12 heads, head_dim=64, SDPA enabled


<All keys matched successfully>

In [10]:
ds = load_dataset("HuggingFaceFW/fineweb-edu", split='train', streaming=True)

def tokenize_function(examples):
    texts = examples['text']
    tokenized = tokenizer(texts, truncation=False, add_special_tokens=False)  # Batched for speed
    tokenized['input_ids'] = [ids + [tokenizer.eos_token_id] for ids in tokenized['input_ids']]
    return tokenized

ds = ds.map(tokenize_function, batched=True, batch_size=1000, remove_columns=['text'])  # Drops raw text, keeps input_ids
ds = ds.shuffle(buffer_size=CFG['shuffle_buffer'])


val_raw = load_dataset("wikitext", "wikitext-103-v1", split="test", streaming=True)  # Or "HuggingFaceFW/fineweb" for raw contrast
val_raw = val_raw.map(tokenize_function, batched=True, batch_size=1000, remove_columns=['text'])
val_raw = val_raw.shuffle(buffer_size=CFG['shuffle_buffer'], seed=CFG['seed'] + 1)  # Fixed seed, diff from train

README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/2410 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/2410 [00:00<?, ?it/s]

README.md: 0.00B [00:00, ?B/s]

In [11]:
# Before dataset: Set pad_token for GPT-2 (safe default)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print(f"Set pad_token to EOS: {tokenizer.pad_token_id}")

class SlidingWindowDataset(IterableDataset):
    def __init__(self, ds, tokenizer, context_len, stride, target_tokens, skip_tokens=0, pad_labels=True):
        self.ds = ds
        self.tokenizer = tokenizer
        self.context_len = context_len
        self.stride = stride
        self.max_tokens = target_tokens
        self.skip_tokens = skip_tokens
        self.pad_id = tokenizer.pad_token_id  # Now safely EOS
        self.pad_labels = pad_labels  # New: Enable label masking

    def __iter__(self):
        buffer = []
        token_count = 0
        skipping = self.skip_tokens > 0
        pad_val_labels = -100 if self.pad_labels else self.pad_id

        for example in self.ds:
            toks = example['input_ids']
            if not toks:
                continue
            buffer.extend(toks)

            while len(buffer) > self.context_len:
                # Extract full sequences
                x = buffer[:self.context_len]  # Inputs: always full, no pad needed here
                y_base = buffer[1:self.context_len + 1]
                
                # Pad if short (remnant)
                num_pads = self.context_len - len(y_base)
                if num_pads > 0:
                    # Pad inputs too for consistency (though rare)
                    x += [self.pad_id] * num_pads
                    # Pad labels with -100 to mask
                    y_base += [pad_val_labels] * num_pads
                
                y = torch.tensor(y_base, dtype=torch.long)
                x = torch.tensor(x, dtype=torch.long)

                seq_tokens = self.context_len

                if skipping:
                    self.skip_tokens -= seq_tokens
                    if self.skip_tokens <= 0:
                        skipping = False
                        print(f"Skipped ~{token_count:,} tokens; resuming.")
                    buffer = buffer[self.stride:]
                    token_count += seq_tokens
                else:
                    yield {'input_ids': x, 'labels': y}
                    buffer = buffer[self.stride:]
                    token_count += seq_tokens
                    if token_count >= self.max_tokens:
                        return

                if len(buffer) > 2 * self.context_len:
                    buffer = buffer[-self.context_len:]

        # Remnant: Same padding logic
        if not skipping and len(buffer) >= 128:
            x = buffer[:self.context_len]
            y_base = buffer[1:min(self.context_len + 1, len(buffer) + 1)]
            num_pads = self.context_len - len(y_base)
            if num_pads > 0:
                x += [self.pad_id] * num_pads
                y_base += [pad_val_labels] * num_pads
            y = torch.tensor(y_base, dtype=torch.long)
            x = torch.tensor(x, dtype=torch.long)
            yield {'input_ids': x, 'labels': y}
            
# Usage
dataset = SlidingWindowDataset(ds, tokenizer, context_len=CFG['context_length'], stride=CFG["context_length"] // 2, target_tokens=CFG['max_tokens'], skip_tokens=tokens_seen)
dataloader = DataLoader(dataset, batch_size=32, num_workers=2, pin_memory=True, prefetch_factor=2, collate_fn=lambda b: {k: torch.stack([d[k] for d in b]) for k in b[0]})

In [12]:
dataset = SlidingWindowDataset(ds, tokenizer, context_len=CFG['context_length'], stride=CFG["context_length"] // 2, target_tokens=CFG['max_tokens'], skip_tokens=tokens_seen)
dataloader = DataLoader(dataset, batch_size=32, num_workers=2, pin_memory=True, prefetch_factor=2, collate_fn=lambda b: {k: torch.stack([d[k] for d in b]) for k in b[0]})

valid_dataset = SlidingWindowDataset(val_raw, tokenizer, context_len=CFG['context_length'], 
                                     stride=CFG["context_length"] // 2, target_tokens=CFG['val_max_tokens'])
valid_dataloader = DataLoader(valid_dataset, batch_size=32, num_workers=2, pin_memory=True, 
                              prefetch_factor=2, collate_fn=lambda b: {k: torch.stack([d[k] for d in b]) for k in b[0]})

In [13]:
# ------------------- 7. OPTIMISER + SCALER -------------------
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CFG["lr"],
    betas=CFG["betas"],
    weight_decay=CFG["weight_decay"],
    fused=True,                          # works on P100
    eps=CFG["eps"],               
)

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# ------------------- 8. LR SCHEDULER -------------------
total_train_tokens = CFG["max_tokens"]
warmup_tokens      = CFG["warmup_tokens"]
warmup_tokens = CFG["warmup_tokens"]
tokens_per_step = CFG["tokens_per_step"]
warmup_steps = warmup_tokens // tokens_per_step  # ~122 steps
total_steps = total_train_tokens // tokens_per_step  # ~24k steps

def lr_lambda(step):  # Now takes step, computes tokens internally
    tokens_current = tokens_seen + (step * tokens_per_step)
    if tokens_current < warmup_tokens:  # Use < to avoid float issues
        return tokens_current / warmup_tokens
    progress = (tokens_current - warmup_tokens) / (total_train_tokens - warmup_tokens)
    cosine = 0.5 * (1.0 + math.cos(math.pi * min(1.0, progress)))
    return (CFG["final_lr"] / CFG["lr"]) + (1.0 - CFG["final_lr"] / CFG["lr"]) * cosine

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])  # Restores any internal state

print("Model, optimizer, and scheduler loaded successfully.")
print(f"LR Scheduler: Warmup ~{warmup_steps:,} steps ({warmup_tokens:,} tokens) → Cosine decay to {CFG['final_lr']:.1e}")

Model, optimizer, and scheduler loaded successfully.
LR Scheduler: Warmup ~488 steps (2,000,000 tokens) → Cosine decay to 6.0e-05


In [14]:
def validate(model, loader):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    with torch.no_grad():
        for batch in loader:
            batch = {k: v.to(CFG["device"], non_blocking=True) for k, v in batch.items()}
            with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
                logits = model(batch["input_ids"])
                loss = criterion(
                    logits.view(-1, logits.size(-1)),
                    batch["labels"].view(-1),
                )
            num_tokens = (batch["labels"] != -100).sum().item()
            total_loss += loss.item() * num_tokens
            total_tokens += num_tokens
    model.train()
    return total_loss / total_tokens

# Checkpoint dirs
os.makedirs("checkpoints", exist_ok=True)

In [15]:
# ------------------- 9. TRAINING LOOP -------------------
criterion = nn.CrossEntropyLoss(ignore_index=-100)
start_time = time.time()

print(f"Resuming training from {tokens_seen:,} tokens...")
for batch in dataloader:
    step += 1
    batch = {k: v.to(CFG["device"], non_blocking=True) for k, v in batch.items()}

    # ---- forward + loss -------------------------------------------------
    with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
        logits = model(batch["input_ids"])
        loss = criterion(
            logits.view(-1, logits.size(-1)),
            batch["labels"].view(-1),
        )
        if not torch.isfinite(loss):
            print(f"Non-finite loss at step {step}: {loss.item()} – Skipping.")
            optimizer.zero_grad(set_to_none=True)
            continue

    # ---- backward -------------------------------------------------------
    # scaler.scale(loss).backward()
    loss.backward()

    # Gradient clipping – caps norms to prevent explosion
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    # scaler.step(optimizer)
    # scaler.update()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

    # ---- LR step (per token, not per step) -------------------------------
    tokens_seen += batch["input_ids"].numel()
    scheduler.step()                     # LambdaLR uses the *current* token count

    # ---- logging ---------------------------------------------------------
    if step % CFG["log_interval"] == 0:
        elapsed = time.time() - start_time
        tokens_per_sec = tokens_seen / elapsed
        lr = optimizer.param_groups[0]["lr"]

        wandb.log({
            "step": step,
            "loss": loss.item(),
            "lr": lr,
            "tokens_seen": tokens_seen,
            "tokens_per_sec": tokens_per_sec,
            "gpu_mem_gb": torch.cuda.max_memory_allocated() / 1e9,
        }, step=step)
        # print(
        #     f"Step {step:5d} | "
        #     f"Loss {loss.item():.4f} | "
        #     f"LR {lr:.2e} | "
        #     f"Tokens {tokens_seen:,}/{CFG['max_tokens']:,} | "
        #     f"Speed {tokens_per_sec:,.0f} t/s"
        # )

    # ---- early stop ------------------------------------------------------
    if tokens_seen >= CFG["max_tokens"]:
        print("\nReached target token count → stopping.")
        break

# --------------------------------------------------------------
wandb.finish()
print("Training finished!")
# --------------------------------------------------------------

Resuming training from 100,016,128 tokens...


Token indices sequence length is longer than the specified maximum sequence length for this model (2530 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1085 > 1024). Running this sequence through the model will result in indexing errors


Skipped ~100,015,872 tokens; resuming.
Skipped ~100,015,872 tokens; resuming.

Reached target token count → stopping.


0,1
gpu_mem_gb,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,▂▃▅▅▆▄▂█▄▄▃▅▃▃▂▇▅▄▃▅▇▄▅▂▄▄▆▃▁▃▄▄▆▄▄▄▄▂▂▄
lr,██▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▆▇▇▇▇▇▇▇▇███
tokens_per_sec,█▇▆▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
tokens_seen,▁▁▁▁▁▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇██

0,1
gpu_mem_gb,14.34544
loss,4.19841
lr,6e-05
step,24400.0
tokens_per_sec,11996.00085
tokens_seen,199876608.0


Training finished!


In [17]:
# ---- checkpoint + validation -----------------------------------------
# if tokens_seen >= CFG["checkpoint_interval"] + CFG["checkpoint_interval"]:
# Save checkpoint
# checkpoint_path = f"checkpoints/checkpoint_{tokens_seen}.pth"
# torch.save({
#     'model_state_dict': model.state_dict(),
#     'optimizer_state_dict': optimizer.state_dict(),
#     'scheduler_state_dict': scheduler.state_dict(),
#     'tokens_seen': tokens_seen,
#     'step': step,
# }, checkpoint_path)
# print(f"Saved checkpoint at {tokens_seen:,} tokens: {checkpoint_path}")

# Run validation (processes ~10k tokens via val_max_tokens)
val_loss = validate(model, valid_dataloader)
# wandb.log({"val_loss": val_loss, "tokens_seen": tokens_seen}, step=step)
print(f"Validation loss: {val_loss:.4f} (over {CFG['val_max_tokens']:,} tokens)")

Too many dataloader workers: 2 (max is dataset.num_shards=1). Stopping 1 dataloader workers.


Validation loss: 6.0194 (over 10,000 tokens)


In [18]:
import os, glob

ckpt_dir = "/kaggle/working/checkpoints"
print("=== RAW ls ===")
!ls -lh {ckpt_dir}

print("\n=== Python list ===")
files = sorted(glob.glob(f"{ckpt_dir}/*.pth"))
for f in files:
    size = os.path.getsize(f) / (1024**2)
    print(f"{f}  ({size:.2f} MB)")

=== RAW ls ===
total 709M
-rw-r--r-- 1 root root 709M Nov  6 15:56 checkpoint_200007680.pth

=== Python list ===
/kaggle/working/checkpoints/checkpoint_200007680.pth  (708.71 MB)


In [19]:
import os
import zipfile

ckpt_path = "/kaggle/working/checkpoints/checkpoint_200007680.pth"
zip_path = "/kaggle/working/checkpoint_200M.zip"

# Create zip
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    zipf.write(ckpt_path, os.path.basename(ckpt_path))

# Verify size
zip_size_mb = os.path.getsize(zip_path) / (1024**2)
print(f"Zipped: {zip_path} ({zip_size_mb:.2f} MB)")

Zipped: /kaggle/working/checkpoint_200M.zip (543.76 MB)


In [29]:
# Fixed generation function (top-k sampling)
def generate(model, tokenizer, prompt, max_new_tokens=50, top_k=50, temperature=0.8, pad_token_id=None):
    if pad_token_id is None:
        pad_token_id = tokenizer.eos_token_id
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            with autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.bfloat16, enabled=True):
                logits = model(input_ids)[:, -1, :]  # Last position logits
                logits = logits / temperature
                # Top-k filtering
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                probs = torch.softmax(logits, dim=-1)
                mask = logits[0] < v[0].min()  # 1D boolean [vocab]
                probs[0][mask] = 0  # Set low-prob to 0
                next_token = torch.multinomial(probs, num_samples=1)
                input_ids = torch.cat([input_ids, next_token], dim=-1)
                if next_token.item() == pad_token_id:
                    break  # Stop at EOS
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Test prompts (web/edu themed)
prompts = [
    "Web hosting is essential for",
    "Machine learning models train on datasets like",
    "A good developer should know",
    "FineWeb-Edu is a filtered version of",
    "The future of AI in education involves",
    "Why do you think people follow me"
]

print("=== Generation Tests (Final Train Loss: 4.19 | PPL: {:.0f}) ===".format(math.exp(4.19)))
for prompt in prompts:
    with torch.no_grad():  # Per-prompt for safety
        generated = generate(model, tokenizer, prompt, max_new_tokens=50, top_k=50, temperature=0.8)
        continuation = generated[len(prompt):].strip()  # Continuation only
        print(f"\nPrompt: {prompt}")
        print(f"Output: {continuation}")
        print("-" * 80)

=== Generation Tests (Final Train Loss: 4.19 | PPL: 66) ===

Prompt: Web hosting is essential for
Output: businesses and businesses alike, as it promotes a culture where people can interact and feel like they can’t take it home. This article will show why web sites and online platforms offer a range of opportunities, from the rise of new and new platforms
--------------------------------------------------------------------------------

Prompt: Machine learning models train on datasets like
Output: the GIS and Google Data. This model also helps to develop the tools to learn and create a more accurate analysis of user experience for a wide range of applications (e.g., virtual reality, virtual reality, or online devices), and so it
--------------------------------------------------------------------------------

Prompt: A good developer should know
Output: when they use the correct code.
In terms of a good user-friendly software, your software can be used for small, dynamic, easy, and con