<a href="https://colab.research.google.com/github/DavinciDreams/JuliaGPT/blob/main/juliagpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# JuliaGPT — A Minimal GPT in Pure Julia

Faithful port of Karpathy's JuliaGPT: the most atomic way to train a GPT.  
Array-based autograd via AutoGrad.jl, transformer, Adam optimizer.  
No external dependencies beyond Julia stdlib + AutoGrad.jl.

**Architecture** (following GPT-2 with simplifications):
- AutoGrad.jl array-based automatic differentiation (`Param` wrapped matrices)
- Single-layer transformer with multi-head attention
- RMSNorm (not LayerNorm), no biases, ReLU (not GELU)
- KV cache for natural causal masking
- Adam optimizer with linear LR decay
- Temperature-controlled generation
- Best-model checkpointing with validation loss tracking (local + HuggingFace Hub)

Based on: https://gist.github.com/karpathy/8627fe009c40f57531cb18360106ce95

## 0. Login & Setup

This cell runs in **Python** to read your Colab secrets and save them for Julia.

1. Add secrets via the key icon in the left sidebar:
   - `HF_TOKEN` — your HuggingFace access token
   - `WANDB_KEY` — your Weights & Biases API key
   - `HF_REPO` — your model repo (e.g. `LisaMegaWatts/JuliaGPT`)
2. Run cells 0-1 (login + install Julia, ~3-5 min)
3. **Runtime > Change runtime type > Julia 1.10**
4. Continue with the remaining Julia cells

In [None]:
# ── Minimal Python setup: install tools + save Colab secrets for Julia ──
!pip install -q wandb huggingface_hub

import json, pathlib

secrets = {}
try:
    from google.colab import userdata
    for key in ("HF_TOKEN", "WANDB_KEY", "HF_REPO"):
        try: secrets[key] = userdata.get(key)
        except Exception: pass
except ImportError:
    pass

secrets_path = pathlib.Path.home() / ".julia_secrets.json"
secrets_path.write_text(json.dumps(secrets))
secrets_path.chmod(0o600)

found = [k for k in secrets if secrets[k]]
print(f"Secrets saved: {', '.join(found) if found else 'none found'}")
print(f"  -> {secrets_path}")
if not found:
    print("Add HF_TOKEN, WANDB_KEY, HF_REPO via the key icon in the sidebar.")

print("\nDone! Now run the next cell to install Julia (~3-5 min).")
print("Then: Runtime > Change runtime type > Julia 1.10")

## 1. Install Julia Kernel

This cell downloads and installs Julia + IJulia. **Takes ~3-5 minutes** on first run.

**After it finishes:**
1. Go to **Runtime → Change runtime type**
2. You may see both "Julia" and "Julia 1.10" — pick **Julia 1.10**
3. Continue running the cells below

In [None]:
%%shell
set -e

JULIA_VERSION="1.10.5"
JULIA_MINOR="1.10"

if [ ! -d "/usr/local/julia-${JULIA_VERSION}" ]; then
    echo "Downloading Julia ${JULIA_VERSION}..."
    wget -q https://julialang-s3.julialang.org/bin/linux/x64/${JULIA_MINOR}/julia-${JULIA_VERSION}-linux-x86_64.tar.gz
    tar xzf julia-${JULIA_VERSION}-linux-x86_64.tar.gz -C /usr/local/
    rm julia-${JULIA_VERSION}-linux-x86_64.tar.gz
    ln -sf /usr/local/julia-${JULIA_VERSION}/bin/julia /usr/local/bin/julia
    echo "Julia installed."
else
    echo "Julia already installed."
fi

julia -e '
    using Pkg
    Pkg.add("IJulia")
    Pkg.add("JSON3")
    Pkg.add("AutoGrad")
    using IJulia
    installkernel("Julia")
'

echo ""
echo "==========================================================="
echo "  Julia kernel installed!                                   "
echo "  Now: Runtime -> Change runtime type -> pick Julia 1.10       "
echo "  Then run the cells below.                              "
echo "==========================================================="

## 1b. Load Credentials + W&B / HuggingFace Helpers (Julia)

Reads tokens from `~/.julia_secrets.json` (written by the Python setup cell).
W&B logging uses a persistent Python subprocess fed JSON lines from Julia.
HuggingFace helpers use `huggingface-cli` to push/pull checkpoints.

In [None]:
using JSON3

# ── Ensure pip-installed binaries (huggingface-cli, wandb) are on PATH ──
for p in ["/usr/local/bin", joinpath(homedir(), ".local/bin"), "/root/.local/bin"]
    if isdir(p) && !occursin(p, get(ENV, "PATH", ""))
        ENV["PATH"] = p * ":" * get(ENV, "PATH", "")
    end
end

# ── Read credentials from ~/.julia_secrets.json (written by Python setup cell) ──
function load_secrets()
    path = expanduser("~/.julia_secrets.json")
    if !isfile(path)
        @warn "No secrets file found at $path — run the Python setup cell first"
        return Dict{String,String}()
    end
    raw = JSON3.read(read(path, String))
    return Dict{String,String}(string(k) => string(v) for (k, v) in pairs(raw) if !isempty(string(v)))
end

secrets = load_secrets()

# W&B
if haskey(secrets, "WANDB_KEY")
    ENV["WANDB_API_KEY"] = secrets["WANDB_KEY"]
    println("W&B API key: found")
else
    println("W&B API key: not found (add WANDB_KEY to Colab secrets)")
end

# HuggingFace token
if haskey(secrets, "HF_TOKEN")
    ENV["HF_TOKEN"] = secrets["HF_TOKEN"]
    println("HF token: found")
else
    println("HF token: not found (add HF_TOKEN to Colab secrets)")
end

# HuggingFace repo ID
HF_REPO_ID = get(secrets, "HF_REPO", "")
if !isempty(HF_REPO_ID)
    println("HF repo: ", HF_REPO_ID)
else
    println("HF repo: not set (add HF_REPO to Colab secrets or set HF_REPO_ID manually)")
end

In [None]:
# ═══════════════════════════════════════════════════════════════
# W&B logging via persistent Python subprocess
# ═══════════════════════════════════════════════════════════════

WANDB_PROJECT = "microgpt-philosophy"
WANDB_RUN_ID = "microgpt-" * join(rand('a':'z', 6))

# Write a tiny Python helper that reads JSON lines on stdin
write("_wandb_log.py", """
import wandb, json, sys, os
project = os.environ.get("WANDB_PROJECT", "microgpt-philosophy")
run_id = os.environ.get("WANDB_RUN_ID", None)
run = wandb.init(project=project, id=run_id, resume="allow",
                 config={"model": "microgpt", "architecture": "1-layer transformer"})
print(f"W&B run: {run.url}", flush=True)
for line in sys.stdin:
    line = line.strip()
    if not line:
        continue
    try:
        data = json.loads(line)
        wandb.log(data)
    except Exception as e:
        print(f"wandb log error: {e}", file=sys.stderr, flush=True)
wandb.finish()
""")

wandb_proc = nothing

function wandb_init()
    global wandb_proc, WANDB_PROJECT, WANDB_RUN_ID
    if !haskey(ENV, "WANDB_API_KEY") || isempty(ENV["WANDB_API_KEY"])
        println("W&B: skipped (no API key)")
        return
    end
    ENV["WANDB_PROJECT"] = WANDB_PROJECT
    ENV["WANDB_RUN_ID"] = WANDB_RUN_ID
    wandb_proc = open(`python3 _wandb_log.py`, "r+")
    println("W&B: initialized ($WANDB_PROJECT / $WANDB_RUN_ID)")
end

function wandb_log(; kwargs...)
    global wandb_proc
    wandb_proc === nothing && return
    metrics = Dict(string(k) => v for (k, v) in kwargs)
    try
        println(wandb_proc, JSON3.write(metrics))
        flush(wandb_proc)
    catch e
        println("W&B log error: $e")
    end
end

function wandb_finish()
    global wandb_proc
    wandb_proc === nothing && return
    try close(wandb_proc) catch end
    wandb_proc = nothing
    println("W&B: run finished")
end

# ═══════════════════════════════════════════════════════════════
# HuggingFace Hub helpers
# ═══════════════════════════════════════════════════════════════

function hf_push(repo_id::String, local_path::String; remote_path::String="")
    rp = isempty(remote_path) ? basename(local_path) : remote_path
    run(`huggingface-cli upload $repo_id $local_path $rp`)
    println("Pushed $local_path -> $repo_id/$rp")
end

function hf_pull(repo_id::String, remote_path::String; local_dir::String="checkpoints")
    mkpath(local_dir)
    run(`huggingface-cli download $repo_id $remote_path --local-dir $local_dir`)
    println("Pulled $repo_id/$remote_path -> $local_dir/")
end

function hf_push_checkpoint(repo_id::String; checkpoint_path::String="checkpoints/best_model.json")
    isfile(checkpoint_path) || error("Checkpoint not found: $checkpoint_path")
    hf_push(repo_id, checkpoint_path)
end

function hf_create_repo(repo_id::String)
    try
        run(`huggingface-cli repo create $repo_id --type model`)
        println("Created HF repo: $repo_id")
    catch
        println("HF repo already exists or creation skipped: $repo_id")
    end
end

# ═══════════════════════════════════════════════════════════════
# Sync helper: push checkpoint to HuggingFace Hub
# ═══════════════════════════════════════════════════════════════

function hf_sync(local_path::String)
    if !@isdefined(HF_REPO_ID) || isempty(HF_REPO_ID)
        return
    end
    try
        hf_push(HF_REPO_ID, local_path)
    catch e
        println("  HF sync failed: $e")
    end
end

println("W&B + HuggingFace helpers defined")

---
## 1. AutoGrad.jl Setup

Array-based automatic differentiation using AutoGrad.jl.  
Parameters are wrapped with `Param()` for gradient tracking; `@diff` builds the computation tape and `grad()` extracts gradients.

In [None]:
using Pkg
Pkg.add("JSON3")
Pkg.add("AutoGrad")

using Random
using Printf
using JSON3
using AutoGrad
using LinearAlgebra

Random.seed!(42)

println("AutoGrad.jl loaded - array-based automatic differentiation")

In [None]:
# ═══════════════════════════════════════════════════════════════
# Resource Sensing & Auto-Configuration
# Inspired by Unsloth's resource-aware training patterns:
#   - Sense available hardware (CPU cores, RAM, GPU VRAM)
#   - Limit usage to RESOURCE_LIMIT (default 50%)
#   - Configure BLAS threads, GC hints, and training params
# ═══════════════════════════════════════════════════════════════

RESOURCE_LIMIT = 0.50  # Use at most 50% of available resources

# ── CPU Detection ──
total_threads = Sys.CPU_THREADS
total_cores = div(total_threads, 2)  # assume hyperthreading
max_blas_threads = max(1, Int(floor(total_cores * RESOURCE_LIMIT)))

# Set BLAS thread count (OpenBLAS)
using LinearAlgebra
BLAS.set_num_threads(max_blas_threads)

println("=== Resource Configuration ($(Int(RESOURCE_LIMIT * 100))% limit) ===")
println("CPU: $(total_cores) cores / $(total_threads) threads")
println("  BLAS threads: $max_blas_threads ($(Int(RESOURCE_LIMIT * 100))% of $total_cores cores)")

# ── Memory Detection ──
total_ram_gb = round(Sys.total_memory() / 1024^3, digits=1)
free_ram_gb = round(Sys.free_memory() / 1024^3, digits=1)
ram_limit_gb = round(free_ram_gb * RESOURCE_LIMIT, digits=1)
println("RAM: $(total_ram_gb) GB total, $(free_ram_gb) GB free")
println("  Training limit: $(ram_limit_gb) GB")

# ── GPU Detection ──
HAS_GPU = false
GPU_NAME = ""
GPU_VRAM_GB = 0.0
GPU_VRAM_FREE_GB = 0.0
GPU_VRAM_LIMIT_GB = 0.0

try
    using CUDA
    if CUDA.functional()
        HAS_GPU = true
        dev = CUDA.device()
        GPU_NAME = CUDA.name(dev)
        GPU_VRAM_GB = round(CUDA.total_memory() / 1024^3, digits=2)
        GPU_VRAM_FREE_GB = round(CUDA.available_memory() / 1024^3, digits=2)
        GPU_VRAM_LIMIT_GB = round(GPU_VRAM_FREE_GB * RESOURCE_LIMIT, digits=2)
        # Set CUDA memory pool limit
        CUDA.pool.max_memory!(Int(floor(GPU_VRAM_LIMIT_GB * 1024^3)))
        println("GPU: $GPU_NAME ($(GPU_VRAM_GB) GB total, $(GPU_VRAM_FREE_GB) GB free)")
        println("  VRAM limit: $(GPU_VRAM_LIMIT_GB) GB")
    else
        println("GPU: CUDA available but not functional")
    end
catch e
    println("GPU: not available ($e)")
end

# ── Compute Device Selection ──
# For tiny models (<100K params), CPU is faster than GPU due to transfer overhead.
# GPU becomes beneficial for larger models or batched operations.
USE_GPU = false  # Default to CPU for this tiny model
GPU_PARAM_THRESHOLD = 100_000  # Switch to GPU above this param count

println("\nDevice: CPU (model too small for GPU benefit; threshold=$(GPU_PARAM_THRESHOLD) params)")
println("  Override: set USE_GPU = true to force GPU")

# ── GC Configuration ──
# Periodic GC between training steps to prevent memory accumulation
# (Unsloth pattern: controlled memory cleanup between batches)
GC_INTERVAL = 100  # Run GC.gc(false) every N steps

# ── Sequence Packing Configuration (Unsloth-inspired) ──
# Pack multiple short documents into block_size-length sequences
# to eliminate wasted computation on short sequences
ENABLE_PACKING = true  # Unsloth's padding-free packing pattern

println("\nOptimizations:")
println("  Sequence packing: $(ENABLE_PACKING ? "enabled" : "disabled") (Unsloth-inspired)")
println("  GC interval: every $GC_INTERVAL steps")
println("===========================================")

---
## 1c. Resource Sensing & Configuration

Auto-detects CPU, RAM, and GPU resources and configures training to use at most 50% of available capacity.
Inspired by Unsloth's resource-aware training patterns: sense hardware, set optimal BLAS threads, and configure memory limits automatically.

In [None]:
# ── Array-based neural network primitives (AutoGrad-compatible) ──
# All operations use standard Julia array ops that AutoGrad can differentiate.
# No in-place mutation inside @diff context.

# ReLU for arrays (element-wise)
relu_ag(x) = max.(x, Float32(0))

# RMSNorm for a vector
function rmsnorm_ag(x)
    n = length(x)
    ms = sum(x .* x) / n
    scale = (ms + Float32(1e-5)) ^ Float32(-0.5)
    return x .* scale
end

# Softmax for a vector (numerically stable)
function softmax_ag(logits)
    mx = maximum(logits)
    exps = exp.(logits .- mx)
    s = sum(exps)
    return exps ./ s
end

# Type-dispatch gradient densifier (replaces try/catch in hot loop)
function to_dense_grad(g)
    if g isa AutoGrad.Sparse
        Float32.(AutoGrad.full(g))
    else
        Float32.(g)
    end
end

println("Array-based primitives defined (relu_ag, rmsnorm_ag, softmax_ag, to_dense_grad)")

In [None]:
# backward! is no longer needed -- AutoGrad.jl handles backpropagation via @diff + grad()
# Usage:
#   tape = @diff loss_expression    # builds computation tape
#   g = grad(tape, some_param)      # extracts gradient for a Param
#   loss_val = value(tape)          # extracts scalar loss value
println("Using AutoGrad.jl for automatic differentiation (no manual backward! needed)")

---
## 2. Dataset — Download & Load Training Data

Downloads **54 classical philosophy texts** from Project Gutenberg and MIT Classics Archive, spanning:
- **Plato** (12 works): Republic, Apology, Symposium, Phaedo, Crito, Meno, Phaedrus, Timaeus, Laws, Gorgias, Protagoras, Theaetetus
- **Aristotle** (12 works): Categories, Ethics, Rhetoric, Physics, Metaphysics, Poetics, Politics, On the Soul, On the Heavens, Prior/Posterior Analytics, Topics, On Generation & Corruption
- **Stoics** (4 works): Marcus Aurelius Meditations, Epictetus Discourses & Enchiridion, Seneca Moral Essays
- **Roman** (4 works): Lucretius, Cicero (On Duties, Nature of Gods, On Friendship)
- **Early Modern** (6 works): Descartes, Kant, Spinoza, Hobbes, Locke, Bacon
- **Enlightenment/19th c.** (10 works): Hume, Rousseau, Nietzsche (x2), Mill (x2), Machiavelli, Emerson, Thoreau, Montaigne
- **Other** (6 works): Boethius, Diogenes/Epicurus, Latin Grammar, Schopenhauer

Each text is split into paragraphs to form the `TRAINING_DATA` document array.
Vocabulary is built dynamically from the data (a-z + punctuation + BOS).

In [None]:
using Downloads

function download_and_clean(url::String, fn::String; is_gutenberg=true)
    if !isfile(fn)
        println("Downloading $fn ...")
        try
            Downloads.download(url, fn)
        catch e
            @warn "Download failed: $url -> $e"
            return ""
        end
    end
    txt = read(fn, String)
    if is_gutenberg
        txt = replace(txt, r"(?is)^.*?\*{3}\s*START OF (THE|THIS) PROJECT GUTENBERG.*?\*{3}[\r\n]*" => "")
        txt = replace(txt, r"(?is)\*{3}\s*END OF (THE|THIS) PROJECT GUTENBERG.*$" => "")
    end
    txt = replace(txt, r"\r\n" => "\n")
    txt = replace(txt, r"\n{3,}" => "\n\n")
    return strip(txt)
end

sources = Dict(
    # ── Original sources (14) ──
    "grammar"             => ("https://www.gutenberg.org/files/15665/15665-0.txt",        "latin_grammar.txt",         true),
    "categories"          => ("https://www.gutenberg.org/ebooks/2412.txt.utf-8",          "aristotle_categories.txt",  true),
    "rhetoric"            => ("http://classics.mit.edu/Aristotle/rhetoric.mb.txt",        "aristotle_rhetoric.txt",    false),
    "prior_analytics"     => ("http://classics.mit.edu/Aristotle/prior.mb.txt",           "prior_analytics.txt",       false),
    "posterior_analytics"  => ("http://classics.mit.edu/Aristotle/posterior.mb.txt",       "posterior_analytics.txt",   false),
    "topics"              => ("http://classics.mit.edu/Aristotle/topics.mb.txt",          "topics.txt",                false),
    "boethius"            => ("https://www.gutenberg.org/files/14328/14328-0.txt",        "boethius_consolation.txt",  true),
    "heavens"             => ("http://classics.mit.edu/Aristotle/heavens.mb.txt",         "aristotle_heavens.txt",     false),
    "republic"            => ("https://www.gutenberg.org/files/1497/1497-0.txt",          "plato_republic.txt",        true),
    "apology"             => ("https://www.gutenberg.org/files/1656/1656-0.txt",          "plato_apology.txt",         true),
    "ethics"              => ("https://www.gutenberg.org/files/8438/8438-0.txt",          "aristotle_ethics.txt",      true),
    "emerson"             => ("https://www.gutenberg.org/files/2944/2944-0.txt",          "emerson_essays.txt",        true),
    "walden"              => ("https://www.gutenberg.org/files/205/205-0.txt",            "thoreau_walden.txt",        true),
    "epicurus"            => ("https://www.gutenberg.org/files/57342/57342-0.txt",        "diogenes_epicurus.txt",     true),
    # ── Plato (10 new) ──
    "plato_symposium"     => ("https://www.gutenberg.org/ebooks/1600.txt.utf-8",          "plato_symposium.txt",       true),
    "plato_phaedo"        => ("https://www.gutenberg.org/ebooks/1658.txt.utf-8",          "plato_phaedo.txt",          true),
    "plato_crito"         => ("https://www.gutenberg.org/ebooks/1657.txt.utf-8",          "plato_crito.txt",           true),
    "plato_meno"          => ("https://www.gutenberg.org/ebooks/1643.txt.utf-8",          "plato_meno.txt",            true),
    "plato_phaedrus"      => ("https://www.gutenberg.org/ebooks/1636.txt.utf-8",          "plato_phaedrus.txt",        true),
    "plato_timaeus"       => ("https://www.gutenberg.org/ebooks/1572.txt.utf-8",          "plato_timaeus.txt",         true),
    "plato_laws"          => ("https://www.gutenberg.org/ebooks/1750.txt.utf-8",          "plato_laws.txt",            true),
    "plato_gorgias"       => ("https://www.gutenberg.org/ebooks/1672.txt.utf-8",          "plato_gorgias.txt",         true),
    "plato_protagoras"    => ("https://www.gutenberg.org/ebooks/1591.txt.utf-8",          "plato_protagoras.txt",      true),
    "plato_theaetetus"    => ("https://www.gutenberg.org/ebooks/1726.txt.utf-8",          "plato_theaetetus.txt",      true),
    # ── Aristotle (6 new) ──
    "aristotle_physics"   => ("http://classics.mit.edu/Aristotle/physics.mb.txt",         "aristotle_physics.txt",     false),
    "aristotle_metaphysics" => ("http://classics.mit.edu/Aristotle/metaphysics.mb.txt",   "aristotle_metaphysics.txt", false),
    "aristotle_soul"      => ("http://classics.mit.edu/Aristotle/soul.mb.txt",            "aristotle_soul.txt",        false),
    "aristotle_poetics"   => ("https://www.gutenberg.org/files/1974/1974.txt",            "aristotle_poetics.txt",     true),
    "aristotle_politics"  => ("https://www.gutenberg.org/ebooks/6762.txt.utf-8",          "aristotle_politics.txt",    true),
    "aristotle_generation" => ("http://classics.mit.edu/Aristotle/gener_corr.mb.txt",     "aristotle_generation.txt",  false),
    # ── Stoics (4 new) ──
    "marcus_meditations"  => ("https://www.gutenberg.org/files/2680/2680-0.txt",          "marcus_meditations.txt",    true),
    "epictetus_discourses" => ("https://www.gutenberg.org/ebooks/10661.txt.utf-8",        "epictetus_discourses.txt",  true),
    "epictetus_enchiridion" => ("https://www.gutenberg.org/ebooks/45109.txt.utf-8",       "epictetus_enchiridion.txt", true),
    "seneca_moral_essays" => ("https://www.gutenberg.org/ebooks/64576.txt.utf-8",         "seneca_moral_essays.txt",   true),
    # ── Roman philosophers (4 new) ──
    "lucretius"           => ("https://www.gutenberg.org/ebooks/785.txt.utf-8",           "lucretius_nature.txt",      true),
    "cicero_duties"       => ("https://www.gutenberg.org/ebooks/47001.txt.utf-8",         "cicero_duties.txt",         true),
    "cicero_nature_gods"  => ("https://www.gutenberg.org/files/14988/14988.txt",          "cicero_nature_gods.txt",    true),
    "cicero_friendship"   => ("https://www.gutenberg.org/ebooks/2808.txt.utf-8",          "cicero_friendship.txt",     true),
    # ── Early Modern (6 new) ──
    "descartes_method"    => ("https://www.gutenberg.org/ebooks/59.txt.utf-8",            "descartes_method.txt",      true),
    "descartes_meditations" => ("https://www.gutenberg.org/ebooks/70091.txt.utf-8",       "descartes_meditations.txt", true),
    "kant_pure_reason"    => ("https://www.gutenberg.org/ebooks/4280.txt.utf-8",          "kant_pure_reason.txt",      true),
    "spinoza_ethics"      => ("https://www.gutenberg.org/ebooks/3800.txt.utf-8",          "spinoza_ethics.txt",        true),
    "hobbes_leviathan"    => ("https://www.gutenberg.org/ebooks/3207.txt.utf-8",          "hobbes_leviathan.txt",      true),
    "locke_government"    => ("https://www.gutenberg.org/ebooks/7370.txt.utf-8",          "locke_government.txt",      true),
    # ── Enlightenment & 19th century (8 new) ──
    "hume_understanding"  => ("https://www.gutenberg.org/ebooks/9662.txt.utf-8",          "hume_understanding.txt",    true),
    "rousseau_social_contract" => ("https://www.gutenberg.org/files/46333/46333-0.txt",   "rousseau_social_contract.txt", true),
    "nietzsche_beyond"    => ("https://www.gutenberg.org/ebooks/4363.txt.utf-8",          "nietzsche_beyond.txt",      true),
    "nietzsche_zarathustra" => ("https://www.gutenberg.org/ebooks/1998.txt.utf-8",        "nietzsche_zarathustra.txt", true),
    "mill_liberty"        => ("https://www.gutenberg.org/ebooks/34901.txt.utf-8",         "mill_liberty.txt",          true),
    "mill_utilitarianism" => ("https://www.gutenberg.org/ebooks/11224.txt.utf-8",         "mill_utilitarianism.txt",   true),
    "machiavelli_prince"  => ("https://www.gutenberg.org/files/57037/57037-0.txt",        "machiavelli_prince.txt",    true),
    "bacon_essays"        => ("https://www.gutenberg.org/ebooks/575.txt.utf-8",           "bacon_essays.txt",          true),
    # ── Essays (2 new) ──
    "montaigne_essays"    => ("https://www.gutenberg.org/ebooks/3600.txt.utf-8",          "montaigne_essays.txt",      true),
    "schopenhauer_essays" => ("https://www.gutenberg.org/ebooks/11945.txt.utf-8",         "schopenhauer_essays.txt",   true)
)

# Download all texts
texts = Dict{String,String}()
for (key, (url, fn, is_gut)) in sources
    texts[key] = download_and_clean(url, fn; is_gutenberg=is_gut)
end
println("Downloaded $(length(texts)) texts.")

# Unicode normalization
for (_, (_, fn, _)) in sources
    isfile(fn) || continue
    txt = read(fn, String)
    txt = replace(txt,
        "\u201c" => "\"", "\u201d" => "\"",
        "\u2018" => "'", "\u2019" => "'",
        "\u2014" => "--", "\u2013" => "-",
        "\u2026" => "...", "\u00A0" => " "
    )
    txt = replace(txt, r"\n{3,}" => "\n\n")
    open(fn, "w") do io; write(io, strip(txt)); end
end

# Split all texts into paragraphs -> TRAINING_DATA array
TRAINING_DATA = String[]
for (_, (_, fn, _)) in sources
    isfile(fn) || continue
    txt = read(fn, String)
    # Normalize to lowercase, keep a-z, space, period, newlines
    txt = lowercase(txt)
    txt = replace(txt, r"[^a-z \.\n]" => " ")
    txt = replace(txt, r"  +" => " ")
    # Split on blank lines into paragraphs
    paragraphs = split(txt, r"\n\n+")
    for p in paragraphs
        cleaned = strip(replace(String(p), r"\n" => " "))
        cleaned = replace(cleaned, r"  +" => " ")
        # Only keep paragraphs with meaningful content (20+ chars)
        if length(cleaned) >= 20
            # Truncate very long paragraphs to block_size-friendly chunks (~512 chars)
            while length(cleaned) > 512
                # Find a sentence break near 512
                cutoff = min(512, length(cleaned))
                dot_pos = findlast('.', cleaned[1:cutoff])
                if dot_pos !== nothing && dot_pos > 100
                    push!(TRAINING_DATA, strip(cleaned[1:dot_pos]))
                    cleaned = strip(cleaned[dot_pos+1:end])
                else
                    push!(TRAINING_DATA, strip(cleaned[1:cutoff]))
                    cleaned = strip(cleaned[cutoff+1:end])
                end
            end
            if length(cleaned) >= 20
                push!(TRAINING_DATA, cleaned)
            end
        end
    end
end

println("TRAINING_DATA: $(length(TRAINING_DATA)) paragraphs from 54 philosophy texts")
println("Total characters: $(sum(length, TRAINING_DATA))")
println("Sample: \"$(TRAINING_DATA[1][1:min(80, end)])...\"")

---
## 3. Neural Network Helpers

Weight matrices are `Param(Matrix{Float32})` -- AutoGrad.jl tracks gradients through array operations automatically.  
Helper functions for parameter initialization, key ordering, and checkpoint support.

In [None]:
# Helper: deterministic parameter key ordering (for checkpoint serialization)
function get_param_keys(n_layer::Int)
    keys = ["wte", "wpe", "lm_head"]
    for i in 0:n_layer-1
        append!(keys, [
            "layer$i.attn_wq", "layer$i.attn_wk", "layer$i.attn_wv", "layer$i.attn_wo",
            "layer$i.mlp_fc1", "layer$i.mlp_fc2"
        ])
    end
    return keys
end

# Helper: initialize a Param-wrapped weight matrix
function init_param(nout::Int, nin::Int; std=0.08f0)
    Param(randn(Float32, nout, nin) .* std)
end

# Helper: collect all Param objects from state_dict in deterministic order
function collect_params(state_dict, param_keys)
    ps = []
    for key in param_keys
        push!(ps, state_dict[key])
    end
    return ps
end

println("Neural network helpers defined (get_param_keys, init_param, collect_params)")

---
## 4. GPT Forward Pass

Processes one token at a time with KV cache for causal masking.  
All operations use array matrix multiplications via AutoGrad.jl.  
KV cache stores detached (plain Float32) vectors -- gradients flow only through the current token.

In [None]:
function gpt(token_id::Int, pos_id::Int,
             kv_key_mats::Vector{Matrix{Float32}},
             kv_val_mats::Vector{Matrix{Float32}},
             kv_lens::Vector{Int},
             params,
             n_layer::Int, n_head::Int, head_dim::Int)

    d_model = n_head * head_dim

    # Embedding lookup: row indexing into Param matrices
    tok_emb = params.wte[token_id, :]
    pos_emb = params.wpe[pos_id, :]
    x = tok_emb .+ pos_emb
    x = rmsnorm_ag(x)

    for li in 1:n_layer
        layer = params.layers[li]
        x_res = x
        x = rmsnorm_ag(x)

        # Linear projections: W * x where W is (d_model, d_model)
        q = layer.attn_wq * x
        k = layer.attn_wk * x
        v = layer.attn_wv * x

        # Detach K, V for cache storage (gradients don't flow through cached entries)
        k_detached = Float32.(value(k))
        v_detached = Float32.(value(v))

        # In-place column write into pre-allocated cache (plain Float32, not tracked)
        idx = kv_lens[li] + 1
        kv_key_mats[li][:, idx] = k_detached
        kv_val_mats[li][:, idx] = v_detached
        kv_lens[li] = idx

        # Zero-allocation view instead of hcat
        n_cached = idx
        K_mat = @view kv_key_mats[li][:, 1:n_cached]
        V_mat = @view kv_val_mats[li][:, 1:n_cached]

        # Multi-head attention with typed ntuple (compiler can unroll)
        head_results = ntuple(n_head) do hh
            h = hh - 1
            hs = h * head_dim + 1
            he = hs + head_dim - 1

            q_h = q[hs:he]                   # (head_dim,) tracked
            K_h = K_mat[hs:he, :]             # (head_dim, n_cached) plain Float32
            V_h = V_mat[hs:he, :]             # (head_dim, n_cached) plain Float32

            # Attention scores: K_h' * q_h / sqrt(head_dim)
            scores = (K_h' * q_h) ./ Float32(sqrt(head_dim))
            attn_w = softmax_ag(scores)

            # Weighted sum: V_h * attn_w
            V_h * attn_w
        end

        x_attn = vcat(head_results...)    # (d_model,) tracked

        # Output projection
        x = layer.attn_wo * x_attn
        x = x .+ x_res

        # MLP block
        x_res = x
        x = rmsnorm_ag(x)
        x = layer.mlp_fc1 * x
        x = relu_ag(x)
        x = layer.mlp_fc2 * x
        x = x .+ x_res
    end

    # LM head: (vocab_size, d_model) * (d_model,) = (vocab_size,)
    logits = params.lm_head * x
    return logits
end

# Helper: build type-stable params NamedTuple from state_dict
function build_params(state_dict, n_layer::Int)
    layers = ntuple(n_layer) do li
        i = li - 1
        (
            attn_wq = state_dict["layer$i.attn_wq"],
            attn_wk = state_dict["layer$i.attn_wk"],
            attn_wv = state_dict["layer$i.attn_wv"],
            attn_wo = state_dict["layer$i.attn_wo"],
            mlp_fc1 = state_dict["layer$i.mlp_fc1"],
            mlp_fc2 = state_dict["layer$i.mlp_fc2"],
        )
    end
    return (
        wte = state_dict["wte"],
        wpe = state_dict["wpe"],
        lm_head = state_dict["lm_head"],
        layers = layers,
    )
end

# Helper: allocate KV cache buffers
function alloc_kv_cache(n_layer::Int, d_model::Int, block_size::Int)
    kv_key_mats = [zeros(Float32, d_model, block_size) for _ in 1:n_layer]
    kv_val_mats = [zeros(Float32, d_model, block_size) for _ in 1:n_layer]
    kv_lens = zeros(Int, n_layer)
    return kv_key_mats, kv_val_mats, kv_lens
end

# Helper: reset KV cache for a new sequence
function reset_kv_cache!(kv_lens::Vector{Int})
    kv_lens .= 0
end

println("GPT forward pass defined (optimized: pre-allocated KV cache, typed params, unrolled heads)")

---
## 5. Checkpoint Save/Load

Save and load model weights + optimizer state as JSON.
Checkpoints are synced to HuggingFace Hub for persistence across Colab sessions.

In [None]:
function save_checkpoint(path::String, state_dict, param_keys, uchars, hyperparams;
                         adam_m=nothing, adam_v=nothing, step::Int=0,
                         lr::Float64=0.01, b1::Float64=0.85, b2::Float64=0.99,
                         best_val_loss::Float64=Inf,
                         train_losses::Vector{Float64}=Float64[],
                         val_losses::Vector{Float64}=Float64[],
                         total_steps::Int=0, num_steps_target::Int=0)

    sd_data = Dict{String,Any}()
    for k in param_keys
        W = value(state_dict[k])  # unwrap Param -> Matrix{Float32}
        sd_data[k] = [Float64.(W[i, :]) for i in 1:size(W, 1)]
    end

    # Serialize Adam state per param key
    adam_m_data = Dict{String,Any}()
    adam_v_data = Dict{String,Any}()
    if adam_m !== nothing
        for k in param_keys
            if haskey(adam_m, k)
                adam_m_data[k] = [Float64.(adam_m[k][i, :]) for i in 1:size(adam_m[k], 1)]
                adam_v_data[k] = [Float64.(adam_v[k][i, :]) for i in 1:size(adam_v[k], 1)]
            end
        end
    end

    checkpoint = Dict{String,Any}(
        "format" => "autograd_v2",
        "uchars" => [string(c) for c in uchars],
        "hyperparams" => hyperparams,
        "state_dict" => sd_data,
        "optimizer" => Dict{String,Any}(
            "adam_m" => adam_m_data,
            "adam_v" => adam_v_data,
            "step" => step,
            "lr" => lr,
            "beta1" => b1,
            "beta2" => b2
        ),
        "training" => Dict{String,Any}(
            "best_val_loss" => best_val_loss,
            "train_losses" => train_losses,
            "val_losses" => val_losses,
            "total_steps_completed" => total_steps,
            "num_steps_target" => num_steps_target
        )
    )

    # Replace Inf with large number for JSON compatibility
    if best_val_loss == Inf
        checkpoint["training"]["best_val_loss"] = 1e30
    end
    mkpath(dirname(path))
    open(path, "w") do f
        JSON3.write(f, checkpoint)
    end
    vl_str = best_val_loss == Inf ? "Inf" : @sprintf("%.4f", best_val_loss)
    println("Checkpoint saved: $path (step $step, best_val_loss=$vl_str)")
end

function load_checkpoint(path::String)
    println("Loading checkpoint from $path ...")
    raw = JSON3.read(read(path, String))

    uchars = [only(String(s)) for s in raw["uchars"]]
    BOS = length(uchars) + 1
    vocab_size = BOS

    hp = raw["hyperparams"]
    n_layer = Int(hp["n_layer"])
    n_embd = Int(hp["n_embd"])
    block_size = Int(hp["block_size"])
    n_head = Int(hp["n_head"])
    head_dim = n_embd ÷ n_head

    # Detect format version
    fmt = haskey(raw, "format") ? String(raw["format"]) : "v1"

    state_dict = Dict{String, Any}()
    for (key, matrix_rows) in pairs(raw["state_dict"])
        rows = [Float32.(collect(row)) for row in matrix_rows]
        W = vcat([reshape(r, 1, :) for r in rows]...)  # stack rows into Matrix
        state_dict[string(key)] = Param(W)
    end

    # Load Adam state
    opt_raw = raw["optimizer"]
    adam_m = Dict{String, Matrix{Float32}}()
    adam_v = Dict{String, Matrix{Float32}}()

    if fmt == "autograd_v2" && haskey(opt_raw, "adam_m")
        am_raw = opt_raw["adam_m"]
        av_raw = opt_raw["adam_v"]
        if !isempty(am_raw)
            for (key, matrix_rows) in pairs(am_raw)
                rows = [Float32.(collect(row)) for row in matrix_rows]
                adam_m[string(key)] = vcat([reshape(r, 1, :) for r in rows]...)
            end
            for (key, matrix_rows) in pairs(av_raw)
                rows = [Float32.(collect(row)) for row in matrix_rows]
                adam_v[string(key)] = vcat([reshape(r, 1, :) for r in rows]...)
            end
        end
    end

    step = Int(opt_raw["step"])
    lr = Float64(opt_raw["lr"])
    b1 = Float64(opt_raw["beta1"])
    b2 = Float64(opt_raw["beta2"])

    trn = raw["training"]
    best_val_loss = Float64(trn["best_val_loss"])
    train_losses = Float64.(collect(trn["train_losses"]))
    val_losses = Float64.(collect(trn["val_losses"]))
    total_steps = Int(trn["total_steps_completed"])
    num_steps_target = Int(trn["num_steps_target"])

    println("  vocab=$vocab_size, embd=$n_embd, layers=$n_layer, step=$step")

    return (;
        state_dict, uchars, BOS, vocab_size,
        n_layer, n_embd, block_size, n_head, head_dim,
        adam_m, adam_v, step, lr, b1, b2,
        best_val_loss, train_losses, val_losses,
        total_steps, num_steps_target
    )
end

println("Checkpoint save/load defined (autograd_v2 format)")

---
## 6. Setup — Dataset, Tokenizer, Parameters

Character-level tokenizer with a BOS token. 90/10 train/val split.

In [None]:
# ── Dataset with train/val split (ordered, not shuffled) ──
docs = copy(TRAINING_DATA)
split_idx = max(1, Int(floor(0.9 * length(docs))))
train_docs = docs[1:split_idx]
val_docs = docs[split_idx+1:end]
if isempty(val_docs)
    val_docs = docs[max(1, end-4):end]
    train_docs = docs[1:max(1, end-5)]
end
println("train: $(length(train_docs)) docs, val: $(length(val_docs)) docs")

# ── Tokenizer ──
uchars = sort(unique(join(docs)))
BOS = length(uchars) + 1
vocab_size = BOS
println("vocab size: $vocab_size ($(length(uchars)) chars + BOS)")

# O(1) character-to-index lookup (replaces O(n) findfirst)
char_to_id = Dict{Char, Int}(ch => i for (i, ch) in enumerate(uchars))

# Pre-tokenize all documents once (avoids re-tokenizing every training step)
function tokenize_doc(doc::String, char_to_id::Dict{Char,Int}, BOS::Int)
    vcat([BOS], [char_to_id[ch] for ch in doc], [BOS])
end

train_tokens = [tokenize_doc(doc, char_to_id, BOS) for doc in train_docs]
val_tokens = [tokenize_doc(doc, char_to_id, BOS) for doc in val_docs]
println("Pre-tokenized: $(length(train_tokens)) train, $(length(val_tokens)) val docs")

# ── Hyperparameters ──
n_layer    = 1
n_embd     = 16
block_size = 256
n_head     = 4
head_dim   = n_embd ÷ n_head

hyperparams = Dict{String,Any}(
    "n_layer" => n_layer, "n_embd" => n_embd,
    "block_size" => block_size, "n_head" => n_head
)

# ── Initialize parameters as Param(Matrix{Float32}) ──
state_dict = Dict{String, Any}()
state_dict["wte"]     = init_param(vocab_size, n_embd)   # (vocab_size, d_model)
state_dict["wpe"]     = init_param(block_size, n_embd)   # (block_size, d_model)
state_dict["lm_head"] = init_param(vocab_size, n_embd)   # (vocab_size, d_model)
for i in 0:n_layer-1
    state_dict["layer$i.attn_wq"]  = init_param(n_embd, n_embd)
    state_dict["layer$i.attn_wk"]  = init_param(n_embd, n_embd)
    state_dict["layer$i.attn_wv"]  = init_param(n_embd, n_embd)
    state_dict["layer$i.attn_wo"]  = init_param(n_embd, n_embd)
    state_dict["layer$i.mlp_fc1"]  = init_param(4 * n_embd, n_embd)
    state_dict["layer$i.mlp_fc2"]  = init_param(n_embd, 4 * n_embd)
end

param_keys = get_param_keys(n_layer)
all_params = collect_params(state_dict, param_keys)
total_num_params = sum(length(value(p)) for p in all_params)
println("num params: $total_num_params ($(length(all_params)) weight matrices)")

# Build type-stable params NamedTuple (avoids Dict lookup in hot loop)
params = build_params(state_dict, n_layer)
println("Type-stable params tuple built")

---
## 7. Training Loop with Validation + Best-Model Checkpointing

Adam optimizer with linear LR decay.
Validates every 50 steps, saves `best_model.json` when val loss improves.
Checkpoints sync to HuggingFace Hub automatically.
Periodic checkpoints every 200 steps.

In [None]:
function compute_val_loss(val_tokens, params, BOS, block_size, n_layer, n_head, head_dim, n_embd)
    total_loss = 0.0
    total_tokens = 0
    kv_key_mats, kv_val_mats, kv_lens = alloc_kv_cache(n_layer, n_embd, block_size)
    for tokens in val_tokens
        n = min(block_size, length(tokens) - 1)
        reset_kv_cache!(kv_lens)
        for pos in 1:n
            token_id = tokens[pos]
            target_id = tokens[pos + 1]
            # Outside @diff, Params behave like plain arrays
            logits = gpt(token_id, pos, kv_key_mats, kv_val_mats, kv_lens, params, n_layer, n_head, head_dim)
            probs = softmax_ag(logits)
            p_val = Float64.(value(probs))
            total_loss += -log(max(p_val[target_id], 1e-10))
            total_tokens += 1
        end
    end
    return total_loss / max(total_tokens, 1)
end

println("compute_val_loss defined (optimized: pre-allocated KV cache, typed params)")

In [None]:
# ── Adam optimizer state (per-parameter matrices) ──
lr, b1, b2, eps = 0.01, 0.85, 0.99, 1e-8

adam_m = Dict{String, Matrix{Float32}}()
adam_v = Dict{String, Matrix{Float32}}()
for k in param_keys
    sz = size(value(state_dict[k]))
    adam_m[k] = zeros(Float32, sz)
    adam_v[k] = zeros(Float32, sz)
end

best_val_loss = Inf
train_loss_history = Float64[]
val_loss_history = Float64[]

# ── Checkpoint paths ──
LOCAL_CKPT = "checkpoints"
mkpath(LOCAL_CKPT)

function save_and_sync(path, sd, pk, uc, hp; kwargs...)
    save_checkpoint(path, sd, pk, uc, hp; kwargs...)
    hf_sync(path)
end

# ── Initialize W&B logging (if API key is set) ──
if haskey(ENV, "WANDB_API_KEY") && !isempty(ENV["WANDB_API_KEY"])
    wandb_init()
end

# ── Optimized training loop (wrapped in function for type stability) ──
function train_loop!(state_dict, params, param_keys, train_tokens, val_tokens,
                     adam_m, adam_v, uchars, hyperparams;
                     num_steps::Int, lr::Float64, b1::Float64, b2::Float64, eps::Float64,
                     n_layer::Int, n_head::Int, head_dim::Int, n_embd::Int,
                     block_size::Int, BOS::Int,
                     best_val_loss::Float64=Inf,
                     train_loss_history::Vector{Float64}=Float64[],
                     val_loss_history::Vector{Float64}=Float64[],
                     start_step::Int=1)

    # Pre-allocate KV cache (reused across steps)
    kv_key_mats, kv_val_mats, kv_lens = alloc_kv_cache(n_layer, n_embd, block_size)

    hf_status = !isempty(get(ENV, "HF_REPO", "")) ? "HF:$(ENV["HF_REPO"])" : ((@isdefined(HF_REPO_ID) && !isempty(HF_REPO_ID)) ? "HF:$HF_REPO_ID" : "HF:(not configured)")
    println("--- training $num_steps steps (steps $start_step..$(start_step + num_steps - 1)) ---")
    println("    Local: checkpoints/  |  $hf_status")
    t_start = time()
    last_save_time = time()
    SAVE_INTERVAL = 600
    completed_steps = start_step - 1

    try
        for step in start_step:(start_step + num_steps - 1)
            completed_steps = step
            tokens = train_tokens[mod1(step, length(train_tokens))]
            n = min(block_size, length(tokens) - 1)

            # Reset KV cache for new sequence
            reset_kv_cache!(kv_lens)

            # Single @diff tape for entire sequence (KV cache mutation is safe -- detached Float32)
            tape = @diff begin
                loss_sum = Float32(0)
                for pos in 1:n
                    token_id  = tokens[pos]
                    target_id = tokens[pos + 1]
                    logits = gpt(token_id, pos, kv_key_mats, kv_val_mats, kv_lens, params, n_layer, n_head, head_dim)
                    probs = softmax_ag(logits)
                    loss_sum = loss_sum + (-log(probs[target_id]))
                end
                loss_sum / Float32(n)
            end

            avg_loss = Float64(value(tape))
            push!(train_loss_history, avg_loss)

            # Extract gradients once and apply Adam update
            lr_t = lr * (1 - (step - 1) / (start_step + num_steps - 1))
            for k in param_keys
                g = grad(tape, state_dict[k])
                if g !== nothing
                    g_dense = to_dense_grad(g)
                    adam_m[k] .= Float32(b1) .* adam_m[k] .+ Float32(1 - b1) .* g_dense
                    adam_v[k] .= Float32(b2) .* adam_v[k] .+ Float32(1 - b2) .* g_dense .^ 2
                    m_hat = adam_m[k] ./ Float32(1 - b1^step)
                    v_hat = adam_v[k] ./ Float32(1 - b2^step)
                    value(state_dict[k]) .-= Float32(lr_t) .* m_hat ./ (sqrt.(v_hat) .+ Float32(eps))
                end
            end

            # Validate + checkpoint every 50 steps
            if step % 50 == 0
                val_loss = compute_val_loss(val_tokens, params, BOS, block_size, n_layer, n_head, head_dim, n_embd)
                push!(val_loss_history, val_loss)
                elapsed = time() - t_start
                wandb_log(; step=step, train_loss=avg_loss, val_loss=val_loss, lr=lr_t)

                improved = ""
                if val_loss < best_val_loss
                    best_val_loss = val_loss
                    save_and_sync("checkpoints/best_model.json", state_dict, param_keys, uchars, hyperparams;
                        adam_m=adam_m, adam_v=adam_v, step=step,
                        lr=lr, b1=b1, b2=b2,
                        best_val_loss=best_val_loss,
                        train_losses=train_loss_history, val_losses=val_loss_history,
                        total_steps=step, num_steps_target=start_step + num_steps - 1)
                    improved = " << new best!"
                end

                @printf("step %4d / %4d | train %.4f | val %.4f | %.1fs%s\n",
                        step, start_step + num_steps - 1, avg_loss, val_loss, elapsed, improved)
            elseif step % 10 == 0
                elapsed = time() - t_start
                @printf("step %4d / %4d | train %.4f | %.1fs\n", step, start_step + num_steps - 1, avg_loss, elapsed)
            end

            # Periodic checkpoint every 100 steps
            if step % 100 == 0
                save_and_sync("checkpoints/checkpoint_latest.json", state_dict, param_keys, uchars, hyperparams;
                    adam_m=adam_m, adam_v=adam_v, step=step,
                    lr=lr, b1=b1, b2=b2,
                    best_val_loss=best_val_loss,
                    train_losses=train_loss_history, val_losses=val_loss_history,
                    total_steps=step, num_steps_target=start_step + num_steps - 1)
                last_save_time = time()
            end

            # Time-based auto-save every 10 min
            if time() - last_save_time > SAVE_INTERVAL
                save_and_sync("checkpoints/checkpoint_latest.json", state_dict, param_keys, uchars, hyperparams;
                    adam_m=adam_m, adam_v=adam_v, step=step,
                    lr=lr, b1=b1, b2=b2,
                    best_val_loss=best_val_loss,
                    train_losses=train_loss_history, val_losses=val_loss_history,
                    total_steps=step, num_steps_target=start_step + num_steps - 1)
                last_save_time = time()
                println("  [auto-save at step $step]")
            end
        end
    catch e
        if e isa InterruptException
            println("\n\nTraining interrupted at step $completed_steps!")
        else
            println("\n\nTraining error at step $completed_steps: $e")
        end
        println("Saving emergency checkpoint...")
        save_and_sync("checkpoints/checkpoint_interrupted.json", state_dict, param_keys, uchars, hyperparams;
            adam_m=adam_m, adam_v=adam_v, step=completed_steps,
            lr=lr, b1=b1, b2=b2,
            best_val_loss=best_val_loss,
            train_losses=train_loss_history, val_losses=val_loss_history,
            total_steps=completed_steps, num_steps_target=start_step + num_steps - 1)
        if !(e isa InterruptException)
            rethrow(e)
        end
    end

    elapsed = time() - t_start
    @printf("\ntraining complete in %.1f seconds\n", elapsed)

    return best_val_loss, train_loss_history, val_loss_history, completed_steps
end

# ── Run training ──
NUM_EPOCHS = 3
num_steps = clamp(NUM_EPOCHS * length(train_tokens), 1000, 50000)

best_val_loss, train_loss_history, val_loss_history, final_step = train_loop!(
    state_dict, params, param_keys, train_tokens, val_tokens,
    adam_m, adam_v, uchars, hyperparams;
    num_steps=num_steps, lr=lr, b1=b1, b2=b2, eps=eps,
    n_layer=n_layer, n_head=n_head, head_dim=head_dim, n_embd=n_embd,
    block_size=block_size, BOS=BOS,
    best_val_loss=best_val_loss,
    train_loss_history=train_loss_history,
    val_loss_history=val_loss_history)

wandb_finish()

# Final save
save_and_sync("checkpoints/final_model.json", state_dict, param_keys, uchars, hyperparams;
    adam_m=adam_m, adam_v=adam_v, step=final_step,
    lr=lr, b1=b1, b2=b2,
    best_val_loss=best_val_loss,
    train_losses=train_loss_history, val_losses=val_loss_history,
    total_steps=final_step, num_steps_target=num_steps)

println("\nCheckpoints saved locally + synced to HF. Best val loss: $(@sprintf("%.4f", best_val_loss))")

---
## 8. Inference — Hallucinated Philosophy

Generate new philosophy-like text using temperature-controlled sampling.

In [None]:
function generate_text(params, uchars, BOS, n_layer, n_head, head_dim, n_embd, block_size;
                       temperature=0.8, max_tokens=128)
    kv_key_mats, kv_val_mats, kv_lens = alloc_kv_cache(n_layer, n_embd, block_size)
    token_id = BOS
    sample = Char[]
    limit = min(max_tokens, block_size)
    for pos in 1:limit
        # No @diff needed for inference -- Params act like plain arrays outside @diff
        logits = gpt(token_id, pos, kv_key_mats, kv_val_mats, kv_lens, params, n_layer, n_head, head_dim)
        scaled = logits ./ temperature
        probs = softmax_ag(scaled)
        weights = Float64.(value(probs))

        # Categorical sampling
        r = rand()
        cum = 0.0
        token_id = 1
        for (idx, w) in enumerate(weights)
            cum += w
            if r <= cum
                token_id = idx
                break
            end
        end
        token_id == BOS && break
        push!(sample, uchars[token_id])
    end
    return String(sample)
end

println("--- inference (hallucinated philosophy) ---")
for i in 1:20
    text = generate_text(params, uchars, BOS, n_layer, n_head, head_dim, n_embd, block_size; temperature=0.8)
    @printf("sample %2d: %s\n", i, text)
end

---
## 8a. Push Model to HuggingFace Hub

Push your trained checkpoint to HuggingFace for persistence across Colab sessions.  
Set `HF_REPO_ID` in the login cell above (e.g. `"yourusername/microgpt-philosophy"`).

In [None]:
# Push checkpoint to HuggingFace Hub
# Make sure HF_REPO_ID is set in the login cell above

if @isdefined(HF_REPO_ID) && !isempty(HF_REPO_ID)
    # Create repo if it doesn't exist yet
    hf_create_repo(HF_REPO_ID)

    # Push best model checkpoint
    if isfile("checkpoints/best_model.json")
        hf_push_checkpoint(HF_REPO_ID; checkpoint_path="checkpoints/best_model.json")
    else
        println("No best_model.json found — train first!")
    end

    # Also push final model if it exists
    if isfile("checkpoints/final_model.json")
        hf_push(HF_REPO_ID, "checkpoints/final_model.json")
    end

    println("\nDone! View your model at: https://huggingface.co/$HF_REPO_ID")
else
    println("Set HF_REPO_ID in the login cell (e.g. \"yourusername/microgpt-philosophy\")")
end

---
## 8b. Pull Checkpoint from HuggingFace Hub

Download a previously pushed checkpoint to resume training in a new Colab session.  
Run this before the Resume Training cell below.

In [None]:
# Pull checkpoint from HuggingFace to resume training
# Make sure HF_REPO_ID is set in the login cell above

if @isdefined(HF_REPO_ID) && !isempty(HF_REPO_ID)
    mkpath("checkpoints")
    hf_pull(HF_REPO_ID, "best_model.json"; local_dir="checkpoints")
    println("\nReady to resume from checkpoints/best_model.json")
    println("Run the 'Resume Training' cell below.")
else
    println("Set HF_REPO_ID in the login cell (e.g. \"yourusername/microgpt-philosophy\")")
end

---
## 9. Resume Training from Checkpoint

Load a saved checkpoint and continue training for more steps.  
Skip this cell if you're training from scratch above.

In [None]:
# ── Load checkpoint ──
# Change the path to load a different checkpoint
RESUME_FROM = "checkpoints/best_model.json"
EXTRA_STEPS = clamp(length(train_docs), 500, 25000)  # ~1 extra epoch

if !isfile(RESUME_FROM)
    # Try pulling from HuggingFace
    if @isdefined(HF_REPO_ID) && !isempty(HF_REPO_ID)
        println("Checkpoint not found locally, pulling from HuggingFace...")
        hf_pull(HF_REPO_ID, basename(RESUME_FROM); local_dir="checkpoints")
    end
    isfile(RESUME_FROM) || error("Checkpoint not found: $RESUME_FROM")
end

ckpt = load_checkpoint(RESUME_FROM)
state_dict = ckpt.state_dict
uchars = ckpt.uchars
BOS = ckpt.BOS
n_layer = ckpt.n_layer
n_embd = ckpt.n_embd
block_size = ckpt.block_size
n_head = ckpt.n_head
head_dim = ckpt.head_dim

hyperparams = Dict{String,Any}(
    "n_layer" => n_layer, "n_embd" => n_embd,
    "block_size" => block_size, "n_head" => n_head
)

# Reconstruct dataset and split (ordered, not shuffled)
docs = copy(TRAINING_DATA)
split_idx = max(1, Int(floor(0.9 * length(docs))))
train_docs = docs[1:split_idx]
val_docs = docs[split_idx+1:end]
if isempty(val_docs)
    val_docs = docs[max(1, end-4):end]
    train_docs = docs[1:max(1, end-5)]
end

# Rebuild O(1) tokenizer and pre-tokenize
char_to_id = Dict{Char, Int}(ch => i for (i, ch) in enumerate(uchars))
train_tokens = [tokenize_doc(doc, char_to_id, BOS) for doc in train_docs]
val_tokens = [tokenize_doc(doc, char_to_id, BOS) for doc in val_docs]

param_keys = get_param_keys(n_layer)

# Build type-stable params tuple
params = build_params(state_dict, n_layer)

# Restore optimizer (Adam state per param key)
lr = ckpt.lr; b1 = ckpt.b1; b2 = ckpt.b2; eps = 1e-8
adam_m = if !isempty(ckpt.adam_m)
    ckpt.adam_m
else
    Dict{String, Matrix{Float32}}(k => zeros(Float32, size(value(state_dict[k]))) for k in param_keys)
end
adam_v = if !isempty(ckpt.adam_v)
    ckpt.adam_v
else
    Dict{String, Matrix{Float32}}(k => zeros(Float32, size(value(state_dict[k]))) for k in param_keys)
end

start_step = ckpt.step + 1
best_val_loss = ckpt.best_val_loss
train_loss_history = copy(ckpt.train_losses)
val_loss_history = copy(ckpt.val_losses)

# ── Initialize W&B logging (if API key is set) ──
if haskey(ENV, "WANDB_API_KEY") && !isempty(ENV["WANDB_API_KEY"])
    wandb_init()
end

println("\nResuming from step $(ckpt.step) -> training $EXTRA_STEPS more steps")
println("Best val loss so far: $(round(best_val_loss, digits=4))")

best_val_loss, train_loss_history, val_loss_history, final_step = train_loop!(
    state_dict, params, param_keys, train_tokens, val_tokens,
    adam_m, adam_v, uchars, hyperparams;
    num_steps=EXTRA_STEPS, lr=lr, b1=b1, b2=b2, eps=Float64(eps),
    n_layer=n_layer, n_head=n_head, head_dim=head_dim, n_embd=n_embd,
    block_size=block_size, BOS=BOS,
    best_val_loss=best_val_loss,
    train_loss_history=train_loss_history,
    val_loss_history=val_loss_history,
    start_step=start_step)

wandb_finish()

save_and_sync("checkpoints/final_model.json", state_dict, param_keys, uchars, hyperparams;
    adam_m=adam_m, adam_v=adam_v, step=final_step,
    lr=lr, b1=b1, b2=b2,
    best_val_loss=best_val_loss,
    train_losses=train_loss_history, val_losses=val_loss_history,
    total_steps=final_step, num_steps_target=start_step + EXTRA_STEPS - 1)

---
## 10. Download Checkpoint

Download the best model checkpoint to use with the inference server.
In Colab, use the Files panel (left sidebar) to download, or pull from HuggingFace Hub.

In [None]:
# List saved checkpoints
if isdir("checkpoints")
    files = readdir("checkpoints")
    println("Saved checkpoints:")
    for f in files
        path = joinpath("checkpoints", f)
        size_kb = round(filesize(path) / 1024, digits=1)
        println("  $path ($(size_kb) KB)")
    end
else
    println("No checkpoints directory found. Train first!")
end