# 500M transformer based on pile uncopyrighted by stream training

In [2]:
%pip install tiktoken tqdm datasets tiktoken zstandard "fsspec[compression]"

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


# RESTART THE KERNEL IF RUNNING ON RUNPOD AFTER THIS PIP INSTALL

In [3]:
# -------- MODEL PARAMS --------
n_layers    = 64
n_embd      = 2048
n_heads     = 16
context_len = 1024
batch_size  = 8    # with grad accumulation
dropout     = 0
lr          = 3e-6
from tiktoken import get_encoding
tokenizer = get_encoding("gpt2")
vocab_size  = tokenizer.n_vocab

# ------------------------------

# *cool lr sinusoid trick*

In [4]:
import math

def get_lr(step, max_steps, base_lr=lr, warmup_steps=2000):
    """
    Warmup + cosine decay learning rate schedule.

    Args:
        step (int): current training step
        max_steps (int): total training steps
        base_lr (float): peak learning rate
        warmup_steps (int): number of warmup steps

    Returns:
        float: learning rate for this step
    """
    if step < warmup_steps:
        return base_lr * step / warmup_steps

    progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
    return base_lr * 0.5 * (1.0 + math.cos(math.pi * progress))


In [5]:
import torch
from torch import nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
class TokenEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.token = nn.Embedding(vocab_size, n_embd)
        self.pos   = nn.Embedding(context_len, n_embd)

    def forward(self, x):
        B, T = x.shape
        tok = self.token(x)                    # (B, T, C)
        pos = self.pos(torch.arange(T))        # (T, C)
        return tok + pos
emb = TokenEmbedding()



# Above is the definition of our embeddings

# Now define attention

In [7]:
class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_heads, dropout=0.0):
        super().__init__()
        assert n_embd % n_heads == 0

        self.n_heads = n_heads
        self.head_dim = n_embd // n_heads

        self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
        self.proj = nn.Linear(n_embd, n_embd, bias=False)
        self.dropout = dropout

    def forward(self, x):
        B, T, C = x.shape

        qkv = self.qkv(x)                       # (B, T, 3C)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        # this is the critical line
        y = F.scaled_dot_product_attention(
            q, k, v,
            is_causal=True,
            dropout_p=self.dropout if self.training else 0.0,
        )

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.proj(y)


# Constants again and positional embeddings


In [8]:
class TokenPosEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.token = nn.Embedding(vocab_size, n_embd)
        self.pos   = nn.Embedding(context_len, n_embd)

    def forward(self, x):
        B, T = x.shape
        tok = self.token(x)                                # (B,T,C)
        pos = self.pos(torch.arange(T, device=x.device))  # (T,C)
        return tok + pos


# After attention we have basic FFN

In [9]:
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
        )

    def forward(self, x):
        return self.net(x)


In [10]:
class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

        self.attn = CausalSelfAttention(n_embd, n_heads, dropout)
        self.ff   = FeedForward()

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x


In [11]:
class TransformerLM(nn.Module):
    def __init__(self):
        super().__init__()

        # token + positional embeddings
        self.token_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb   = nn.Embedding(context_len, n_embd)

        self.blocks = nn.Sequential(
            *[Block() for _ in range(n_layers)]
        )

        # final normalization + LM head
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size)

    def forward(self, x, targets=None):
        B, T = x.shape

        # embeddings
        tok = self.token_emb(x)                              # (B,T,C)
        pos = self.pos_emb(torch.arange(T, device=x.device))# (T,C)
        x = tok + pos                                        # (B,T,C)

        # APPLY ALL BLOCKS HERE
        x = self.blocks(x)

        # final projection
        x = self.ln_f(x)
        logits = self.head(x)                                # (B,T,vocab)

        loss = None
        if targets is not None:
            logits = logits.view(B*T, vocab_size)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss



In [12]:
model = TransformerLM().to(device)


In [13]:
print(sum(p.numel() for p in model.parameters()) / 1e6, "M params")


3430.409297 M params


# optimus

In [14]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=lr
)


In [15]:
@torch.no_grad()
def estimate_loss():
    model.eval()
    out = {}
    for split in ["train", "test"]:
        losses = []
        for _ in range(20):
            xb, yb = get_batch(split)
            xb, yb = xb.to(device), yb.to(device)
            _, loss = model(xb, yb)
            losses.append(loss.item())
        out[split] = sum(losses) / len(losses)
    model.train() # go back because we will be in a train loop
    return out


In [16]:
def token_stream(hf_dataset, tokenizer):
    """
    Lazily yields lists of token ids from a streaming HF dataset.
    """
    for ex in hf_dataset:
        text = ex.get("text", "")
        if not text:
            continue
        ids = tokenizer.encode(text)
        if len(ids) > 1:
            yield ids
import random
import torch

def window_stream(token_iter, context_len, device):
    """
    Yields single (x, y) training samples of shape [context_len].
    """
    buffer = []

    for ids in token_iter:
        buffer.extend(ids)

        while len(buffer) >= context_len + 1:
            start = random.randint(0, len(buffer) - context_len - 1)

            x = buffer[start : start + context_len]
            y = buffer[start + 1 : start + context_len + 1]

            yield (
                torch.tensor(x, dtype=torch.long, device=device),
                torch.tensor(y, dtype=torch.long, device=device),
            )
def batch_stream(sample_iter, batch_size):
    """
    Groups single samples into batches.
    """
    while True:
        xb, yb = zip(*(next(sample_iter) for _ in range(batch_size)))
        yield torch.stack(xb), torch.stack(yb)
from datasets import load_dataset

# load streaming dataset
ds = load_dataset(
    "monology/pile-uncopyrighted",
    split="train",
    streaming=True,
)

# tokenizer
tokenizer = get_encoding("gpt2")

# build streams
tok_iter   = token_stream(ds, tokenizer)
sample_iter = window_stream(tok_iter, context_len, device)
batch_iter  = batch_stream(sample_iter, batch_size)


README.md:   0%|          | 0.00/776 [00:00<?, ?B/s]



Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

In [17]:
import os
os.environ["HF_TOKEN"] = "hf_yeahatokenhere"


In [18]:
from huggingface_hub import HfApi, upload_file
import os
import torch
from tqdm.notebook import tqdm
import torch.nn.functional as F

HF_REPO = "345rf4gt56t4r3e3/nnn"
CKPT_DIR = "checkpoints"

os.makedirs(CKPT_DIR, exist_ok=True)

api = HfApi()

# create repo if it doesn't exist
api.create_repo(
    repo_id=HF_REPO,
    exist_ok=True,
    repo_type="model",

)


RepoUrl('https://huggingface.co/345rf4gt56t4r3e3/nnn', endpoint='https://huggingface.co', repo_type='model', repo_id='345rf4gt56t4r3e3/nnn')

In [None]:
from tqdm.notebook import tqdm
import torch
import os
from huggingface_hub import upload_file

# ---------- CONFIG ----------
CKPT_DIR   = "weights"
SAVE_EVERY = 300
# ---------------------------

os.makedirs(CKPT_DIR, exist_ok=True)

scaler = torch.cuda.amp.GradScaler()

max_steps = 9000
history = []

model.train()

for step in tqdm(range(max_steps)):
    # ---- LR schedule ----
    lr = get_lr(step, max_steps)
    for g in optimizer.param_groups:
        g["lr"] = lr

    # ---- get batch ----
    xb, yb = next(batch_iter)
    xb = xb.to(device, non_blocking=True)
    yb = yb.to(device, non_blocking=True)

    # ---- forward + backward (AMP) ----
    with torch.cuda.amp.autocast():
        _, loss = model(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    # ---- logging ----
    loss_val = loss.item()
    history.append(loss_val)

    if step % 10 == 0:
        print(f"step {step:5d} | loss {loss_val:.4f}")

    # ---- SAVE *WEIGHTS ONLY* ----
    if step > 0 and step % SAVE_EVERY == 0:
        fname = f"model_step_{step}.pt"
        fpath = os.path.join(CKPT_DIR, fname)

        # ensure inference-friendly weights
        model_to_save = model.half().cpu()

        torch.save(model_to_save.state_dict(), fpath)
        print(f"[saved weights → {fpath}]")

        # restore training state
        model.to(device).train()

        # upload versioned snapshot
        upload_file(
            path_or_fileobj=fpath,
            path_in_repo=fname,
            repo_id=HF_REPO,
            repo_type="model",
            commit_message=f"weights @ step {step}",
        )

        # update rolling pointer
        upload_file(
            path_or_fileobj=fpath,
            path_in_repo="latest.pt",
            repo_id=HF_REPO,
            repo_type="model",
            commit_message=f"update latest @ step {step}",
        )

        print(f"[uploaded → hf://{HF_REPO}/{fname}]")


  scaler = torch.cuda.amp.GradScaler()


  0%|          | 0/9000 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


step     0 | loss 11.0095
step    10 | loss 10.9892
step    20 | loss 10.9466
step    30 | loss 10.8715
step    40 | loss 10.7718
step    50 | loss 10.6325
step    60 | loss 10.4818
step    70 | loss 10.2661
step    80 | loss 10.1145
step    90 | loss 9.8985
step   100 | loss 9.6991
step   110 | loss 9.5507
step   120 | loss 9.3283
step   130 | loss 9.1491
step   140 | loss 9.0098
step   150 | loss 8.9417
step   160 | loss 8.7541
step   170 | loss 8.6928
step   180 | loss 8.5124
step   190 | loss 8.5088


In [None]:
print(xb[0][:10])
print(yb[0][:10])


In [None]:
import torch

@torch.no_grad()
def complete(
    model,
    tokenizer,
    prompt: str,
    max_new_tokens: int = 50,
    temperature: float = 1.0,
    top_k: int | None = 50,
    device: str = "cuda",
):
    model.eval()

    # encode prompt
    idx = torch.tensor(
        [tokenizer.encode(prompt)],
        dtype=torch.long,
        device=device,
    )

    for _ in range(max_new_tokens):
        # crop context if needed
        idx_cond = idx[:, -context_len :]

        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / temperature

        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("inf")

        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        idx = torch.cat([idx, next_token], dim=1)

    # decode only the completion
    completion = tokenizer.decode(idx[0].tolist())
    return completion


In [None]:
from tiktoken import get_encoding


text = complete(
    model,
    tokenizer,
    prompt="Blue is a",
    max_new_tokens=60,
    temperature=0.02,
    top_k=40,
    device=device,
)

print(text)


In [None]:
torch.cuda.empty_cache()