In [None]:
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)
# !mkdir -p /content/drive/MyDrive/slm
# !cd /content/drive/MyDrive/slm
# !pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# !pip install -q datasets tiktoken tqdm

Mounted at /content/drive


In [9]:
import os
from tqdm import tqdm
import numpy as np
import tiktoken
from datasets import load_dataset

SUBSET_PERCENT = 1        # 1% of OpenWebText
VAL_FRACTION = 0.01       # 1% of subset for validation

out_dir = "data"
os.makedirs(out_dir, exist_ok=True)

enc = tiktoken.get_encoding("gpt2")

# ---- compute absolute sample count ----
# OpenWebText ≈ 8M examples
TOTAL_EXAMPLES = 8_000_000
NUM_SAMPLES = int(TOTAL_EXAMPLES * SUBSET_PERCENT / 100)
NUM_VAL = int(NUM_SAMPLES * VAL_FRACTION)
NUM_TRAIN = NUM_SAMPLES - NUM_VAL

print(f"Loading {NUM_SAMPLES:,} samples (streaming)")

dataset = load_dataset(
    "openwebtext",
    split="train",
    streaming=True
)

dataset = dataset.take(NUM_SAMPLES)

train_tokens = []
val_tokens = []

def encode(text):
    ids = enc.encode(text)
    ids.append(enc.eot_token)
    return ids

print("Tokenizing & splitting")

for i, ex in enumerate(tqdm(dataset, total=NUM_SAMPLES)):
    ids = encode(ex["text"])
    if i < NUM_VAL:
        val_tokens.extend(ids)
    else:
        train_tokens.extend(ids)

# ---- write binaries ----
def write_bin(path, tokens):
    arr = np.memmap(
        path,
        dtype=np.uint16,
        mode="w+",
        shape=(len(tokens),)
    )
    arr[:] = np.array(tokens, dtype=np.uint16)
    arr.flush()

print(f"Writing train.bin ({len(train_tokens):,} tokens)")
write_bin(os.path.join(out_dir, "train.bin"), train_tokens)

print(f"Writing test.bin ({len(val_tokens):,} tokens)")
write_bin(os.path.join(out_dir, "test.bin"), val_tokens)

print("Dataset preparation complete.")


Loading 80,000 samples (streaming)


Resolving data files:   0%|          | 0/80 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/80 [00:00<?, ?it/s]

Tokenizing & splitting


100%|██████████| 80000/80000 [01:10<00:00, 1126.98it/s]


Writing train.bin (89,498,669 tokens)
Writing test.bin (934,461 tokens)
Dataset preparation complete.


In [11]:
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F

class LayerNorm(nn.Module):
    def __init__(self, ndim, bias=True):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim)) #element wise scaling
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, x):
        return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)

# RoPE
def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rope(q, k, cos, sin):
    q = (q * cos) + (rotate_half(q) * sin)
    k = (k * cos) + (rotate_half(k) * sin)
    return q, k

class RotaryEmbedding:
    def __init__(self, dim, max_seq_len):
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register(inv_freq, max_seq_len)

    def register(self, inv_freq, max_seq_len):
        t = torch.arange(max_seq_len)
        freqs = torch.einsum("i,j->ij", t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos = emb.cos()[None, None, :, :]
        self.sin = emb.sin()[None, None, :, :]

    def get(self, seq_len, device):
        return (
            self.cos[:, :, :seq_len, :].to(device),
            self.sin[:, :, :seq_len, :].to(device),
        )


class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embed % config.n_head == 0

        self.n_head = config.n_head
        self.head_dim = config.n_embed // config.n_head
        self.dropout = config.dropout

        self.qkv = nn.Linear(config.n_embed, 3 * config.n_embed, bias=config.bias)
        self.proj = nn.Linear(config.n_embed, config.n_embed, bias=config.bias)

        self.rope = RotaryEmbedding(self.head_dim, config.block_size)

    def forward(self, x):
        B, T, C = x.size()

        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        cos, sin = self.rope.get(T, x.device)
        q, k = apply_rope(q, k, cos, sin)

        # FlashAttention (automatic if supported)
        out = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=True
        )

        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.proj(out)


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.fc = nn.Linear(config.n_embed, 4 * config.n_embed, bias=config.bias)
        self.proj = nn.Linear(4 * config.n_embed, config.n_embed, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.fc(x)
        x = F.gelu(x)
        x = self.proj(x)
        return self.dropout(x)

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = LayerNorm(config.n_embed, config.bias)
        self.attn = MultiHeadAttention(config)
        self.ln2 = LayerNorm(config.n_embed, config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

@dataclass
class SLMconfig:
    vocab_size: int = 50257
    block_size: int = 64
    n_layer: int = 4
    n_head: int = 6
    n_embed: int = 384
    dropout: float = 0.1
    bias: bool = False

class SLM(nn.Module):
    def __init__(self, config: SLMconfig):
        super().__init__()
        self.config = config

        self.token_emb = nn.Embedding(config.vocab_size, config.n_embed)
        self.drop = nn.Dropout(config.dropout)

        self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_f = LayerNorm(config.n_embed, config.bias)

        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)
        self.token_emb.weight = self.lm_head.weight  # weight tying

        self.apply(self._init_weights)

        print(f"Number of parameters: {self.num_params()/1e6:.2f}M")

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def num_params(self):
        return sum(p.numel() for p in self.parameters())

    def forward(self, idx, targets=None):
        B, T = idx.shape
        assert T <= self.config.block_size

        x = self.token_emb(idx)
        x = self.drop(x)

        for block in self.blocks:
            x = block(x)

        x = self.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-1
            )

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.config.block_size :]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature

            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float("Inf")

            probs = F.softmax(logits, dim=-1)
            next_idx = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_idx), dim=1)

        return idx


In [None]:
import os
import math
import time
import torch
import numpy as np
from tqdm import tqdm
from pathlib import Path

# =========================
# DEVICE
# =========================
device = "cuda" if torch.cuda.is_available() else "cpu"

# =========================
# TRAINING HYPERPARAMS
# =========================
batch_size = 8
block_size = 256
gradient_accum_steps = 8
max_iters = 20000
eval_interval = 1000
learning_rate = 3e-4
eval_iters = 200

# =========================
# CHECKPOINT CONFIG
# =========================
from google.colab import drive
drive.mount("/content/drive")

CKPT_ROOT = Path("/content/drive/MyDrive/slm_checkpoints")
CKPT_ROOT.mkdir(parents=True, exist_ok=True)

SAVE_EVERY = eval_interval
MAX_KEEP = 3

# =========================
# DATA LOADING
# =========================
def load_data(split):
    return np.memmap(
        f"/content/data/{split}.bin",
        dtype=np.uint16,
        mode="r"
    )

train_data = load_data("train")
val_data   = load_data("test")

def get_batch(data):
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([
        torch.from_numpy(data[i:i+block_size].astype(np.int64))
        for i in ix
    ])
    y = torch.stack([
        torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64))
        for i in ix
    ])
    return x.to(device), y.to(device)

# =========================
# MODEL
# =========================
config = SLMconfig(
    vocab_size=50257,
    block_size=block_size,
    n_layer=4,
    n_head=6,
    n_embed=384,
    dropout=0.1
)

model = SLM(config).to(device)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=0.06
)

# =========================
# LR SCHEDULE
# =========================
warmup_steps = int(0.05 * max_iters)
min_lr = learning_rate * 0.1

def get_lr(step):
    if step < warmup_steps:
        return learning_rate * step / warmup_steps
    progress = (step - warmup_steps) / (max_iters - warmup_steps)
    cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
    return min_lr + cosine_decay * (learning_rate - min_lr)

# =========================
# EVAL
# =========================
@torch.no_grad()
def estimate_loss():
    model.eval()
    out = {}
    for split, data in [("train", train_data), ("val", val_data)]:
        losses = []
        for _ in range(eval_iters):
            X, Y = get_batch(data)
            _, loss = model(X, Y)
            losses.append(loss.item())
        out[split] = sum(losses) / len(losses)
    model.train()
    return out

# =========================
# ATOMIC SAVE
# =========================
def atomic_save(obj, path):
    tmp = path.with_suffix(".tmp")
    torch.save(obj, tmp, _use_new_zipfile_serialization=False)
    os.replace(tmp, path)

def save_checkpoint(step):
    ckpt_dir = CKPT_ROOT / f"step_{step}"
    ckpt_dir.mkdir(exist_ok=True)

    atomic_save(model.state_dict(), ckpt_dir / "model.pt")
    atomic_save(optimizer.state_dict(), ckpt_dir / "optimizer.pt")
    atomic_save({
    "step": step,
    "torch_rng": torch.get_rng_state(),
    "cuda_rng": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None}, ckpt_dir / "meta.pt")

    # cleanup old
    ckpts = sorted(CKPT_ROOT.glob("step_*"), key=lambda x: int(x.name.split("_")[1]))
    for old in ckpts[:-MAX_KEEP]:
        for f in old.iterdir():
            f.unlink()
        old.rmdir()

    print(f"[✓] checkpoint saved at step {step}")

# =========================
# RESUME (AUTO)
# =========================
def try_resume():
    ckpts = sorted(CKPT_ROOT.glob("step_*"), key=lambda x: int(x.name.split("_")[1]))
    if not ckpts:
        return 0

    latest = ckpts[-1]
    print(f"[↻] resuming from {latest.name}")

    model.load_state_dict(torch.load(latest / "model.pt", map_location=device))
    optimizer.load_state_dict(torch.load(latest / "optimizer.pt", map_location=device))

    meta = torch.load(latest / "meta.pt", map_location="cpu",weights_only=False)
    torch.set_rng_state(meta["torch_rng"])
    if meta["cuda_rng"] is not None and torch.cuda.is_available():
        torch.cuda.set_rng_state_all(meta["cuda_rng"])

    return meta["step"] + 1

# =========================
# TRAIN LOOP
# =========================
print("Starting training...")
start_step = try_resume()
t0 = time.time()

for step in range(start_step, max_iters):

    lr = get_lr(step)
    for pg in optimizer.param_groups:
        pg["lr"] = lr

    optimizer.zero_grad(set_to_none=True)

    for _ in range(gradient_accum_steps):
        X, Y = get_batch(train_data)
        _, loss = model(X, Y)
        (loss / gradient_accum_steps).backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    if step % eval_interval == 0:
        losses = estimate_loss()
        print(
            f"step {step} | "
            f"train {losses['train']:.4f} | "
            f"val {losses['val']:.4f}"
        )
        save_checkpoint(step)

print("Training complete.")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Number of parameters: 26.38M
Starting training...
step 0 | train 10.9084 | val 10.9070
[✓] checkpoint saved at step 0


before rerunning, delete thid folder using-

In [None]:
# -rm -rf /content/drive/MyDrive/slm_checkpoints/*