# 3.4B transformer based on pile uncopyrighted streamed via huggingface datasets

In [1]:
%pip install tiktoken tqdm datasets tiktoken zstandard "fsspec[compression]"

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


# RESTART THE KERNEL IF RUNNING ON RUNPOD AFTER THIS PIP INSTALL

In [2]:
# -------- MODEL PARAMS --------
n_layers    = 64
n_embd      = 2048
n_heads     = 16
context_len = 1024
batch_size  = 8    # with grad accumulation
dropout     = 0
lr          = 3e-6
from tiktoken import get_encoding
tokenizer = get_encoding("gpt2")
vocab_size  = tokenizer.n_vocab

# ------------------------------

# *cool lr sinusoid trick*

In [3]:
import math

def get_lr(step, max_steps, base_lr=lr, warmup_steps=2000):
    """
    Warmup + cosine decay learning rate schedule.

    Args:
        step (int): current training step
        max_steps (int): total training steps
        base_lr (float): peak learning rate
        warmup_steps (int): number of warmup steps

    Returns:
        float: learning rate for this step
    """
    if step < warmup_steps:
        return base_lr * step / warmup_steps

    progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
    return base_lr * 0.5 * (1.0 + math.cos(math.pi * progress))


In [4]:
import torch
from torch import nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
class TokenEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.token = nn.Embedding(vocab_size, n_embd)
        self.pos   = nn.Embedding(context_len, n_embd)

    def forward(self, x):
        B, T = x.shape
        tok = self.token(x)                    # (B, T, C)
        pos = self.pos(torch.arange(T))        # (T, C)
        return tok + pos
emb = TokenEmbedding()



# Above is the definition of our embeddings

# Now define attention

In [6]:
class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_heads, dropout=0.0):
        super().__init__()
        assert n_embd % n_heads == 0

        self.n_heads = n_heads
        self.head_dim = n_embd // n_heads

        self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
        self.proj = nn.Linear(n_embd, n_embd, bias=False)
        self.dropout = dropout

    def forward(self, x):
        B, T, C = x.shape

        qkv = self.qkv(x)                       # (B, T, 3C)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        # this is the critical line
        y = F.scaled_dot_product_attention(
            q, k, v,
            is_causal=True,
            dropout_p=self.dropout if self.training else 0.0,
        )

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.proj(y)


# Constants again and positional embeddings


In [7]:
class TokenPosEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.token = nn.Embedding(vocab_size, n_embd)
        self.pos   = nn.Embedding(context_len, n_embd)

    def forward(self, x):
        B, T = x.shape
        tok = self.token(x)                                # (B,T,C)
        pos = self.pos(torch.arange(T, device=x.device))  # (T,C)
        return tok + pos


# After attention we have basic FFN

In [8]:
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
        )

    def forward(self, x):
        return self.net(x)


In [9]:
class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

        self.attn = CausalSelfAttention(n_embd, n_heads, dropout)
        self.ff   = FeedForward()

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x


In [10]:
class TransformerLM(nn.Module):
    def __init__(self):
        super().__init__()

        # token + positional embeddings
        self.token_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb   = nn.Embedding(context_len, n_embd)

        self.blocks = nn.Sequential(
            *[Block() for _ in range(n_layers)]
        )

        # final normalization + LM head
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size)

    def forward(self, x, targets=None):
        B, T = x.shape

        # embeddings
        tok = self.token_emb(x)                              # (B,T,C)
        pos = self.pos_emb(torch.arange(T, device=x.device))# (T,C)
        x = tok + pos                                        # (B,T,C)

        # APPLY ALL BLOCKS HERE
        x = self.blocks(x)

        # final projection
        x = self.ln_f(x)
        logits = self.head(x)                                # (B,T,vocab)

        loss = None
        if targets is not None:
            logits = logits.view(B*T, vocab_size)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss



In [11]:
model = TransformerLM().to(device)


In [12]:
print(sum(p.numel() for p in model.parameters()) / 1e6, "M params")


3430.409297 M params


# optimus

In [13]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=lr
)


In [14]:
@torch.no_grad()
def estimate_loss():
    model.eval()
    out = {}
    for split in ["train", "test"]:
        losses = []
        for _ in range(20):
            xb, yb = get_batch(split)
            xb, yb = xb.to(device), yb.to(device)
            _, loss = model(xb, yb)
            losses.append(loss.item())
        out[split] = sum(losses) / len(losses)
    model.train() # go back because we will be in a train loop
    return out


In [15]:
def token_stream(hf_dataset, tokenizer):
    """
    Lazily yields lists of token ids from a streaming HF dataset.
    """
    for ex in hf_dataset:
        text = ex.get("text", "")
        if not text:
            continue
        ids = tokenizer.encode(text)
        if len(ids) > 1:
            yield ids
import random
import torch

def window_stream(token_iter, context_len, device):
    """
    Yields single (x, y) training samples of shape [context_len].
    """
    buffer = []

    for ids in token_iter:
        buffer.extend(ids)

        while len(buffer) >= context_len + 1:
            start = random.randint(0, len(buffer) - context_len - 1)

            x = buffer[start : start + context_len]
            y = buffer[start + 1 : start + context_len + 1]

            yield (
                torch.tensor(x, dtype=torch.long, device=device),
                torch.tensor(y, dtype=torch.long, device=device),
            )
def batch_stream(sample_iter, batch_size):
    """
    Groups single samples into batches.
    """
    while True:
        xb, yb = zip(*(next(sample_iter) for _ in range(batch_size)))
        yield torch.stack(xb), torch.stack(yb)
from datasets import load_dataset

# load streaming dataset
ds = load_dataset(
    "monology/pile-uncopyrighted",
    split="train",
    streaming=True,
)

# tokenizer
tokenizer = get_encoding("gpt2")

# build streams
tok_iter   = token_stream(ds, tokenizer)
sample_iter = window_stream(tok_iter, context_len, device)
batch_iter  = batch_stream(sample_iter, batch_size)


Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

In [16]:
import os
os.environ["HF_TOKEN"] = "hf_mytokenyeah"


In [19]:
from huggingface_hub import HfApi, upload_file
import os
import torch
from tqdm.notebook import tqdm
import torch.nn.functional as F

HF_REPO = "345rf4gt56t4r3e3/nnn"
CKPT_DIR = "checkpoints"

os.makedirs(CKPT_DIR, exist_ok=True)

api = HfApi()

# create repo if it doesn't exist
api.create_repo(
    repo_id=HF_REPO,
    exist_ok=True,
    repo_type="model",

)


RepoUrl('https://huggingface.co/345rf4gt56t4r3e3/nnn', endpoint='https://huggingface.co', repo_type='model', repo_id='345rf4gt56t4r3e3/nnn')

In [20]:
from tqdm.notebook import tqdm
import torch
import os
from huggingface_hub import upload_file

# ---------- CONFIG ----------
CKPT_DIR   = "weights"
SAVE_EVERY = 700
# ---------------------------

os.makedirs(CKPT_DIR, exist_ok=True)

scaler = torch.cuda.amp.GradScaler()

max_steps = 9000
history = []

model.train()

for step in tqdm(range(max_steps)):
    # ---- LR schedule ----
    lr = get_lr(step, max_steps)
    for g in optimizer.param_groups:
        g["lr"] = lr

    # ---- get batch ----
    xb, yb = next(batch_iter)
    xb = xb.to(device, non_blocking=True)
    yb = yb.to(device, non_blocking=True)

    # ---- forward + backward (AMP) ----
    with torch.cuda.amp.autocast():
        _, loss = model(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    # ---- logging ----
    loss_val = loss.item()
    history.append(loss_val)

    if step % 10 == 0:
        print(f"step {step:5d} | loss {loss_val:.4f}")

    # ---- SAVE *WEIGHTS ONLY* (SAFE) ----
    if step > 0 and step % SAVE_EVERY == 0:
        fname = f"model_step_{step}.pt"
        fpath = os.path.join(CKPT_DIR, fname)

        # ðŸ”’ SAFE FP16 EXPORT (does NOT touch training model)
        state_fp16 = {
            k: v.detach().half().cpu()
            for k, v in model.state_dict().items()
        }

        torch.save(state_fp16, fpath)
        print(f"[saved FP16 weights â†’ {fpath}]")

        # upload versioned snapshot
        upload_file(
            path_or_fileobj=fpath,
            path_in_repo=fname,
            repo_id=HF_REPO,
            repo_type="model",
            commit_message=f"weights @ step {step}",
        )

        # update rolling pointer
        upload_file(
            path_or_fileobj=fpath,
            path_in_repo="latest.pt",
            repo_id=HF_REPO,
            repo_type="model",
            commit_message=f"update latest @ step {step}",
        )

        print(f"[uploaded â†’ hf://{HF_REPO}/{fname}]")


  scaler = torch.cuda.amp.GradScaler()


  0%|          | 0/9000 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


step     0 | loss 10.9504
step    10 | loss 10.9479
step    20 | loss 10.9091
step    30 | loss 10.8186
step    40 | loss 10.7316
step    50 | loss 10.5976
step    60 | loss 10.4588
step    70 | loss 10.2639
step    80 | loss 10.0770
step    90 | loss 9.8780
step   100 | loss 9.7151
step   110 | loss 9.5252
step   120 | loss 9.3798
step   130 | loss 9.1531
step   140 | loss 9.0567
step   150 | loss 8.9283
step   160 | loss 8.7812
step   170 | loss 8.7036
step   180 | loss 8.7216
step   190 | loss 8.4415
step   200 | loss 8.3514
step   210 | loss 8.4170
step   220 | loss 8.3396
step   230 | loss 8.2484
step   240 | loss 8.1511
step   250 | loss 7.9582
step   260 | loss 8.0959
step   270 | loss 7.9606
step   280 | loss 7.8577
step   290 | loss 7.7751
step   300 | loss 7.6415
step   310 | loss 7.5857
step   320 | loss 7.5396
step   330 | loss 7.5784
step   340 | loss 7.3402
step   350 | loss 7.3432
step   360 | loss 7.2335
step   370 | loss 7.3027
step   380 | loss 7.0501
step   390 | los

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

'The read operation timed out' thrown while requesting POST https://huggingface.co/api/models/345rf4gt56t4r3e3/nnn/preupload/main
Retrying in 1s [Retry 1/5].


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

[uploaded â†’ hf://345rf4gt56t4r3e3/nnn/model_step_700.pt]
step   710 | loss 3.4508
step   720 | loss 3.0982
step   730 | loss 3.2763
step   740 | loss 3.2156
step   750 | loss 2.9231
step   760 | loss 2.9953
step   770 | loss 3.1089
step   780 | loss 2.9970
step   790 | loss 2.9663
step   800 | loss 3.0363
step   810 | loss 2.6403
step   820 | loss 2.7821
step   830 | loss 2.6905
step   840 | loss 2.6176
step   850 | loss 2.6778
step   860 | loss 2.6074
step   870 | loss 2.7695
step   880 | loss 2.5941
step   890 | loss 2.4486
step   900 | loss 2.6034
step   910 | loss 2.5522
step   920 | loss 2.4749
step   930 | loss 2.8069
step   940 | loss 2.5082
step   950 | loss 2.5039
step   960 | loss 2.6154
step   970 | loss 2.5087
step   980 | loss 2.3259
step   990 | loss 2.5937
step  1000 | loss 2.4408
step  1010 | loss 2.3217
step  1020 | loss 2.2404
step  1030 | loss 2.5770
step  1040 | loss 2.4223
step  1050 | loss 2.2085
step  1060 | loss 2.1079
step  1070 | loss 2.2974
step  1080 | los

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

[uploaded â†’ hf://345rf4gt56t4r3e3/nnn/model_step_1400.pt]
step  1410 | loss 1.4795
step  1420 | loss 1.4394
step  1430 | loss 1.7100
step  1440 | loss 1.3239
step  1450 | loss 1.3800
step  1460 | loss 1.4112
step  1470 | loss 1.3611
step  1480 | loss 1.5482
step  1490 | loss 1.3672
step  1500 | loss 1.2588
step  1510 | loss 1.2188
step  1520 | loss 1.4458
step  1530 | loss 1.3950
step  1540 | loss 1.3066
step  1550 | loss 1.1370
step  1560 | loss 1.3426
step  1570 | loss 1.2565
step  1580 | loss 1.1822
step  1590 | loss 1.2953
step  1600 | loss 1.2644
step  1610 | loss 1.4146
step  1620 | loss 1.2526
step  1630 | loss 1.1391
step  1640 | loss 1.1153
step  1650 | loss 1.2446
step  1660 | loss 0.9831
step  1670 | loss 1.1703
step  1680 | loss 1.0244
step  1690 | loss 1.1867
step  1700 | loss 0.9559
step  1710 | loss 0.9530
step  1720 | loss 0.9545
step  1730 | loss 0.9174
step  1740 | loss 0.8532
step  1750 | loss 0.9469
step  1760 | loss 0.9774
step  1770 | loss 0.9164
step  1780 | lo

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

[uploaded â†’ hf://345rf4gt56t4r3e3/nnn/model_step_2100.pt]
step  2110 | loss 0.4070
step  2120 | loss 0.4219
step  2130 | loss 0.4588
step  2140 | loss 0.3288
step  2150 | loss 0.4323
step  2160 | loss 0.3760
step  2170 | loss 0.3659
step  2180 | loss 0.3389
step  2190 | loss 0.3540
step  2200 | loss 0.2749
step  2210 | loss 0.3233
step  2220 | loss 0.2835
step  2230 | loss 0.3643
step  2240 | loss 0.3046
step  2250 | loss 0.2563
step  2260 | loss 0.2901
step  2270 | loss 0.2622
step  2280 | loss 0.1930
step  2290 | loss 0.2778
step  2300 | loss 0.3602
step  2310 | loss 0.1758
step  2320 | loss 0.2913
step  2330 | loss 0.2402
step  2340 | loss 0.1749
step  2350 | loss 0.2563
step  2360 | loss 0.2387
step  2370 | loss 0.2438
step  2380 | loss 0.2262
step  2390 | loss 0.2054
step  2400 | loss 0.1955
step  2410 | loss 0.2703
step  2420 | loss 0.1868
step  2430 | loss 0.2155
step  2440 | loss 0.1850
step  2450 | loss 0.1941
step  2460 | loss 0.2285
step  2470 | loss 0.1247
step  2480 | lo

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

[uploaded â†’ hf://345rf4gt56t4r3e3/nnn/model_step_2800.pt]
step  2810 | loss 0.0448
step  2820 | loss 0.0446
step  2830 | loss 0.0760
step  2840 | loss 0.0481
step  2850 | loss 0.0713
step  2860 | loss 0.0621
step  2870 | loss 0.0608
step  2880 | loss 0.0501
step  2890 | loss 0.0514
step  2900 | loss 0.0490
step  2910 | loss 0.0659
step  2920 | loss 0.0693
step  2930 | loss 0.0555
step  2940 | loss 0.0626
step  2950 | loss 0.0675
step  2960 | loss 0.0654
step  2970 | loss 0.0697
step  2980 | loss 0.0748
step  2990 | loss 0.0632
step  3000 | loss 0.0512
step  3010 | loss 0.0416
step  3020 | loss 0.0382
step  3030 | loss 0.0511
step  3040 | loss 0.0438
step  3050 | loss 0.0618
step  3060 | loss 0.0506
step  3070 | loss 0.0395
step  3080 | loss 0.0555
step  3090 | loss 0.0324
step  3100 | loss 0.0258
step  3110 | loss 0.0253
step  3120 | loss 0.0309
step  3130 | loss 0.0278
step  3140 | loss 0.0381
step  3150 | loss 0.0265


KeyboardInterrupt: 

In [None]:
print(xb[0][:10])
print(yb[0][:10])


In [21]:
import torch

@torch.no_grad()
def complete(
    model,
    tokenizer,
    prompt: str,
    max_new_tokens: int = 50,
    temperature: float = 1.0,
    top_k: int | None = 50,
    device: str = "cuda",
):
    model.eval()

    # encode prompt
    idx = torch.tensor(
        [tokenizer.encode(prompt)],
        dtype=torch.long,
        device=device,
    )

    for _ in range(max_new_tokens):
        # crop context if needed
        idx_cond = idx[:, -context_len :]

        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / temperature

        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("inf")

        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        idx = torch.cat([idx, next_token], dim=1)

    # decode only the completion
    completion = tokenizer.decode(idx[0].tolist())
    return completion


In [22]:
from tiktoken import get_encoding


text = complete(
    model,
    tokenizer,
    prompt="Blue is a color",
    max_new_tokens=60,
    temperature=0.8,
    top_k=40,
    device=device,
)

print(text)


Blue is a color denotes the longest time. It was great seeing other people working â€“ I had a few tabs opened on my second monitor all the time. Itâ€™s actually a bit sad, because if I could, I could have spent the whole weekend just watching other people working! But I had to do my


In [None]:
torch.cuda.empty_cache()