In [None]:
!pip install --upgrade transformers datasets tokenizers accelerate huggingface-hub

In [None]:
pip install -U datasets

In [None]:
!pip install torch torchvision torchaudio
!pip install transformers==4.35.0
!pip install datasets==2.14.0
!pip install accelerate==0.24.0
!pip install wandb
!pip install tiktoken
!pip install matplotlib pandas numpy
!pip install gradio

In [None]:
# Upgrade datasets to latest version
!pip install datasets --upgrade

In [None]:
# Load datasets
from datasets import load_dataset
tinygsm_dataset = load_dataset("TinyGSM/TinyGSM")
gsm8k_dataset = load_dataset("gsm8k", "main")  # for evaluation

print(f"TinyGSM train size: {len(tinygsm_dataset['train'])}")
print(f"GSM8k train size {len(gsm8k_dataset['train'])}")
print(f"GSM8k test size {len(gsm8k_dataset['test'])}")

print("\nTinyGSM example:")
print(tinygsm_dataset['train'][0])

print("\nGSM8k example:")
print(gsm8k_dataset['train'][0])

print("Keys available in TinyGSM dataset example:", tinygsm_dataset['train'][0].keys())

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
import numpy as np
from datasets import load_dataset, DatasetDict
import json
from pathlib import Path
from tqdm.auto import tqdm
from contextlib import nullcontext
import tiktoken
import random

In [None]:
def get_result_from_code(code_str: str):
    local_vars = {}
    try:
        # Execute the code string which defines a function
        exec(code_str, globals(), local_vars)
        import re
        function_name_match = re.search(r'def\s+(\w+)\s*\(', code_str)
        if function_name_match:
            function_name = function_name_match.group(1)
        else:
            return "Execution Failed: Function name not found."
        result_func = local_vars[function_name]
        final_result = result_func()
        return final_result
    except Exception as e:
        return f"Code Execution Error: {e}"

# Visualize a few examples
for i in range(3):
    ex = tinygsm_dataset['train'][i]
    print(f"\n{'='*60}")
    print(f"Example {i+1}:")
    print(f"Question: {ex['question']}")
    code_str = ex['code']
    print(f"Code:\n{code_str}")
    answer = get_result_from_code(code_str)
    print(f"Calculated Answer: {answer}")


In [None]:
!pip install tiktoken

In [None]:
VOCAB_SIZE = 50257
N_EMBD = 384       # reduce for ~80M params
N_LAYER = 12
N_HEAD = 8
SEQ_LEN = 128

DATA_DIR = Path("./tinygsm_data")
DATA_DIR.mkdir(parents=True, exist_ok=True)
CACHE_DIR = Path("./tokenized_cache")
CACHE_DIR.mkdir(parents=True, exist_ok=True)
GRAD_ACCUM_STEPS = 16
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.1
WARMUP_STEPS = 500
MAX_STEPS = 20000
EVAL_INTERVAL = 500
SEED = 42
CHECKPOINT_DIR = Path("./checkpoints")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
NUM_WORKERS = 4
NUM_PROC = 11
BATCH_SIZE = 25000
WRITER_BATCH_SIZE = 10000

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

In [None]:
enc = tiktoken.get_encoding("gpt2")
print("tiktoken vocab size (n_vocab):", getattr(enc, "n_vocab", VOCAB_SIZE))

def format_io(question: str, answer: str):
    input_pref = f"Question: {question}\nAnswer:"
    answer_text = str(answer).strip()
    return input_pref, answer_text

def tokenize_example(example):
    try:
        q_text, a_text = format_io(example['question'], example['code'])
        q_ids = enc.encode(q_text)
        a_ids = enc.encode(a_text)
        ids = q_ids + a_ids

        labels = [-100] * len(q_ids) + a_ids
        return {"ids": ids, "labels": labels, "len": len(ids)}
    except Exception as e:
        print(f"Skipping example due to error: {e}")
        return {"ids": [], "labels": [], "len": 0}

In [None]:
print(tinygsm_dataset['train'][9])

In [None]:
import os
print(f"Available CPU cores: {os.cpu_count()}")

In [None]:
from datasets import load_from_disk
def batch_tokenize_example(batch):
    """Optimized tokenization function"""
    ids_list = []
    labels_list = []
    len_list = []

    for question, code in zip(batch["question"], batch["code"]):
        q_text, a_text = format_io(question, code)
        q_ids = enc.encode(q_text)
        a_ids = enc.encode(a_text)

        ids = q_ids + a_ids
        labels = [-100] * len(q_ids) + a_ids

        ids_list.append(ids)
        labels_list.append(labels)
        len_list.append(len(ids))

    return {"ids": ids_list, "labels": labels_list, "len": len_list}


def get_or_create_tokenized_dataset(dataset, split="train", force_retokenize=False):
    """
    Load cached tokenized dataset or create new one.
    This ensures you only tokenize once!
    """
    cache_path = CACHE_DIR / f"{split}_tokenized"

    if cache_path.exists() and not force_retokenize:
        print(f" Loading cached tokenized dataset from {cache_path}")
        tokenized = load_from_disk(str(cache_path))
        print(f" Loaded {len(tokenized)} examples from cache")
        return tokenized

    print(f"Tokenizing {split} dataset (this will be cached)...")
    tokenized = dataset.map(
        batch_tokenize_example,
        remove_columns=dataset.column_names,
        batched=True,
        batch_size=BATCH_SIZE,
        num_proc=NUM_PROC,
        desc=f"tokenizing {split}",
        writer_batch_size=WRITER_BATCH_SIZE,
        load_from_cache_file=True
    )

    # Save to cache
    print(f"Saving tokenized dataset to cache: {cache_path}")
    tokenized.save_to_disk(str(cache_path))
    print(f"Cached tokenized dataset for future use!")

    return tokenized


def build_memmap_fast(dataset_dict, splits=("train",), overwrite=False, use_cache=True):
    """
    Fast memmap builder with caching support.
    Only tokenizes once, reuses cached data on subsequent runs.
    """
    for split in splits:
        out_ids = DATA_DIR / f"{split}_ids.bin"
        out_labels = DATA_DIR / f"{split}_labels.bin"
        meta_json = DATA_DIR / f"{split}_meta.json"

        # Check if memmap already exists
        if out_ids.exists() and out_labels.exists() and not overwrite:
            print(f"{split} memmap already exists, skipping (use overwrite=True to rebuild).")

            # Load and print stats if metadata exists
            if meta_json.exists():
                with open(meta_json, "r") as f:
                    meta = json.load(f)
                    print(f"  Total tokens: {meta['total_len']:,}")
            continue

        print(f"\n{'='*60}")
        print(f"Processing split: {split}")
        print(f"{'='*60}")

        ds = dataset_dict[split]
        print(f"Dataset size: {len(ds):,} examples")

        # Use cached tokenization or create new
        tokenized = get_or_create_tokenized_dataset(
            ds,
            split=split,
            force_retokenize=not use_cache
        )

        # Calculate total length
        print("Calculating total token count...")
        total_len = int(np.sum(np.asarray(tokenized["len"]), dtype=np.uint64))
        print(f" Total tokens for {split}: {total_len:,}")

        if total_len <= 0:
            print(f"⚠ No data to write for split {split}. Skipping memmap creation.")
            continue

        # Data types for memmap
        ids_dtype = np.uint16
        lbls_dtype = np.int32

        # Calculate expected file sizes
        ids_size_gb = (total_len * np.dtype(ids_dtype).itemsize) / (1024**3)
        lbls_size_gb = (total_len * np.dtype(lbls_dtype).itemsize) / (1024**3)
        print(f"Expected file sizes:")
        print(f"  - IDs: {ids_size_gb:.2f} GB")
        print(f"  - Labels: {lbls_size_gb:.2f} GB")
        print(f"  - Total: {ids_size_gb + lbls_size_gb:.2f} GB")

        # Create memmaps
        print(f"\nCreating memmap files...")
        out_ids.parent.mkdir(parents=True, exist_ok=True)
        ids_arr = np.memmap(out_ids, dtype=ids_dtype, mode="w+", shape=(total_len,))
        labels_arr = np.memmap(out_labels, dtype=lbls_dtype, mode="w+", shape=(total_len,))

        # Write data efficiently - iterate through dataset batches
        print(f"Writing {total_len:,} tokens to disk...")
        idx = 0
        batch_size = 100000  # Process 100k examples at a time

        for i in range(0, len(tokenized), batch_size):
            batch_end = min(i + batch_size, len(tokenized))
            batch = tokenized.select(range(i, batch_end)).with_format("numpy")

            # Write each example in the batch
            for j in range(len(batch)):
                example_ids = batch["ids"][j]
                example_labels = batch["labels"][j]
                n = len(example_ids)

                ids_arr[idx:idx + n] = example_ids
                labels_arr[idx:idx + n] = example_labels
                idx += n

            # Progress update every batch
            progress = 100 * idx / total_len
            print(f"  Progress: {idx:,}/{total_len:,} tokens ({progress:.1f}%) - Processed {batch_end:,}/{len(tokenized):,} examples")

            # Flush periodically
            if i % (batch_size * 5) == 0:
                ids_arr.flush()
                labels_arr.flush()

        # Final flush
        ids_arr.flush()
        labels_arr.flush()
        print(f" Finished writing {idx:,} tokens")

        # Save metadata
        meta = {"total_len": total_len}
        with open(meta_json, "w") as f:
            json.dump(meta, f)

        print(f"\n{'='*60}")
        print(f" SUCCESS! Memmap files created for {split}")
        print(f"  - IDs: {out_ids}")
        print(f"  - Labels: {out_labels}")
        print(f"  - Metadata: {meta_json}")
        print(f"{'='*60}\n")

build_memmap_fast(
    tinygsm_dataset,
    splits=("train",),
    overwrite=True,  # Set True to rebuild memmaps
    use_cache=True    # Set False to force re-tokenization
)

print("\n" + "="*60)
print("COMPLETE! Your data is ready for training.")
print("="*60)
print("\nNext runs will be MUCH faster because tokenization is cached!")
print("To force re-tokenization: build_memmap_fast(..., use_cache=False)")
print("To rebuild memmaps: build_memmap_fast(..., overwrite=True)")

In [None]:
import torch
print("GPU Available:" if torch.cuda.is_available() else "GPU Not Available - Check back later")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
import numpy as np
from tqdm.auto import tqdm
from contextlib import nullcontext
import os
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import time
class LayerNorm(nn.Module):
    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, x):
        return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)


class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.flash = hasattr(F, 'scaled_dot_product_attention')
        if not self.flash:
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                       .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        if self.flash:
            y = F.scaled_dot_product_attention(q, k, v, attn_mask=None,
                                              dropout_p=self.attn_dropout.p if self.training else 0.0,
                                              is_causal=True)
        else:
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = LayerNorm(config.n_embd, config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln2 = LayerNorm(config.n_embd, config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


@dataclass
class GPTConfig:
    block_size: int = 512
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.1
    bias: bool = True


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte=nn.Embedding(config.vocab_size, config.n_embd),
            wpe=nn.Embedding(config.block_size, config.n_embd),
            drop=nn.Dropout(config.dropout),
            h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f=LayerNorm(config.n_embd, config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight  # weight tying

        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size
        pos = torch.arange(0, t, dtype=torch.long, device=device)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
                                 targets.view(-1), ignore_index=-100)
            return logits, loss
        else:
            logits = self.lm_head(x[:, [-1], :])
            return logits, None

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

    def get_num_params(self, non_embedding=True):
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

class MemmapDataset(Dataset):
    def __init__(self, ids_path, labels_path, seq_length=512):
        self.ids = np.memmap(ids_path, dtype=np.uint16, mode='r')
        self.labels = np.memmap(labels_path, dtype=np.int32, mode='r')
        self.seq_length = seq_length
        self.total_length = len(self.ids)

    def __len__(self):
        return (self.total_length - self.seq_length) // self.seq_length

    def __getitem__(self, idx):
        start = idx * self.seq_length
        end = start + self.seq_length

        ids = torch.from_numpy(self.ids[start:end].astype(np.int64))
        labels = torch.from_numpy(self.labels[start:end].astype(np.int64))

        return ids, labels

# Create dataset and dataloader
dataset = MemmapDataset(
    DATA_DIR / "train_ids.bin",
    DATA_DIR / "train_labels.bin",
    seq_length=512  # Adjust based on model's context length
)

dataloader = DataLoader(
    dataset,
    batch_size=32,  # Adjust based on GPU memory
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

print(f"Dataset ready: {len(dataset):,} sequences")

In [None]:
from pathlib import Path
DATA_DIR = Path("/content/drive/MyDrive/tinygsm_data")
DATA_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
def estimate_loss_tinygsm(model, get_batch, eval_iters=200, ctx=torch.no_grad()):
    losses = {}
    model.eval()
    with ctx:
        for split in ['train', 'val']:
            loss_vals = torch.zeros(eval_iters)
            for i in range(eval_iters):
                X, Y = get_batch(split)
                logits, loss = model(X, Y)
                loss_vals[i] = loss.item()
            losses[split] = loss_vals.mean().item()
    model.train()
    return losses

In [None]:
class TrainingConfig:
    # Model
    block_size: int = 512
    vocab_size: int = 50257
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    dropout: float = 0.1
    bias: bool = True

    # Training
    batch_size: int = 32
    learning_rate: float = 6e-4
    max_iters: int = 50000
    weight_decay: float = 0.1
    beta1: float = 0.9
    beta2: float = 0.95
    grad_clip: float = 1.0

    # Learning rate schedule
    warmup_steps: int = 2000
    min_lr: float = 6e-5

    # Evaluation
    eval_interval: int = 500
    eval_iters: int = 200

    # Logging
    log_interval: int = 10

    # Checkpointing
    checkpoint_dir: str = "/content/drive/MyDrive/tinygsm_checkpoints"
    save_interval: int = 5000

    # Data
    data_dir: str = "/content/drive/MyDrive/tinygsm_data"

    # System
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    dtype: str = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
    compile: bool = False

In [None]:
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

def compute_loss_and_backprop(model, X, Y, optimizer, scaler=None):
    model.train()
    X = X.to(DEVICE)
    Y = Y.to(DEVICE)
    logits = model(X)
    B, T, V = logits.shape
    loss = loss_fn(logits.view(B*T, V), Y.view(B*T))
    if scaler is not None:
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    else:
        loss.backward()
        optimizer.step()
    return loss.item()

In [None]:
!mkdir -p /content/drive/MyDrive/tinygsm_data

In [None]:
!cp /teamspace/studios/this_studio/tinygsm_data/* /content/drive/MyDrive/tinygsm_data/

In [None]:
from pathlib import Path
DATA_DIR = Path("/teamspace/studios/this_studio/tinygsm_data")

print("train_ids.bin exists:", (DATA_DIR/"train_ids.bin").exists())
print("train_labels.bin exists:", (DATA_DIR/"train_labels.bin").exists())

In [None]:
from pathlib import Path
import os
import torch

DATA_DIR = Path("/teamspace/studios/this_studio/tinygsm_data")

print("="*60)
print("QUICK VERIFICATION")
print("="*60)

# Check GPU
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

# Check data files
print(f"\nData directory: {DATA_DIR}")
for file in ["train_ids.bin", "train_labels.bin"]:
    path = DATA_DIR / file
    if path.exists():
        size_mb = os.path.getsize(path) / (1024**2)
        print(f" {file}: {size_mb:.2f} MB")
    else:
        print(f" {file}: NOT FOUND")

print("="*60)

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from pathlib import Path
import time
import math
import sys
from contextlib import nullcontext
from tqdm import tqdm


def get_lr(it, config):
    """Learning rate schedule with warmup and cosine decay"""
    if it < config.warmup_iters:
        return config.learning_rate * it / config.warmup_iters
    if it > config.lr_decay_iters:
        return config.min_lr
    decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return config.min_lr + coeff * (config.learning_rate - config.min_lr)


@torch.no_grad()
def estimate_loss(model, train_loader, val_loader, config, ctx, eval_iters=50):
    """Estimate loss on train and validation sets"""
    out = {}
    model.eval()

    for split, loader in [('train', train_loader), ('val', val_loader)]:
        if loader is None:
            continue
        losses = []
        data_iter = iter(loader)
        for _ in range(min(eval_iters, len(loader))):
            try:
                X, Y = next(data_iter)
            except StopIteration:
                break
            X, Y = X.to(config.device), Y.to(config.device)
            with ctx:
                logits, loss = model(X, Y)
            losses.append(loss.item())
        out[split] = sum(losses) / len(losses) if losses else 0.0

    model.train()
    return out


def train(resume_from=None):
    """Main training function - matches TinyStories pattern"""
    config = TrainingConfig()

    print("="*70)
    print("TinyGSM Training - Full Dataset Mode")
    print("="*70)
    print(f"Model: {config.n_layer} layers, {config.n_head} heads, {config.n_embd} embd")
    print(f"Block size: {config.block_size}")
    print(f"Batch size: {config.batch_size} × {config.gradient_accumulation_steps} = {config.batch_size * config.gradient_accumulation_steps} effective")
    print(f"Learning rate: {config.learning_rate} → {config.min_lr}")
    print(f"Max iterations: {config.max_iters:,}")

    # Calculate total tokens
    tokens_per_iter = config.batch_size * config.gradient_accumulation_steps * config.block_size
    total_tokens = tokens_per_iter * config.max_iters
    print(f"Total tokens to process: {total_tokens:,} ({total_tokens/1e9:.2f}B)")
    print(f"Data directory: {config.data_dir}")
    print("="*70)

    # Device setup
    if torch.cuda.is_available():
        device = 'cuda'
        print(f" Using GPU: {torch.cuda.get_device_name(0)}")
    else:
        device = 'cpu'
        print(" Using CPU")
    config.device = device

    # Precision setup
    dtype_map = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}
    ptdtype = dtype_map.get(config.dtype, torch.float32)
    ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type='cuda', dtype=ptdtype)
    use_amp = (config.dtype == 'float16') and (device == 'cuda')

    # Load datasets
    print("\nLoading datasets...")
    DATA_DIR = Path(config.data_dir)

    if not DATA_DIR.exists():
        raise FileNotFoundError(f"Data directory not found: {DATA_DIR}")

    train_dataset = MemmapDataset(
        DATA_DIR / "train_ids.bin",
        DATA_DIR / "train_labels.bin",
        seq_length=config.block_size
    )

    # Check if validation exists
    val_loader = None
    if (DATA_DIR / "val_ids.bin").exists():
        val_dataset = MemmapDataset(
            DATA_DIR / "val_ids.bin",
            DATA_DIR / "val_labels.bin",
            seq_length=config.block_size
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=0,
            pin_memory=(device == 'cuda')
        )
        print(f" Validation dataset: {len(val_dataset):,} sequences")

    # Training dataloader with proper generator
    if device == 'cuda':
        g = torch.Generator(device='cuda')
        g.manual_seed(42)
    else:
        g = torch.Generator()
        g.manual_seed(42)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=0,
        generator=g,
        pin_memory=(device == 'cuda')
    )

    print(f" Training dataset: {len(train_dataset):,} sequences")
    print(f" Training batches per epoch: {len(train_loader):,}")

    # Initialize model
    print("\nInitializing model...")
    model_config = GPTConfig(
        block_size=config.block_size,
        vocab_size=config.vocab_size,
        n_layer=config.n_layer,
        n_head=config.n_head,
        n_embd=config.n_embd,
        dropout=config.dropout,
        bias=config.bias
    )

    model = GPT(model_config).to(device)
    model.train()
    print(f" Model parameters: {model.get_num_params()/1e6:.2f}M")

    # Compile model
    if config.compile and hasattr(torch, 'compile'):
        try:
            print("Compiling model...")
            model = torch.compile(model)
            print(" Model compiled")
        except Exception as e:
            print(f" Compile failed: {e}")

    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        betas=(config.beta1, config.beta2),
        weight_decay=config.weight_decay
    )

    # Gradient scaler
    scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
    if use_amp:
        print(" Mixed precision training enabled")

    # Resume from checkpoint
    iter_num = 0
    best_val_loss = float('inf')
    train_loss_list = []
    val_loss_list = []

    if resume_from:
        checkpoint_path = Path(config.checkpoint_dir) / resume_from
        if checkpoint_path.exists():
            print(f"\nLoading checkpoint: {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            iter_num = checkpoint.get('iter_num', 0)
            best_val_loss = checkpoint.get('best_val_loss', float('inf'))
            train_loss_list = checkpoint.get('train_loss_list', [])
            val_loss_list = checkpoint.get('val_loss_list', [])
            print(f" Resumed from iteration {iter_num}")

    # Training loop
    print("\n" + "="*70)
    print("Starting Training")
    print("="*70 + "\n")

    t0 = time.time()
    running_loss = 0.0
    local_iter_num = 0
    epoch = 0

    sys.stdout.flush()

    try:
        # Main iteration loop
        pbar = tqdm(total=config.max_iters, initial=iter_num, desc="Training")

        while iter_num < config.max_iters:
            # Iterate through entire dataset each epoch
            for X, Y in train_loader:
                if iter_num >= config.max_iters:
                    break

                # Update learning rate
                lr = get_lr(iter_num, config)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

                # Evaluation
                if iter_num % config.eval_interval == 0 and iter_num > 0:
                    losses = estimate_loss(model, train_loader, val_loader, config, ctx)

                    pbar.write(f"\n{'='*70}")
                    pbar.write(f"Iteration {iter_num}/{config.max_iters} ({100*iter_num/config.max_iters:.1f}%)")
                    pbar.write(f"Train loss: {losses['train']:.4f}", end="")
                    if 'val' in losses:
                        pbar.write(f" | Val loss: {losses['val']:.4f}")
                        val_loss_list.append(losses['val'])

                        # Save best model
                        if losses['val'] < best_val_loss:
                            best_val_loss = losses['val']
                            best_path = Path(config.checkpoint_dir) / 'best_model.pt'
                            best_path.parent.mkdir(parents=True, exist_ok=True)
                            torch.save({
                                'model': model.state_dict(),
                                'optimizer': optimizer.state_dict(),
                                'iter_num': iter_num,
                                'best_val_loss': best_val_loss,
                                'train_loss_list': train_loss_list,
                                'val_loss_list': val_loss_list,
                                'config': config,
                            }, best_path)
                            pbar.write(f" New best validation loss: {best_val_loss:.4f}")
                    else:
                        pbar.write("")

                    pbar.write(f"Learning rate: {lr:.2e}")
                    train_loss_list.append(losses['train'])

                    elapsed = (time.time() - t0) / 3600
                    remaining = (elapsed / max(1, iter_num)) * (config.max_iters - iter_num)
                    pbar.write(f"Time: {elapsed:.2f}h elapsed | ~{remaining:.2f}h remaining")
                    pbar.write(f"{'='*70}")
                    sys.stdout.flush()

                # Save checkpoint
                if iter_num % config.save_interval == 0 and iter_num > 0:
                    ckpt_dir = Path(config.checkpoint_dir)
                    ckpt_dir.mkdir(parents=True, exist_ok=True)

                    checkpoint = {
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'iter_num': iter_num,
                        'best_val_loss': best_val_loss,
                        'train_loss_list': train_loss_list,
                        'val_loss_list': val_loss_list,
                        'config': config,
                    }
                    torch.save(checkpoint, ckpt_dir / f'ckpt_iter_{iter_num}.pt')
                    torch.save(checkpoint, ckpt_dir / 'ckpt_latest.pt')
                    pbar.write(f" Checkpoint saved at iteration {iter_num}")

                # Training step with gradient accumulation
                optimizer.zero_grad(set_to_none=True)

                for micro_step in range(config.gradient_accumulation_steps):
                    X_batch = X.to(device, non_blocking=True)
                    Y_batch = Y.to(device, non_blocking=True)

                    with ctx:
                        logits, loss = model(X_batch, Y_batch)
                        loss = loss / config.gradient_accumulation_steps

                    scaler.scale(loss).backward()
                    running_loss += loss.item()

                # Gradient clipping
                if config.grad_clip != 0.0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)

                # Optimizer step
                scaler.step(optimizer)
                scaler.update()

                # Logging
                if iter_num % config.log_interval == 0:
                    lossf = running_loss * config.gradient_accumulation_steps
                    pbar.set_postfix({'loss': f'{lossf:.4f}', 'lr': f'{lr:.2e}'})
                    running_loss = 0.0

                iter_num += 1
                local_iter_num += 1
                pbar.update(1)

            epoch += 1
            pbar.write(f" Completed epoch {epoch}")
            sys.stdout.flush()

        pbar.close()

    except KeyboardInterrupt:
        print("\n Training interrupted by user")
    except Exception as e:
        print(f"\n Training failed: {e}")
        import traceback
        traceback.print_exc()
        raise

    # Save final checkpoint
    print("\n" + "="*70)
    print("Training Complete!")
    print("="*70)

    ckpt_dir = Path(config.checkpoint_dir)
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    final_checkpoint = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'iter_num': iter_num,
        'best_val_loss': best_val_loss,
        'train_loss_list': train_loss_list,
        'val_loss_list': val_loss_list,
        'config': config,
    }
    torch.save(final_checkpoint, ckpt_dir / 'ckpt_final.pt')

    print(f" Final checkpoint saved")
    print(f" Total iterations: {iter_num:,}")
    print(f" Total epochs: {epoch}")
    if val_loss_list:
        print(f" Best validation loss: {best_val_loss:.4f}")
    if train_loss_list:
        print(f" Final training loss: {train_loss_list[-1]:.4f}")

    total_time = (time.time() - t0) / 3600
    print(f" Total training time: {total_time:.2f} hours")
    print("="*70)

    return {
        'train_loss_list': train_loss_list,
        'val_loss_list': val_loss_list,
        'best_val_loss': best_val_loss
    }


if __name__ == "__main__":
    results = train()

In [None]:
!ls -lh checkpoints/