# 2.6B transformer based on pile uncopyrighted streamed via huggingface datasets

In [1]:
%pip install tiktoken tqdm datasets tiktoken zstandard "fsspec[compression]"

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


# RESTART THE KERNEL IF RUNNING ON RUNPOD AFTER THIS PIP INSTALL

In [2]:
# -------- MODEL PARAMS --------
n_layers    = 48
n_embd      = 2048
n_heads     = 16
context_len = 1024
batch_size  = 8    # with grad accumulation
dropout     = 0
lr          = 3e-6
from tiktoken import get_encoding
tokenizer = get_encoding("gpt2")
vocab_size  = tokenizer.n_vocab

# ------------------------------

# *fuck this trick i suspended this trick*

In [22]:
import math

def get_lr(step, max_steps, base_lr=lr, warmup_steps=2000):
    """
    Warmup + cosine decay learning rate schedule.

    Args:
        step (int): current training step
        max_steps (int): total training steps
        base_lr (float): peak learning rate
        warmup_steps (int): number of warmup steps

    Returns:
        float: learning rate for this step
    """
    return base_lr


In [4]:
import torch
from torch import nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
class TokenEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.token = nn.Embedding(vocab_size, n_embd)
        self.pos   = nn.Embedding(context_len, n_embd)

    def forward(self, x):
        B, T = x.shape
        tok = self.token(x)                    # (B, T, C)
        pos = self.pos(torch.arange(T))        # (T, C)
        return tok + pos
emb = TokenEmbedding()



# Above is the definition of our embeddings

# Now define attention

In [6]:
class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_heads, dropout=0.0):
        super().__init__()
        assert n_embd % n_heads == 0

        self.n_heads = n_heads
        self.head_dim = n_embd // n_heads

        self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
        self.proj = nn.Linear(n_embd, n_embd, bias=False)
        self.dropout = dropout

    def forward(self, x):
        B, T, C = x.shape

        qkv = self.qkv(x)                       # (B, T, 3C)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        # this is the critical line
        y = F.scaled_dot_product_attention(
            q, k, v,
            is_causal=True,
            dropout_p=self.dropout if self.training else 0.0,
        )

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.proj(y)


# Constants again and positional embeddings


In [7]:
class TokenPosEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.token = nn.Embedding(vocab_size, n_embd)
        self.pos   = nn.Embedding(context_len, n_embd)

    def forward(self, x):
        B, T = x.shape
        tok = self.token(x)                                # (B,T,C)
        pos = self.pos(torch.arange(T, device=x.device))  # (T,C)
        return tok + pos


# After attention we have basic FFN

In [8]:
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
        )

    def forward(self, x):
        return self.net(x)


In [9]:
class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

        self.attn = CausalSelfAttention(n_embd, n_heads, dropout)
        self.ff   = FeedForward()

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x


In [10]:
class TransformerLM(nn.Module):
    def __init__(self):
        super().__init__()

        # token + positional embeddings
        self.token_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb   = nn.Embedding(context_len, n_embd)

        self.blocks = nn.Sequential(
            *[Block() for _ in range(n_layers)]
        )

        # final normalization + LM head
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size)

    def forward(self, x, targets=None):
        B, T = x.shape

        # embeddings
        tok = self.token_emb(x)                              # (B,T,C)
        pos = self.pos_emb(torch.arange(T, device=x.device))# (T,C)
        x = tok + pos                                        # (B,T,C)

        # APPLY ALL BLOCKS HERE
        x = self.blocks(x)

        # final projection
        x = self.ln_f(x)
        logits = self.head(x)                                # (B,T,vocab)

        loss = None
        if targets is not None:
            logits = logits.view(B*T, vocab_size)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss



In [11]:
model = TransformerLM().to(device)


In [12]:
print(sum(p.numel() for p in model.parameters()) / 1e6, "M params")


2624.808017 M params


# optimus

In [13]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=lr
)


In [14]:
@torch.no_grad()
def estimate_loss_from_iters(model, train_batch_iter, val_batch_iter,
                             eval_batches=20, device="cuda"):
    model.eval()
    out = {}

    for name, it in [("train", train_batch_iter), ("val", val_batch_iter)]:
        losses = []
        for _ in range(eval_batches):
            xb, yb = next(it)
            # already on device in this pipeline, but keep safe:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            _, loss = model(xb, yb)
            losses.append(loss.item())
        out[name] = sum(losses) / len(losses)

    model.train()
    return out


In [15]:
def token_stream(hf_dataset, tokenizer):
    """
    Lazily yields lists of token ids from a streaming HF dataset.
    """
    for ex in hf_dataset:
        text = ex.get("text", "")
        if not text:
            continue
        ids = tokenizer.encode(text)
        if len(ids) > 1:
            yield ids
import random
import torch

def window_stream(token_iter, context_len, device):
    """
    Yields single (x, y) training samples of shape [context_len].
    """
    buffer = []

    for ids in token_iter:
        buffer.extend(ids)

        while len(buffer) >= context_len + 1:
            start = random.randint(0, len(buffer) - context_len - 1)

            x = buffer[start : start + context_len]
            y = buffer[start + 1 : start + context_len + 1]

            yield (
                torch.tensor(x, dtype=torch.long, device=device),
                torch.tensor(y, dtype=torch.long, device=device),
            )
def batch_stream(sample_iter, batch_size):
    """
    Groups single samples into batches.
    """
    while True:
        xb, yb = zip(*(next(sample_iter) for _ in range(batch_size)))
        yield torch.stack(xb), torch.stack(yb)
from datasets import load_dataset

# load streaming dataset
ds = load_dataset(
    "monology/pile-uncopyrighted",
    split="train",
    streaming=True,
)

# tokenizer
tokenizer = get_encoding("gpt2")

# build streams
tok_iter   = token_stream(ds, tokenizer)
sample_iter = window_stream(tok_iter, context_len, device)
batch_iter  = batch_stream(sample_iter, batch_size)


README.md:   0%|          | 0.00/776 [00:00<?, ?B/s]



Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

In [18]:
import os
os.environ["HF_TOKEN"] = "hf_my_token"


In [19]:
from huggingface_hub import HfApi, upload_file
import os
import torch
from tqdm.notebook import tqdm
import torch.nn.functional as F

HF_REPO = "345rf4gt56t4r3e3/nnn"
CKPT_DIR = "checkpoints"

os.makedirs(CKPT_DIR, exist_ok=True)

api = HfApi()

# create repo if it doesn't exist
api.create_repo(
    repo_id=HF_REPO,
    exist_ok=True,
    repo_type="model",

)


RepoUrl('https://huggingface.co/345rf4gt56t4r3e3/nnn', endpoint='https://huggingface.co', repo_type='model', repo_id='345rf4gt56t4r3e3/nnn')

In [20]:
import os, math, random, hashlib, time
from collections import deque
import torch
import torch.nn.functional as F
from datasets import load_dataset
from itertools import islice

# ---------------- SPLIT CONFIG ----------------
SEED = 1337
TRAIN_EXAMPLES = 2_000_000
VAL_EXAMPLES   = 10_000
SHUFFLE_BUFFER = 50_000
# -------------- STREAM CONFIG -----------------
STRIDE = context_len
MAX_BUFFER_TOKENS = 2_000_000
# ------------------------------------------------

def _drain_with_progress(it, n, label="[skip]"):
    """
    Drain exactly n items from iterator it, printing progress + ETA.
    Returns the SAME iterator positioned after draining.
    """
    t0 = time.time()
    last_t = t0
    last_i = 0

    # We can’t “seek” in streaming; we must consume.
    for i in range(1, n + 1):
        try:
            next(it)
        except StopIteration:
            print(f"{label} iterator ended early at {i-1}/{n}")
            return it

        # Print often at start, then every 10k
        if i <= 10 or i % 10_000 == 0:
            now = time.time()
            elapsed = now - t0
            dt = now - last_t
            di = i - last_i
            rate = (di / dt) if dt > 0 else 0.0

            # overall avg rate + ETA
            avg_rate = (i / elapsed) if elapsed > 0 else 0.0
            remaining = n - i
            eta = (remaining / avg_rate) if avg_rate > 0 else float("inf")

            print(
                f"{label} {i:,}/{n:,} "
                f"| inst={rate:,.1f} ex/s avg={avg_rate:,.1f} ex/s "
                f"| elapsed={elapsed/60:.1f}m eta={eta/60:.1f}m"
            )

            last_t = now
            last_i = i

    return it

def build_streaming_splits(dataset_name="monology/pile-uncopyrighted",
                           split="train",
                           train_examples=TRAIN_EXAMPLES,
                           val_examples=VAL_EXAMPLES,
                           seed=SEED,
                           shuffle_buffer=SHUFFLE_BUFFER):
    print("[build_streaming_splits] loading dataset...")
    t0 = time.time()

    ds = load_dataset(dataset_name, split=split, streaming=True)

    print(f"[build_streaming_splits] dataset object ready in {time.time()-t0:.2f}s")

    # IMPORTANT:
    # We MUST create two independent streaming iterators if we want to both:
    # - train on first N (shuffled)
    # - val on the next M after skipping N
    # But streaming datasets are single-pass iterables; calling .take() and .skip()
    # on the same ds object can still be okay lazily, but progress during skip is invisible.
    #
    # Here: we explicitly build train stream from ds, and build val stream by creating
    # a fresh streaming dataset iterator, draining N with progress, then taking M.

    print("[build_streaming_splits] building train stream...")
    t1 = time.time()
    train_ds = ds.take(train_examples).shuffle(seed=seed, buffer_size=shuffle_buffer)
    print(f"[build_streaming_splits] train stream ready in {time.time()-t1:.2f}s")

    print("[build_streaming_splits] building val stream (skip with progress)...")
    t2 = time.time()

    # Fresh iterator for val path (so we don't consume the train path)
    ds_val_base = load_dataset(dataset_name, split=split, streaming=True)
    it = iter(ds_val_base)

    # Drain exactly TRAIN_EXAMPLES with visible progress
    it = _drain_with_progress(it, train_examples, label="[val.skip]")

    # Now take val_examples from remaining stream
    # Wrap back into a simple iterable of dict examples
    def _val_iter():
        for ex in islice(it, val_examples):
            yield ex

    val_ds = _val_iter()

    print(f"[build_streaming_splits] val stream ready in {time.time()-t2:.2f}s")

    return train_ds, val_ds

def token_stream(hf_dataset, tokenizer):
    doc_count = 0
    t0 = time.time()

    for ex in hf_dataset:
        doc_count += 1
        if doc_count % 10_000 == 0:
            elapsed = time.time() - t0
            print(f"[token_stream] {doc_count} docs in {elapsed:.1f}s ({doc_count/elapsed:.1f} docs/s)")

        text = ex.get("text", "")
        if not text:
            continue

        ids = tokenizer.encode(text)
        if len(ids) >= 2:
            yield ids

def window_stream(token_iter, context_len, *, device, stride=STRIDE, max_buffer_tokens=MAX_BUFFER_TOKENS):
    buf = deque()
    total_tokens = 0
    yielded = 0
    t0 = time.time()

    for ids in token_iter:
        buf.extend(ids)
        total_tokens += len(ids)

        if total_tokens % 100_000 == 0:
            elapsed = time.time() - t0
            print(
                f"[window_stream] tokens buffered={len(buf)} "
                f"total_seen={total_tokens} "
                f"rate={total_tokens/elapsed:.0f} tok/s "
                f"samples_yielded={yielded}"
            )

        while len(buf) > max_buffer_tokens:
            buf.popleft()

        while len(buf) >= context_len + 1:
            x = [buf[i] for i in range(context_len)]
            y = [buf[i+1] for i in range(context_len)]

            for _ in range(stride):
                if not buf:
                    break
                buf.popleft()

            yielded += 1
            if yielded % 100 == 0:
                elapsed = time.time() - t0
                print(
                    f"[window_stream] yielded={yielded} "
                    f"avg_time_per_sample={elapsed/yielded:.3f}s"
                )

            yield (
                torch.tensor(x, dtype=torch.long, device=device),
                torch.tensor(y, dtype=torch.long, device=device),
            )

def batch_stream(sample_iter, batch_size):
    batch_count = 0
    t0 = time.time()

    while True:
        batch_count += 1
        tb0 = time.time()
        xb, yb = zip(*(next(sample_iter) for _ in range(batch_size)))
        tb1 = time.time()

        if batch_count <= 5 or batch_count % 10 == 0:
            elapsed = time.time() - t0
            print(
                f"[batch_stream] batch={batch_count} "
                f"batch_time={tb1-tb0:.2f}s "
                f"avg_batch_time={elapsed/batch_count:.2f}s"
            )

        yield torch.stack(xb, 0), torch.stack(yb, 0)

@torch.no_grad()
def sanity_check_batch(xb, yb, vocab_size, context_len):
    assert xb.ndim == 2 and yb.ndim == 2
    B, T = xb.shape
    assert yb.shape == (B, T)
    assert T == context_len

    x_min, x_max = int(xb.min()), int(xb.max())
    y_min, y_max = int(yb.min()), int(yb.max())
    assert 0 <= x_min and x_max < vocab_size
    assert 0 <= y_min and y_max < vocab_size

    assert torch.equal(yb[:, :-1], xb[:, 1:])
    assert not torch.equal(xb, yb)

    return {"B": B, "T": T, "xb_range": (x_min, x_max), "yb_range": (y_min, y_max)}

@torch.no_grad()
def estimate_loss_from_iters(model, train_batch_iter, val_batch_iter, eval_batches=20, device="cuda"):
    model.eval()
    out = {}
    for name, it in [("train", train_batch_iter), ("val", val_batch_iter)]:
        losses = []
        t0 = time.time()
        for i in range(eval_batches):
            xb, yb = next(it)
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            _, loss = model(xb, yb)
            losses.append(loss.item())
        elapsed = time.time() - t0
        print(f"[estimate_loss] {name} {eval_batches} batches in {elapsed:.2f}s")
        out[name] = sum(losses) / len(losses)
    model.train()
    return out

def batch_fingerprint(xb):
    b = xb[0, :16].detach().cpu().numpy().tobytes()
    return hashlib.sha1(b).hexdigest()[:10]

# ---------------- BUILD ----------------
print("=== BUILDING STREAMS ===")
train_ds, val_ds = build_streaming_splits()

print("=== BUILDING TOKEN STREAMS ===")
train_tok = token_stream(train_ds, tokenizer)
val_tok   = token_stream(val_ds, tokenizer)

print("=== BUILDING WINDOW STREAMS ===")
train_samples = window_stream(train_tok, context_len, device=device)
val_samples   = window_stream(val_tok,   context_len, device=device)

print("=== BUILDING BATCH STREAMS ===")
train_batch_iter = batch_stream(train_samples, batch_size)
val_batch_iter   = batch_stream(val_samples,   batch_size)

# ---------------- SANITY ----------------
print("=== SANITY CHECK: TRAIN ===")
t0 = time.time()
xb, yb = next(train_batch_iter)
print(f"[sanity] train batch fetched in {time.time()-t0:.2f}s")
print("train batch sanity:", sanity_check_batch(xb, yb, vocab_size, context_len), "fp", batch_fingerprint(xb))

print("=== SANITY CHECK: VAL ===")
t1 = time.time()
xb, yb = next(val_batch_iter)
print(f"[sanity] val batch fetched in {time.time()-t1:.2f}s")
print("val batch sanity:", sanity_check_batch(xb, yb, vocab_size, context_len), "fp", batch_fingerprint(xb))


=== BUILDING STREAMS ===
[build_streaming_splits] loading dataset...


Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

[build_streaming_splits] dataset object ready in 1.46s
[build_streaming_splits] building train stream...
[build_streaming_splits] train stream ready in 0.00s
[build_streaming_splits] building val stream (skip with progress)...


Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

[val.skip] 1/2,000,000 | inst=1.2 ex/s avg=1.2 ex/s | elapsed=0.0m eta=27371.0m
[val.skip] 2/2,000,000 | inst=14,218.0 ex/s avg=2.4 ex/s | elapsed=0.0m eta=13686.7m
[val.skip] 3/2,000,000 | inst=80,659.7 ex/s avg=3.7 ex/s | elapsed=0.0m eta=9124.6m
[val.skip] 4/2,000,000 | inst=127,100.1 ex/s avg=4.9 ex/s | elapsed=0.0m eta=6843.5m
[val.skip] 5/2,000,000 | inst=76,260.1 ex/s avg=6.1 ex/s | elapsed=0.0m eta=5474.9m
[val.skip] 6/2,000,000 | inst=131,072.0 ex/s avg=7.3 ex/s | elapsed=0.0m eta=4562.4m
[val.skip] 7/2,000,000 | inst=119,837.3 ex/s avg=8.5 ex/s | elapsed=0.0m eta=3910.7m
[val.skip] 8/2,000,000 | inst=87,381.3 ex/s avg=9.7 ex/s | elapsed=0.0m eta=3421.9m
[val.skip] 9/2,000,000 | inst=116,508.4 ex/s avg=11.0 ex/s | elapsed=0.0m eta=3041.7m
[val.skip] 10/2,000,000 | inst=135,300.1 ex/s avg=12.2 ex/s | elapsed=0.0m eta=2737.6m
[val.skip] 10,000/2,000,000 | inst=9,068.8 ex/s avg=5,200.6 ex/s | elapsed=0.0m eta=6.4m
[val.skip] 20,000/2,000,000 | inst=9,133.5 ex/s avg=6,627.5 ex/s |

In [23]:
from tqdm.notebook import tqdm
from huggingface_hub import upload_file

# ---------- CONFIG ----------
CKPT_DIR   = "weights"
SAVE_EVERY = 700
VAL_EVERY  = 200          # evaluate train/val loss every N steps
EVAL_BATCHES = 10         # how many batches to average for eval
MAX_STEPS  = 9000
GRAD_CLIP  = 1.0          # helps stability
# ---------------------------

os.makedirs(CKPT_DIR, exist_ok=True)

scaler = torch.amp.GradScaler("cuda") if device.type == "cuda" else None

history = []
model.train()

for step in tqdm(range(MAX_STEPS)):
    # ---- LR schedule ----
    lr_now = get_lr(step, MAX_STEPS)
    for g in optimizer.param_groups:
        g["lr"] = lr_now

    # ---- get batch ----
    xb, yb = next(train_batch_iter)  # ✅ train stream only
    xb = xb.to(device, non_blocking=True)
    yb = yb.to(device, non_blocking=True)

    # optional cheap loop health check
    if step % 500 == 0:
        print("fp", batch_fingerprint(xb))

    # ---- forward ----
    if device.type == "cuda":
        with torch.amp.autocast("cuda"):
            _, loss = model(xb, yb)
    else:
        _, loss = model(xb, yb)

    # ---- safety: NaN/Inf loss ----
    if not torch.isfinite(loss):
        print(f"[FATAL] non-finite loss at step {step}: {loss.item()}")
        # dump a debug batch
        torch.save({"xb": xb.detach().cpu(), "yb": yb.detach().cpu()}, os.path.join(CKPT_DIR, f"bad_batch_step_{step}.pt"))
        break

    optimizer.zero_grad(set_to_none=True)

    # ---- backward (AMP if cuda) ----
    if device.type == "cuda":
        scaler.scale(loss).backward()
        # unscale -> clip -> step
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        scaler.step(optimizer)
        scaler.update()
    else:
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()

    # ---- logging ----
    loss_val = loss.item()
    history.append(loss_val)

    if step % 10 == 0:
        print(f"step {step:5d} | lr {lr_now:.3e} | train_loss {loss_val:.4f}")

    # ---- periodic eval (REAL val) ----
    if step > 0 and step % VAL_EVERY == 0:
        stats = estimate_loss_from_iters(
            model,
            train_batch_iter,
            val_batch_iter,
            eval_batches=EVAL_BATCHES,
            device=device.type,
        )
        print(f"[eval] step {step:5d} | train {stats['train']:.4f} | val {stats['val']:.4f}")

    # ---- SAVE *WEIGHTS ONLY* (SAFE) ----
    if step > 0 and step % SAVE_EVERY == 0:
        fname = f"model_step_{step}.pt"
        fpath = os.path.join(CKPT_DIR, fname)

        # 🔒 SAFE FP16 EXPORT (does NOT touch training model)
        state_fp16 = {k: v.detach().half().cpu() for k, v in model.state_dict().items()}
        torch.save(state_fp16, fpath)
        print(f"[saved FP16 weights → {fpath}]")

        # upload versioned snapshot
        upload_file(
            path_or_fileobj=fpath,
            path_in_repo=fname,
            repo_id=HF_REPO,
            repo_type="model",
            commit_message=f"weights @ step {step}",
        )

        # update rolling pointer
        upload_file(
            path_or_fileobj=fpath,
            path_in_repo="latest.pt",
            repo_id=HF_REPO,
            repo_type="model",
            commit_message=f"update latest @ step {step}",
        )

        print(f"[uploaded → hf://{HF_REPO}/{fname}]")


  0%|          | 0/9000 [00:00<?, ?it/s]

[batch_stream] batch=240 batch_time=0.01s avg_batch_time=0.56s
fp ef30f40788
step     0 | lr 3.000e-06 | train_loss 10.0673
[window_stream] yielded=2000 avg_time_per_sample=0.069s
[batch_stream] batch=250 batch_time=0.01s avg_batch_time=0.56s
step    10 | lr 3.000e-06 | train_loss 9.1881
[batch_stream] batch=260 batch_time=0.01s avg_batch_time=0.55s
step    20 | lr 3.000e-06 | train_loss 9.1061
[window_stream] yielded=2100 avg_time_per_sample=0.069s
[batch_stream] batch=270 batch_time=0.00s avg_batch_time=0.54s
step    30 | lr 3.000e-06 | train_loss 9.5590
[window_stream] yielded=2200 avg_time_per_sample=0.068s
[batch_stream] batch=280 batch_time=0.00s avg_batch_time=0.54s
step    40 | lr 3.000e-06 | train_loss 8.3156
[window_stream] yielded=2300 avg_time_per_sample=0.067s
[batch_stream] batch=290 batch_time=0.00s avg_batch_time=0.54s
step    50 | lr 3.000e-06 | train_loss 6.9631
[window_stream] yielded=2400 avg_time_per_sample=0.066s
[batch_stream] batch=300 batch_time=0.01s avg_batch

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

[uploaded → hf://345rf4gt56t4r3e3/nnn/model_step_700.pt]
[window_stream] yielded=7800 avg_time_per_sample=0.075s
[batch_stream] batch=980 batch_time=0.01s avg_batch_time=0.60s
step   710 | lr 3.000e-06 | train_loss 6.9647
[window_stream] yielded=7900 avg_time_per_sample=0.075s
[batch_stream] batch=990 batch_time=0.00s avg_batch_time=0.60s
step   720 | lr 3.000e-06 | train_loss 5.8218
[window_stream] yielded=8000 avg_time_per_sample=0.074s
[batch_stream] batch=1000 batch_time=0.01s avg_batch_time=0.59s
step   730 | lr 3.000e-06 | train_loss 5.8928
[batch_stream] batch=1010 batch_time=0.00s avg_batch_time=0.59s
step   740 | lr 3.000e-06 | train_loss 6.1133
[window_stream] yielded=8100 avg_time_per_sample=0.074s
[batch_stream] batch=1020 batch_time=0.01s avg_batch_time=0.59s
step   750 | lr 3.000e-06 | train_loss 7.4464
[window_stream] yielded=8200 avg_time_per_sample=0.074s
[batch_stream] batch=1030 batch_time=0.02s avg_batch_time=0.59s
step   760 | lr 3.000e-06 | train_loss 7.1951
[wind

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

[uploaded → hf://345rf4gt56t4r3e3/nnn/model_step_1400.pt]
[window_stream] yielded=13700 avg_time_per_sample=0.076s
[batch_stream] batch=1720 batch_time=0.01s avg_batch_time=0.61s
step  1410 | lr 3.000e-06 | train_loss 7.1621
[window_stream] yielded=13800 avg_time_per_sample=0.076s
[batch_stream] batch=1730 batch_time=0.01s avg_batch_time=0.61s
step  1420 | lr 3.000e-06 | train_loss 6.8651
[window_stream] yielded=13900 avg_time_per_sample=0.076s
[batch_stream] batch=1740 batch_time=0.01s avg_batch_time=0.61s
step  1430 | lr 3.000e-06 | train_loss 6.9542
[window_stream] yielded=14000 avg_time_per_sample=0.076s
[batch_stream] batch=1750 batch_time=0.00s avg_batch_time=0.60s
step  1440 | lr 3.000e-06 | train_loss 6.9205
[batch_stream] batch=1760 batch_time=0.00s avg_batch_time=0.60s
step  1450 | lr 3.000e-06 | train_loss 6.8813
[window_stream] yielded=14100 avg_time_per_sample=0.075s
[batch_stream] batch=1770 batch_time=0.01s avg_batch_time=0.60s
step  1460 | lr 3.000e-06 | train_loss 5.31

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

[uploaded → hf://345rf4gt56t4r3e3/nnn/model_step_2100.pt]
[window_stream] yielded=19600 avg_time_per_sample=0.076s
[batch_stream] batch=2450 batch_time=0.00s avg_batch_time=0.61s
step  2110 | lr 3.000e-06 | train_loss 5.0834
[batch_stream] batch=2460 batch_time=0.03s avg_batch_time=0.61s
step  2120 | lr 3.000e-06 | train_loss 6.4235
[window_stream] yielded=19700 avg_time_per_sample=0.076s
[batch_stream] batch=2470 batch_time=0.00s avg_batch_time=0.61s
step  2130 | lr 3.000e-06 | train_loss 5.7862
[window_stream] yielded=19800 avg_time_per_sample=0.076s
[batch_stream] batch=2480 batch_time=0.00s avg_batch_time=0.60s
step  2140 | lr 3.000e-06 | train_loss 5.6431
[window_stream] yielded=19900 avg_time_per_sample=0.075s
[batch_stream] batch=2490 batch_time=0.01s avg_batch_time=0.60s
step  2150 | lr 3.000e-06 | train_loss 6.8243
[window_stream] yielded=20000 avg_time_per_sample=0.075s
[batch_stream] batch=2500 batch_time=0.01s avg_batch_time=0.60s
step  2160 | lr 3.000e-06 | train_loss 5.99

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

[uploaded → hf://345rf4gt56t4r3e3/nnn/model_step_2800.pt]
[window_stream] yielded=25500 avg_time_per_sample=0.076s
[batch_stream] batch=3190 batch_time=0.00s avg_batch_time=0.61s
step  2810 | lr 3.000e-06 | train_loss 6.2171
[window_stream] yielded=25600 avg_time_per_sample=0.076s
[batch_stream] batch=3200 batch_time=0.01s avg_batch_time=0.61s
step  2820 | lr 3.000e-06 | train_loss 6.2830
[batch_stream] batch=3210 batch_time=0.00s avg_batch_time=0.61s
step  2830 | lr 3.000e-06 | train_loss 6.5752
[window_stream] yielded=25700 avg_time_per_sample=0.076s
[batch_stream] batch=3220 batch_time=0.00s avg_batch_time=0.61s
step  2840 | lr 3.000e-06 | train_loss 6.8074
[window_stream] yielded=25800 avg_time_per_sample=0.076s
[batch_stream] batch=3230 batch_time=0.01s avg_batch_time=0.61s
step  2850 | lr 3.000e-06 | train_loss 6.1810
[window_stream] yielded=25900 avg_time_per_sample=0.076s
[batch_stream] batch=3240 batch_time=0.00s avg_batch_time=0.61s
step  2860 | lr 3.000e-06 | train_loss 6.62

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

[uploaded → hf://345rf4gt56t4r3e3/nnn/model_step_3500.pt]
[window_stream] yielded=31300 avg_time_per_sample=0.077s
[batch_stream] batch=3920 batch_time=0.01s avg_batch_time=0.61s
step  3510 | lr 3.000e-06 | train_loss 6.8460
[window_stream] yielded=31400 avg_time_per_sample=0.077s
[batch_stream] batch=3930 batch_time=0.01s avg_batch_time=0.61s
step  3520 | lr 3.000e-06 | train_loss 6.9041
[window_stream] yielded=31500 avg_time_per_sample=0.077s
[batch_stream] batch=3940 batch_time=0.00s avg_batch_time=0.61s
step  3530 | lr 3.000e-06 | train_loss 5.3311
[window_stream] yielded=31600 avg_time_per_sample=0.076s
[batch_stream] batch=3950 batch_time=0.00s avg_batch_time=0.61s
step  3540 | lr 3.000e-06 | train_loss 4.8917
[batch_stream] batch=3960 batch_time=0.01s avg_batch_time=0.61s
step  3550 | lr 3.000e-06 | train_loss 6.7529
[window_stream] yielded=31700 avg_time_per_sample=0.076s
[batch_stream] batch=3970 batch_time=0.00s avg_batch_time=0.61s
step  3560 | lr 3.000e-06 | train_loss 7.00

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

[uploaded → hf://345rf4gt56t4r3e3/nnn/model_step_4200.pt]
[batch_stream] batch=4660 batch_time=0.01s avg_batch_time=0.61s
step  4210 | lr 3.000e-06 | train_loss 5.8252
[window_stream] yielded=37300 avg_time_per_sample=0.077s
[batch_stream] batch=4670 batch_time=0.00s avg_batch_time=0.61s
step  4220 | lr 3.000e-06 | train_loss 5.5544
[window_stream] yielded=37400 avg_time_per_sample=0.077s
[batch_stream] batch=4680 batch_time=0.01s avg_batch_time=0.61s
step  4230 | lr 3.000e-06 | train_loss 6.7073
[window_stream] yielded=37500 avg_time_per_sample=0.077s
[batch_stream] batch=4690 batch_time=0.01s avg_batch_time=0.61s
step  4240 | lr 3.000e-06 | train_loss 5.0921
[window_stream] yielded=37600 avg_time_per_sample=0.076s
[batch_stream] batch=4700 batch_time=0.01s avg_batch_time=0.61s
step  4250 | lr 3.000e-06 | train_loss 6.5903
[batch_stream] batch=4710 batch_time=0.00s avg_batch_time=0.61s
step  4260 | lr 3.000e-06 | train_loss 6.7459
[window_stream] yielded=37700 avg_time_per_sample=0.07

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

[uploaded → hf://345rf4gt56t4r3e3/nnn/model_step_4900.pt]
[window_stream] yielded=43100 avg_time_per_sample=0.077s
[batch_stream] batch=5390 batch_time=0.00s avg_batch_time=0.61s
step  4910 | lr 3.000e-06 | train_loss 6.4536
[window_stream] yielded=43200 avg_time_per_sample=0.077s
[batch_stream] batch=5400 batch_time=0.01s avg_batch_time=0.61s
step  4920 | lr 3.000e-06 | train_loss 5.8772
[batch_stream] batch=5410 batch_time=0.00s avg_batch_time=0.61s
step  4930 | lr 3.000e-06 | train_loss 5.8425
[window_stream] yielded=43300 avg_time_per_sample=0.077s
[batch_stream] batch=5420 batch_time=0.00s avg_batch_time=0.61s
step  4940 | lr 3.000e-06 | train_loss 5.9022
[window_stream] yielded=43400 avg_time_per_sample=0.076s
[batch_stream] batch=5430 batch_time=0.00s avg_batch_time=0.61s
step  4950 | lr 3.000e-06 | train_loss 1.5981
[window_stream] yielded=43500 avg_time_per_sample=0.076s
[batch_stream] batch=5440 batch_time=0.01s avg_batch_time=0.61s
step  4960 | lr 3.000e-06 | train_loss 6.46

KeyboardInterrupt: 

In [None]:
print(xb[0][:10])
print(yb[0][:10])


In [None]:
import os, math, random, time
import torch
import numpy as np

def seed_all(seed=1337):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_all(1337)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

# How many fixed batches to evaluate on.
# 50 is already much more stable than your current 10.
FIXED_VAL_BATCHES = 50

# Assumes you already have val_batch_iter defined (as in your notebook).
# This consumes from the stream ONCE and stores batches for repeatable evals.
fixed_val = []
for i in range(FIXED_VAL_BATCHES):
    xb, yb = next(val_batch_iter)  # <- your existing iterator
    fixed_val.append((xb.cpu(), yb.cpu()))
print(f"cached fixed_val batches: {len(fixed_val)}")
import math
import torch

@torch.no_grad()
def eval_batches(model, batches, device="cuda", amp=True):
    model.eval()

    total_loss = 0.0

    correct = 0
    total = 0

    for xb_cpu, yb_cpu in batches:
        xb = xb_cpu.to(device, non_blocking=True)
        yb = yb_cpu.to(device, non_blocking=True)

        if device == "cuda" and amp:
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                logits, loss = model(xb, yb)
        else:
            logits, loss = model(xb, yb)

        total_loss += float(loss.item())

        # --- Make logits comparable to yb for accuracy ---
        # yb is (B, T)
        B, T = yb.shape

        if logits.dim() == 3:
            # logits: (B, T, V)
            preds = torch.argmax(logits, dim=-1)          # (B, T)
            correct += int((preds == yb).sum().item())
            total += int(yb.numel())

        elif logits.dim() == 2:
            # logits: either (B*T, V) OR (B, V) for last token
            # If (B*T, V), compare to flattened yb
            if logits.shape[0] == B * T:
                preds = torch.argmax(logits, dim=-1)      # (B*T,)
                y_flat = yb.reshape(-1)                   # (B*T,)
                correct += int((preds == y_flat).sum().item())
                total += int(y_flat.numel())
            elif logits.shape[0] == B:
                # last-token-only logits (rare in training forward)
                # compare only last token of yb
                preds = torch.argmax(logits, dim=-1)      # (B,)
                correct += int((preds == yb[:, -1]).sum().item())
                total += int(B)
            else:
                raise RuntimeError(f"Unexpected logits shape {tuple(logits.shape)} for yb {tuple(yb.shape)}")
        else:
            raise RuntimeError(f"Unexpected logits dim {logits.dim()} with shape {tuple(logits.shape)}")

    mean_loss = total_loss / len(batches)
    ppl = math.exp(mean_loss) if mean_loss < 50 else float("inf")
    acc = correct / total if total > 0 else float("nan")

    return {"loss": mean_loss, "ppl": ppl, "acc": acc, "tokens": total, "batches": len(batches)}

def load_fp16_state_into_model(model, ckpt_path, device="cuda"):
    sd = torch.load(ckpt_path, map_location="cpu")
    missing, unexpected = model.load_state_dict(sd, strict=False)
    if missing or unexpected:
        print("WARNING load_state_dict mismatch")
        print("missing keys:", missing[:10], "..." if len(missing) > 10 else "")
        print("unexpected keys:", unexpected[:10], "..." if len(unexpected) > 10 else "")
    model.to(device)
    model.eval()
    return model

# Put paths you want to compare here:
CKPTS = [
      "weights/model_step_700.pt",
    "weights/model_step_2800.pt",
    "weights/model_step_3500.pt",
    "weights/model_step_4200.pt",
    # add more if you have them:
    "weights/model_step_4900.pt",
]
results = {}

for p in CKPTS:
    assert os.path.exists(p), f"missing: {p}"
    print("\n===", p, "===")

    # IMPORTANT: build a fresh model instance each time to avoid any weirdness
    # Replace TransformerLM(...) with however you construct your model in the notebook
    model = TransformerLM().to(device)

    load_fp16_state_into_model(model, p, device=device)

    t0 = time.time()
    val_metrics = eval_batches(model, fixed_val, device=device, amp=True)
    dt = time.time() - t0

    out = {"val": val_metrics, "sec": dt}

    # Optional train-eval slice
    if "fixed_train_eval" in globals():
        t1 = time.time()
        tr_metrics = eval_batches(model, fixed_train_eval, device=device, amp=True)
        out["train_eval"] = tr_metrics
        out["sec_train_eval"] = time.time() - t1

    results[p] = out

    print("val:", val_metrics, f"(time {dt:.2f}s)")
    if "train_eval" in out:
        print("train_eval:", out["train_eval"], f"(time {out['sec_train_eval']:.2f}s)")



device: cuda
[batch_stream] batch=310 batch_time=0.00s avg_batch_time=11.04s
[window_stream] yielded=2500 avg_time_per_sample=1.369s
[batch_stream] batch=320 batch_time=0.00s avg_batch_time=10.69s
[window_stream] yielded=2600 avg_time_per_sample=1.316s
[batch_stream] batch=330 batch_time=0.00s avg_batch_time=10.37s
[window_stream] yielded=2700 avg_time_per_sample=1.267s
[batch_stream] batch=340 batch_time=0.01s avg_batch_time=10.06s
[window_stream] yielded=2800 avg_time_per_sample=1.222s
[batch_stream] batch=350 batch_time=0.01s avg_batch_time=9.78s
cached fixed_val batches: 50

=== weights/model_step_700.pt ===


  sd = torch.load(ckpt_path, map_location="cpu")


val: {'loss': 6.854935159683228, 'ppl': 948.5506230881616, 'acc': 0.16411376953125, 'tokens': 409600, 'batches': 50} (time 5.80s)

=== weights/model_step_2800.pt ===
val: {'loss': 6.150047655105591, 'ppl': 468.7397240912154, 'acc': 0.19770751953125, 'tokens': 409600, 'batches': 50} (time 5.80s)

=== weights/model_step_3500.pt ===
val: {'loss': 6.025877189636231, 'ppl': 414.0046433717589, 'acc': 0.2038427734375, 'tokens': 409600, 'batches': 50} (time 5.80s)

=== weights/model_step_4200.pt ===


In [None]:
results

In [None]:
import torch

@torch.no_grad()
def complete(
    model,
    tokenizer,
    prompt: str,
    max_new_tokens: int = 50,
    temperature: float = 1.0,
    top_k: int | None = 50,
    device: str = "cuda",
):
    model.eval()

    # encode prompt
    idx = torch.tensor(
        [tokenizer.encode(prompt)],
        dtype=torch.long,
        device=device,
    )

    for _ in range(max_new_tokens):
        # crop context if needed
        idx_cond = idx[:, -context_len :]

        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / temperature

        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("inf")

        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        idx = torch.cat([idx, next_token], dim=1)

    # decode only the completion
    completion = tokenizer.decode(idx[0].tolist())
    return completion


In [None]:
from tiktoken import get_encoding


text = complete(
    model,
    tokenizer,
    prompt="Blue is a color",
    max_new_tokens=60,
    temperature=0.8,
    top_k=40,
    device=device,
)

print(text)


In [None]:
torch.cuda.empty_cache()