<a href="https://colab.research.google.com/github/DavinciDreams/JuliaGPT/blob/main/juliagpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# JuliaGPT — A Minimal GPT in Pure Julia

Faithful port of Karpathy's JuliaGPT: the most atomic way to train a GPT.  
Array-based autograd via AutoGrad.jl, transformer, Adam optimizer.  
No external dependencies beyond Julia stdlib + AutoGrad.jl.

**Architecture** (following GPT-2 with simplifications):
- AutoGrad.jl array-based automatic differentiation (`Param` wrapped matrices)
- Single-layer transformer with multi-head attention
- RMSNorm (not LayerNorm), no biases, ReLU (not GELU)
- KV cache for natural causal masking
- Adam optimizer with linear LR decay
- Temperature-controlled generation
- Best-model checkpointing with validation loss tracking (local + HuggingFace Hub)

Based on: https://gist.github.com/karpathy/8627fe009c40f57531cb18360106ce95

## 0. Setup

**Credentials** (needed for checkpoint sync + logging):
- Option A: Upload a `.env` file with `HF_TOKEN`, `WANDB_API_KEY`, and `HF_REPO`
- Option B: Add secrets via Colab sidebar (key icon) with the same names
- Option C: Set environment variables before launching Julia

**Then just run all cells.** Colab already has Julia — no installation needed.


In [None]:
using JSON3

# ── Load environment variables from .env file ──
function load_dotenv(paths=[".env", joinpath(homedir(), ".env")])
    for path in paths
        isfile(path) || continue
        println("Loading secrets from $path")
        for line in eachline(path)
            stripped = strip(line)
            (isempty(stripped) || startswith(stripped, '#')) && continue
            m = match(r"^([A-Za-z_][A-Za-z0-9_]*)\s*=\s*(.*)", stripped)
            m === nothing && continue
            key, val = string(m[1]), strip(string(m[2]))
            if length(val) >= 2 && val[1] == val[end] && val[1] in ('"', '\'')
                val = val[2:end-1]
            end
            ENV[key] = val
        end
        return true
    end
    return false
end

load_dotenv()

# ── Read credentials from ENV ──
if haskey(ENV, "HF_TOKEN") && !isempty(ENV["HF_TOKEN"])
    println("HF token: found")
else
    println("HF token: not found (set HF_TOKEN in .env or Colab secrets)")
end

if haskey(ENV, "WANDB_API_KEY") && !isempty(ENV["WANDB_API_KEY"])
    println("W&B API key: found")
else
    println("W&B API key: not found (set WANDB_API_KEY in .env or Colab secrets)")
end

HF_REPO_ID = get(ENV, "HF_REPO", "")
if !isempty(HF_REPO_ID)
    println("HF repo: ", HF_REPO_ID)
else
    println("HF repo: not set (set HF_REPO in .env or Colab secrets)")
end


In [None]:
# ═══════════════════════════════════════════════════════════════
# HuggingFace Hub helpers (pure Julia via HTTP.jl)
# ═══════════════════════════════════════════════════════════════

using HTTP, Downloads
import Base64
using SHA
using HuggingFaceApi

# ── Download files from HuggingFace ──
function hf_download(repo_id::String, filename::String;
                     local_dir::String=".", repo_type::String="dataset")
    local_path = joinpath(local_dir, filename)
    isfile(local_path) && return local_path
    mkpath(local_dir)
    token = get(ENV, "HF_TOKEN", "")
    try
        path = HuggingFaceApi.hf_hub_download(repo_id, filename;
                    repo_type=repo_type, auth_token=token)
        cp(path, local_path; force=true)
        println("  Downloaded: $filename ($(filesize(local_path)) bytes)")
    catch e
        prefix = repo_type == "dataset" ? "datasets/" : ""
        url = "https://huggingface.co/$(prefix)$(repo_id)/resolve/main/$(filename)"
        headers = isempty(token) ? Pair{String,String}[] : ["Authorization" => "Bearer $token"]
        Downloads.download(url, local_path; headers)
        println("  Downloaded (HTTP fallback): $filename ($(filesize(local_path)) bytes)")
    end
    return local_path
end

# ── Upload files to HuggingFace ──
# Small files (<5MB): JSON commit API with base64 content
# Large files (>=5MB): Git LFS batch API
function hf_upload(repo_id::String, local_path::String;
                   remote_path::String="", repo_type::String="model")
    token = get(ENV, "HF_TOKEN", "")
    isempty(token) && (@warn "Cannot upload: no HF_TOKEN set"; return)
    rp = isempty(remote_path) ? basename(local_path) : remote_path
    prefix = repo_type == "model" ? "models" : "datasets"
    base_url = "https://huggingface.co/api/$(prefix)/$(repo_id)"
    data = read(local_path)
    auth_headers = ["Authorization" => "Bearer $token"]

    if length(data) < 5_000_000
        encoded = Base64.base64encode(data)
        headers = vcat(auth_headers, ["Content-Type" => "application/json"])
        body = JSON3.write(Dict(
            "summary" => "Upload $rp",
            "files" => [Dict("path" => rp, "content" => encoded, "encoding" => "base64")]
        ))
        try
            HTTP.post("$(base_url)/commit/main", headers, body)
            println("Pushed $local_path -> $repo_id/$rp ($(length(data)) bytes)")
        catch e
            @warn "Upload failed: $e"
        end
    else
        file_sha = bytes2hex(sha256(data))
        headers = vcat(auth_headers, ["Content-Type" => "application/json"])
        println("Uploading $(round(length(data)/1024/1024, digits=1)) MB via LFS...")
        lfs_url = "https://huggingface.co/$(repo_id).git/info/lfs/objects/batch"
        lfs_body = JSON3.write(Dict(
            "operation" => "upload",
            "transfers" => ["basic"],
            "objects" => [Dict("oid" => file_sha, "size" => length(data))]
        ))
        lfs_headers = vcat(auth_headers, [
            "Content-Type" => "application/vnd.git-lfs+json",
            "Accept" => "application/vnd.git-lfs+json"
        ])
        try
            resp = HTTP.post(lfs_url, lfs_headers, lfs_body)
            lfs_resp = JSON3.read(String(resp.body))
            obj = lfs_resp.objects[1]
            if haskey(obj, :actions) && haskey(obj.actions, :upload)
                upload_action = obj.actions.upload
                upload_hdrs = Pair{String,String}[]
                if haskey(upload_action, :header)
                    for (k, v) in pairs(upload_action.header)
                        push!(upload_hdrs, string(k) => string(v))
                    end
                end
                HTTP.put(string(upload_action.href), upload_hdrs, data)
            end
            commit_body = JSON3.write(Dict(
                "summary" => "Upload $rp ($(round(length(data)/1024/1024, digits=1)) MB)",
                "lfsFiles" => [Dict("path" => rp, "algo" => "sha256", "oid" => file_sha, "size" => length(data))]
            ))
            HTTP.post("$(base_url)/commit/main", headers, commit_body)
            println("Pushed $local_path -> $repo_id/$rp ($(length(data)) bytes, LFS)")
        catch e
            @warn "LFS upload failed: $e"
        end
    end
end

function hf_push(repo_id::String, local_path::String; remote_path::String="")
    hf_upload(repo_id, local_path; remote_path)
end

function hf_pull(repo_id::String, filename::String; local_dir::String="checkpoints")
    hf_download(repo_id, filename; local_dir=local_dir, repo_type="model")
end

function hf_push_checkpoint(repo_id::String; checkpoint_path::String="checkpoints/best_model.json")
    isfile(checkpoint_path) || error("Checkpoint not found: $checkpoint_path")
    hf_push(repo_id, checkpoint_path)
end

function hf_create_repo(repo_id::String)
    token = get(ENV, "HF_TOKEN", "")
    isempty(token) && return
    parts = split(repo_id, "/")
    name = length(parts) >= 2 ? parts[end] : repo_id
    try
        HTTP.post("https://huggingface.co/api/repos/create",
            ["Authorization" => "Bearer $token", "Content-Type" => "application/json"],
            JSON3.write(Dict("name" => name, "type" => "model", "private" => false)))
        println("Created HF repo: $repo_id")
    catch
        println("HF repo already exists or creation skipped: $repo_id")
    end
end

function hf_sync(local_path::String)
    if !@isdefined(HF_REPO_ID) || isempty(HF_REPO_ID)
        return
    end
    try
        hf_push(HF_REPO_ID, local_path)
    catch e
        println("  HF sync failed: $e")
    end
end

# ═══════════════════════════════════════════════════════════════
# W&B logging (pure Julia — logs to console, no Python needed)
# ═══════════════════════════════════════════════════════════════

wandb_init() = nothing
wandb_log(; kwargs...) = nothing
wandb_finish() = nothing

println("HuggingFace + logging helpers defined (pure Julia, LFS-capable)")

In [None]:
import Pkg
Pkg.add(["JSON3", "AutoGrad", "HTTP", "HuggingFaceApi", "SHA", "Downloads", "CUDA"])

---
## 1. AutoGrad.jl Setup

Array-based automatic differentiation using AutoGrad.jl.  
Parameters are wrapped with `Param()` for gradient tracking; `@diff` builds the computation tape and `grad()` extracts gradients.

In [None]:
using Random
using Printf
using JSON3
using AutoGrad
using LinearAlgebra
using Downloads

Random.seed!(42)

println("AutoGrad.jl loaded - array-based automatic differentiation")


In [None]:
RESOURCE_LIMIT = 0.50  # Use at most 50% of available resources

# ── CPU Detection ──
total_threads = Sys.CPU_THREADS
total_cores = div(total_threads, 2)  # assume hyperthreading
max_blas_threads = max(1, Int(floor(total_cores * RESOURCE_LIMIT)))

using LinearAlgebra
BLAS.set_num_threads(max_blas_threads)

println("=== Resource Detection ===")
println("CPU: $(total_cores) cores / $(total_threads) threads → BLAS: $max_blas_threads threads")

# ── Memory Detection ──
total_ram_gb = round(Sys.total_memory() / 1024^3, digits=1)
free_ram_gb = round(Sys.free_memory() / 1024^3, digits=1)
ram_limit_gb = round(free_ram_gb * RESOURCE_LIMIT, digits=1)
println("RAM: $(total_ram_gb) GB total, $(free_ram_gb) GB free → limit: $(ram_limit_gb) GB")

# ── GPU Detection (informational — AutoGrad.jl is CPU-only) ──
HAS_GPU = false
GPU_NAME = ""
GPU_VRAM_GB = 0.0
try
    using CUDA
    if CUDA.functional()
        HAS_GPU = true
        GPU_NAME = CUDA.name(CUDA.device())
        GPU_VRAM_GB = round(CUDA.total_memory() / 1024^3, digits=1)
        println("GPU: $GPU_NAME ($(GPU_VRAM_GB) GB) — detected but AutoGrad.jl runs on CPU")
    end
catch; end
!HAS_GPU && println("GPU: not available")

# ═══════════════════════════════════════════════════════════════
# Dynamic Compute Tier — scale model to available resources
# AutoGrad.jl wraps plain Julia arrays (no CuArray support),
# so we scale by RAM, not GPU VRAM.
# ═══════════════════════════════════════════════════════════════

if free_ram_gb >= 16.0
    COMPUTE_TIER = :large
    n_layer    = 4
    n_embd     = 256
    n_head     = 8
    block_size = 256
elseif free_ram_gb >= 8.0
    COMPUTE_TIER = :medium
    n_layer    = 2
    n_embd     = 128
    n_head     = 4
    block_size = 256
else
    COMPUTE_TIER = :small
    n_layer    = 1
    n_embd     = 64
    n_head     = 4
    block_size = 256
end

head_dim = n_embd ÷ n_head

# ── Training defaults (overridden after tokenization by data size) ──
target_epochs  = 3
lr             = 3e-3
min_lr         = 1e-5
b1             = 0.9
b2             = 0.999
eps            = 1e-8
max_grad_norm  = 1.0
warmup_frac    = 0.05

# ── GC Configuration ──
GC_INTERVAL = 100

println("\n=== Compute Tier: $COMPUTE_TIER ===")
println("  n_layer=$n_layer, n_embd=$n_embd, n_head=$n_head, head_dim=$head_dim")
println("  block_size=$block_size, lr=$lr, grad_clip=$max_grad_norm")
println("===================================")

---
## 1c. Resource Sensing & Configuration

Auto-detects CPU, RAM, and GPU resources and configures training to use at most 50% of available capacity.
Inspired by Unsloth's resource-aware training patterns: sense hardware, set optimal BLAS threads, and configure memory limits automatically.

In [None]:
# ── Array-based neural network primitives (AutoGrad-compatible) ──
# All operations use standard Julia array ops that AutoGrad can differentiate.
# No in-place mutation inside @diff context.

# ReLU for arrays (element-wise)
relu_ag(x) = max.(x, Float32(0))

# RMSNorm for a vector
function rmsnorm_ag(x)
    n = length(x)
    ms = sum(x .* x) / n
    scale = (ms + Float32(1e-5)) ^ Float32(-0.5)
    return x .* scale
end

# Softmax for a vector (numerically stable)
function softmax_ag(logits)
    mx = maximum(logits)
    exps = exp.(logits .- mx)
    s = sum(exps)
    return exps ./ s
end

# Type-dispatch gradient densifier (replaces try/catch in hot loop)
function to_dense_grad(g)
    if g isa AutoGrad.Sparse
        Float32.(AutoGrad.full(g))
    else
        Float32.(g)
    end
end

println("Array-based primitives defined (relu_ag, rmsnorm_ag, softmax_ag, to_dense_grad)")

In [None]:
# backward! is no longer needed -- AutoGrad.jl handles backpropagation via @diff + grad()
# Usage:
#   tape = @diff loss_expression    # builds computation tape
#   g = grad(tape, some_param)      # extracts gradient for a Param
#   loss_val = value(tape)          # extracts scalar loss value
println("Using AutoGrad.jl for automatic differentiation (no manual backward! needed)")

---
## 2. Dataset — Download & Load Training Data

Downloads **54 classical philosophy texts** from Project Gutenberg and MIT Classics Archive, spanning:
- **Plato** (12 works): Republic, Apology, Symposium, Phaedo, Crito, Meno, Phaedrus, Timaeus, Laws, Gorgias, Protagoras, Theaetetus
- **Aristotle** (12 works): Categories, Ethics, Rhetoric, Physics, Metaphysics, Poetics, Politics, On the Soul, On the Heavens, Prior/Posterior Analytics, Topics, On Generation & Corruption
- **Stoics** (4 works): Marcus Aurelius Meditations, Epictetus Discourses & Enchiridion, Seneca Moral Essays
- **Roman** (4 works): Lucretius, Cicero (On Duties, Nature of Gods, On Friendship)
- **Early Modern** (6 works): Descartes, Kant, Spinoza, Hobbes, Locke, Bacon
- **Enlightenment/19th c.** (10 works): Hume, Rousseau, Nietzsche (x2), Mill (x2), Machiavelli, Emerson, Thoreau, Montaigne
- **Other** (6 works): Boethius, Diogenes/Epicurus, Latin Grammar, Schopenhauer

Each text is split into paragraphs to form the `TRAINING_DATA` document array.
Vocabulary is built dynamically from the data (a-z + punctuation + BOS).

In [None]:
# ── Load data from HuggingFace dataset repo ──
# Same pre-cleaned philosophy corpus used by juliaflux_v2
# (produced by the text-pipeline: clean → chunk → train.txt/val.txt)

DATA_DIR = "data"
mkpath(DATA_DIR)

for f in ["train.txt", "val.txt"]
    if !isfile(joinpath(DATA_DIR, f))
        println("Downloading $f from $HF_DATA_REPO ...")
        hf_download(HF_DATA_REPO, f; local_dir=DATA_DIR, repo_type="dataset")
    end
end

# Try to get BPE tokenizer
tokenizer_file = joinpath(DATA_DIR, "tokenizer.json")
if !isfile(tokenizer_file)
    try
        hf_download(HF_DATA_REPO, "tokenizer.json"; local_dir=DATA_DIR, repo_type="dataset")
    catch
        println("  No tokenizer.json available (will use char-level tokenizer)")
    end
end

train_text = read(joinpath(DATA_DIR, "train.txt"), String)
val_text = read(joinpath(DATA_DIR, "val.txt"), String)

println("Data loaded from $HF_DATA_REPO:")
println("  train.txt: $(length(train_text)) chars")
println("  val.txt:   $(length(val_text)) chars")

# Data sanity check
unique_chars = sort(unique(filter(c -> c != '\n', train_text)))
n_punct = count(c -> c in ".,;:?!'\"-()[]", train_text)
println("  unique chars: $(length(unique_chars))  |  punctuation: $(n_punct) ($(round(100*n_punct/length(train_text), digits=1))%)")
if length(unique_chars) < 10
    @warn "Very few unique characters ($(length(unique_chars))) — data may be over-cleaned"
end

# ── Split into paragraphs/chunks for token-by-token processing ──
# The HF data has newline-separated chunks; each becomes a training document
TRAINING_DATA = String[]
for text in [train_text, val_text]
    for chunk in split(text, '\n')
        cleaned = strip(String(chunk))
        if length(cleaned) >= 20
            # Respect block_size: split long chunks
            while length(cleaned) > block_size * 2
                cutoff = min(block_size * 2, length(cleaned))
                dot_pos = findlast('.', cleaned[1:cutoff])
                if dot_pos !== nothing && dot_pos > 50
                    push!(TRAINING_DATA, strip(cleaned[1:dot_pos]))
                    cleaned = strip(cleaned[dot_pos+1:end])
                else
                    push!(TRAINING_DATA, strip(cleaned[1:cutoff]))
                    cleaned = strip(cleaned[cutoff+1:end])
                end
            end
            if length(cleaned) >= 20
                push!(TRAINING_DATA, cleaned)
            end
        end
    end
end

println("\nTRAINING_DATA: $(length(TRAINING_DATA)) chunks")
println("Total characters: $(sum(length, TRAINING_DATA))")
println("Sample: \"$(TRAINING_DATA[1][1:min(80, end)])...\"")

---
## 3. Neural Network Helpers

Weight matrices are `Param(Matrix{Float32})` -- AutoGrad.jl tracks gradients through array operations automatically.  
Helper functions for parameter initialization, key ordering, and checkpoint support.

In [None]:
# Helper: deterministic parameter key ordering (for checkpoint serialization)
function get_param_keys(n_layer::Int)
    keys = ["wte", "wpe", "lm_head"]
    for i in 0:n_layer-1
        append!(keys, [
            "layer$i.attn_wq", "layer$i.attn_wk", "layer$i.attn_wv", "layer$i.attn_wo",
            "layer$i.mlp_fc1", "layer$i.mlp_fc2"
        ])
    end
    return keys
end

# Helper: initialize a Param-wrapped weight matrix
function init_param(nout::Int, nin::Int; std=0.08f0)
    Param(randn(Float32, nout, nin) .* std)
end

# Helper: collect all Param objects from state_dict in deterministic order
function collect_params(state_dict, param_keys)
    ps = []
    for key in param_keys
        push!(ps, state_dict[key])
    end
    return ps
end

println("Neural network helpers defined (get_param_keys, init_param, collect_params)")

---
## 4. GPT Forward Pass

Processes one token at a time with KV cache for causal masking.  
All operations use array matrix multiplications via AutoGrad.jl.  
KV cache stores detached (plain Float32) vectors -- gradients flow only through the current token.

In [None]:
function gpt(token_id::Int, pos_id::Int,
             kv_key_mats::Vector{Matrix{Float32}},
             kv_val_mats::Vector{Matrix{Float32}},
             kv_lens::Vector{Int},
             params,
             n_layer::Int, n_head::Int, head_dim::Int)

    d_model = n_head * head_dim

    # Embedding lookup: row indexing into Param matrices
    tok_emb = params.wte[token_id, :]
    pos_emb = params.wpe[pos_id, :]
    x = tok_emb .+ pos_emb
    x = rmsnorm_ag(x)

    for li in 1:n_layer
        layer = params.layers[li]
        x_res = x
        x = rmsnorm_ag(x)

        # Linear projections: W * x where W is (d_model, d_model)
        q = layer.attn_wq * x
        k = layer.attn_wk * x
        v = layer.attn_wv * x

        # Detach K, V for cache storage (gradients don't flow through cached entries)
        k_detached = Float32.(value(k))
        v_detached = Float32.(value(v))

        # In-place column write into pre-allocated cache (plain Float32, not tracked)
        idx = kv_lens[li] + 1
        kv_key_mats[li][:, idx] = k_detached
        kv_val_mats[li][:, idx] = v_detached
        kv_lens[li] = idx

        # Zero-allocation view instead of hcat
        n_cached = idx
        K_mat = @view kv_key_mats[li][:, 1:n_cached]
        V_mat = @view kv_val_mats[li][:, 1:n_cached]

        # Multi-head attention with typed ntuple (compiler can unroll)
        head_results = ntuple(n_head) do hh
            h = hh - 1
            hs = h * head_dim + 1
            he = hs + head_dim - 1

            q_h = q[hs:he]                   # (head_dim,) tracked
            K_h = K_mat[hs:he, :]             # (head_dim, n_cached) plain Float32
            V_h = V_mat[hs:he, :]             # (head_dim, n_cached) plain Float32

            # Attention scores: K_h' * q_h / sqrt(head_dim)
            scores = (K_h' * q_h) ./ Float32(sqrt(head_dim))
            attn_w = softmax_ag(scores)

            # Weighted sum: V_h * attn_w
            V_h * attn_w
        end

        x_attn = vcat(head_results...)    # (d_model,) tracked

        # Output projection
        x = layer.attn_wo * x_attn
        x = x .+ x_res

        # MLP block
        x_res = x
        x = rmsnorm_ag(x)
        x = layer.mlp_fc1 * x
        x = relu_ag(x)
        x = layer.mlp_fc2 * x
        x = x .+ x_res
    end

    # LM head: (vocab_size, d_model) * (d_model,) = (vocab_size,)
    logits = params.lm_head * x
    return logits
end

# Helper: build type-stable params NamedTuple from state_dict
function build_params(state_dict, n_layer::Int)
    layers = ntuple(n_layer) do li
        i = li - 1
        (
            attn_wq = state_dict["layer$i.attn_wq"],
            attn_wk = state_dict["layer$i.attn_wk"],
            attn_wv = state_dict["layer$i.attn_wv"],
            attn_wo = state_dict["layer$i.attn_wo"],
            mlp_fc1 = state_dict["layer$i.mlp_fc1"],
            mlp_fc2 = state_dict["layer$i.mlp_fc2"],
        )
    end
    return (
        wte = state_dict["wte"],
        wpe = state_dict["wpe"],
        lm_head = state_dict["lm_head"],
        layers = layers,
    )
end

# Helper: allocate KV cache buffers
function alloc_kv_cache(n_layer::Int, d_model::Int, block_size::Int)
    kv_key_mats = [zeros(Float32, d_model, block_size) for _ in 1:n_layer]
    kv_val_mats = [zeros(Float32, d_model, block_size) for _ in 1:n_layer]
    kv_lens = zeros(Int, n_layer)
    return kv_key_mats, kv_val_mats, kv_lens
end

# Helper: reset KV cache for a new sequence
function reset_kv_cache!(kv_lens::Vector{Int})
    kv_lens .= 0
end

println("GPT forward pass defined (optimized: pre-allocated KV cache, typed params, unrolled heads)")

---
## 5. Checkpoint Save/Load

Save and load model weights + optimizer state as JSON.
Checkpoints are synced to HuggingFace Hub for persistence across Colab sessions.

In [None]:
function save_checkpoint(path::String, state_dict, param_keys, uchars, hyperparams;
                         adam_m=nothing, adam_v=nothing, step::Int=0,
                         lr::Float64=0.01, b1::Float64=0.85, b2::Float64=0.99,
                         best_val_loss::Float64=Inf,
                         train_losses::Vector{Float64}=Float64[],
                         val_losses::Vector{Float64}=Float64[],
                         total_steps::Int=0, num_steps_target::Int=0)

    sd_data = Dict{String,Any}()
    for k in param_keys
        W = value(state_dict[k])  # unwrap Param -> Matrix{Float32}
        sd_data[k] = [Float64.(W[i, :]) for i in 1:size(W, 1)]
    end

    # Serialize Adam state per param key
    adam_m_data = Dict{String,Any}()
    adam_v_data = Dict{String,Any}()
    if adam_m !== nothing
        for k in param_keys
            if haskey(adam_m, k)
                adam_m_data[k] = [Float64.(adam_m[k][i, :]) for i in 1:size(adam_m[k], 1)]
                adam_v_data[k] = [Float64.(adam_v[k][i, :]) for i in 1:size(adam_v[k], 1)]
            end
        end
    end

    checkpoint = Dict{String,Any}(
        "format" => "autograd_v2",
        "uchars" => [string(c) for c in uchars],
        "hyperparams" => hyperparams,
        "state_dict" => sd_data,
        "optimizer" => Dict{String,Any}(
            "adam_m" => adam_m_data,
            "adam_v" => adam_v_data,
            "step" => step,
            "lr" => lr,
            "beta1" => b1,
            "beta2" => b2
        ),
        "training" => Dict{String,Any}(
            "best_val_loss" => best_val_loss,
            "train_losses" => train_losses,
            "val_losses" => val_losses,
            "total_steps_completed" => total_steps,
            "num_steps_target" => num_steps_target
        )
    )

    # Replace Inf with large number for JSON compatibility
    if best_val_loss == Inf
        checkpoint["training"]["best_val_loss"] = 1e30
    end
    mkpath(dirname(path))
    open(path, "w") do f
        JSON3.write(f, checkpoint)
    end
    vl_str = best_val_loss == Inf ? "Inf" : @sprintf("%.4f", best_val_loss)
    println("Checkpoint saved: $path (step $step, best_val_loss=$vl_str)")
end

function load_checkpoint(path::String)
    println("Loading checkpoint from $path ...")
    raw = JSON3.read(read(path, String))

    uchars = [only(String(s)) for s in raw["uchars"]]
    BOS = length(uchars) + 1
    vocab_size = BOS

    hp = raw["hyperparams"]
    n_layer = Int(hp["n_layer"])
    n_embd = Int(hp["n_embd"])
    block_size = Int(hp["block_size"])
    n_head = Int(hp["n_head"])
    head_dim = n_embd ÷ n_head

    # Detect format version
    fmt = haskey(raw, "format") ? String(raw["format"]) : "v1"

    state_dict = Dict{String, Any}()
    for (key, matrix_rows) in pairs(raw["state_dict"])
        rows = [Float32.(collect(row)) for row in matrix_rows]
        W = vcat([reshape(r, 1, :) for r in rows]...)  # stack rows into Matrix
        state_dict[string(key)] = Param(W)
    end

    # Load Adam state
    opt_raw = raw["optimizer"]
    adam_m = Dict{String, Matrix{Float32}}()
    adam_v = Dict{String, Matrix{Float32}}()

    if fmt == "autograd_v2" && haskey(opt_raw, "adam_m")
        am_raw = opt_raw["adam_m"]
        av_raw = opt_raw["adam_v"]
        if !isempty(am_raw)
            for (key, matrix_rows) in pairs(am_raw)
                rows = [Float32.(collect(row)) for row in matrix_rows]
                adam_m[string(key)] = vcat([reshape(r, 1, :) for r in rows]...)
            end
            for (key, matrix_rows) in pairs(av_raw)
                rows = [Float32.(collect(row)) for row in matrix_rows]
                adam_v[string(key)] = vcat([reshape(r, 1, :) for r in rows]...)
            end
        end
    end

    step = Int(opt_raw["step"])
    lr = Float64(opt_raw["lr"])
    b1 = Float64(opt_raw["beta1"])
    b2 = Float64(opt_raw["beta2"])

    trn = raw["training"]
    best_val_loss = Float64(trn["best_val_loss"])
    train_losses = Float64.(collect(trn["train_losses"]))
    val_losses = Float64.(collect(trn["val_losses"]))
    total_steps = Int(trn["total_steps_completed"])
    num_steps_target = Int(trn["num_steps_target"])

    println("  vocab=$vocab_size, embd=$n_embd, layers=$n_layer, step=$step")

    return (;
        state_dict, uchars, BOS, vocab_size,
        n_layer, n_embd, block_size, n_head, head_dim,
        adam_m, adam_v, step, lr, b1, b2,
        best_val_loss, train_losses, val_losses,
        total_steps, num_steps_target
    )
end

println("Checkpoint save/load defined (autograd_v2 format)")

---
## 6. Setup — Dataset, Tokenizer, Parameters

Character-level tokenizer with a BOS token. 90/10 train/val split.

In [None]:
# ── Dataset with train/val split (ordered, not shuffled) ──
docs = copy(TRAINING_DATA)
split_idx = max(1, Int(floor(0.9 * length(docs))))
train_docs = docs[1:split_idx]
val_docs = docs[split_idx+1:end]
if isempty(val_docs)
    val_docs = docs[max(1, end-4):end]
    train_docs = docs[1:max(1, end-5)]
end
println("train: $(length(train_docs)) docs, val: $(length(val_docs)) docs")

# ── Tokenizer: BPE with character-level fallback ──
TOKENIZER_PATH = joinpath(DATA_DIR, "tokenizer.json")
USE_BPE = isfile(TOKENIZER_PATH)

if USE_BPE
    println("Loading BPE tokenizer from $TOKENIZER_PATH ...")
    tok_raw = read(TOKENIZER_PATH, String)
    tok_json = JSON3.read(tok_raw)

    bpe_vocab = Dict{String, Int}()
    for (tok_str, id) in pairs(tok_json.model.vocab)
        bpe_vocab[string(tok_str)] = Int(id) + 1  # 1-indexed for Julia
    end

    bpe_merges = Vector{Tuple{String,String}}()
    for merge_str in tok_json.model.merges
        parts = split(string(merge_str), " ", limit=2)
        if length(parts) == 2
            push!(bpe_merges, (String(parts[1]), String(parts[2])))
        end
    end

    bpe_id_to_token = Dict{Int, String}(id => tok for (tok, id) in bpe_vocab)
    BOS = length(bpe_vocab) + 1
    vocab_size = BOS

    # ── GPT-2 byte-to-unicode mapping ──
    # HuggingFace ByteLevel BPE maps each byte to a printable unicode char.
    function build_byte_to_unicode()
        bs = UInt8[]
        cs = Char[]
        for r in [0x21:0x7e, 0xa1:0xac, 0xae:0xff]
            for b in r
                push!(bs, b)
                push!(cs, Char(b))
            end
        end
        n = 0
        for b in 0x00:0xff
            if b ∉ bs
                push!(bs, b)
                push!(cs, Char(256 + n))
                n += 1
            end
        end
        return Dict(bs[i] => string(cs[i]) for i in eachindex(bs))
    end

    const _b2u = build_byte_to_unicode()
    const _u2b = Dict(v[1] => k for (k, v) in _b2u)

    # ── BPE merge (per word, not whole corpus!) ──
    function bpe_encode_word(word::Vector{String})
        tokens = copy(word)
        for (a, b) in bpe_merges
            new_tokens = String[]
            i = 1
            while i <= length(tokens)
                if i < length(tokens) && tokens[i] == a && tokens[i+1] == b
                    push!(new_tokens, a * b)
                    i += 2
                else
                    push!(new_tokens, tokens[i])
                    i += 1
                end
            end
            tokens = new_tokens
            length(tokens) <= 1 && break
        end
        return tokens
    end

    # ── GPT-2 pre-tokenization regex ──
    const _GPT2_PAT = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"

    function tokenize_doc(doc::String, _unused1, BOS::Int)
        ids = [BOS]
        for m in eachmatch(_GPT2_PAT, doc)
            word_bytes = Vector{UInt8}(m.match)
            chars = [_b2u[b] for b in word_bytes]
            tokens = bpe_encode_word(chars)
            for tok in tokens
                id = get(bpe_vocab, tok, nothing)
                id !== nothing && push!(ids, id)
            end
        end
        push!(ids, BOS)
        return ids
    end

    # For decode in inference
    uchars = Char[]  # not used with BPE, but checkpoint format expects it
    function decode_ids(ids::Vector{Int})
        text = join(get(bpe_id_to_token, id, "") for id in ids if id != BOS)
        bytes = UInt8[get(_u2b, c, UInt8(c)) for c in text if haskey(_u2b, c)]
        return String(bytes)
    end

    println("BPE tokenizer: vocab_size=$vocab_size ($(length(bpe_vocab)) tokens + BOS)")
    println("  $(length(bpe_merges)) merges")

else
    # ── Fallback: character-level tokenizer ──
    println("No tokenizer.json — using character-level tokenizer")

    uchars = sort(unique(join(docs)))
    BOS = length(uchars) + 1
    vocab_size = BOS
    char_to_id = Dict{Char, Int}(ch => i for (i, ch) in enumerate(uchars))

    function tokenize_doc(doc::String, char_to_id::Dict{Char,Int}, BOS::Int)
        vcat([BOS], [char_to_id[ch] for ch in doc if haskey(char_to_id, ch)], [BOS])
    end

    function decode_ids(ids::Vector{Int})
        join(id <= length(uchars) ? uchars[id] : ' ' for id in ids if id != BOS)
    end

    println("Char-level tokenizer: vocab_size=$vocab_size ($(length(uchars)) chars + BOS)")
end

# ── Pre-tokenize all documents ──
println("Pre-tokenizing documents...")
t_enc = time()
if USE_BPE
    train_tokens = [tokenize_doc(doc, nothing, BOS) for doc in train_docs]
    val_tokens = [tokenize_doc(doc, nothing, BOS) for doc in val_docs]
else
    train_tokens = [tokenize_doc(doc, char_to_id, BOS) for doc in train_docs]
    val_tokens = [tokenize_doc(doc, char_to_id, BOS) for doc in val_docs]
end
println("Tokenization done in $(round(time() - t_enc, digits=1))s")
println("Pre-tokenized: $(length(train_tokens)) train, $(length(val_tokens)) val docs")

total_train_tokens = sum(length, train_tokens)
avg_doc_len = round(total_train_tokens / length(train_tokens), digits=1)
println("Total train tokens: $total_train_tokens (avg doc: $avg_doc_len tokens)")

# ── Hyperparameters (n_layer, n_embd, n_head, head_dim set in cell 8) ──
hyperparams = Dict{String,Any}(
    "n_layer" => n_layer, "n_embd" => n_embd,
    "block_size" => block_size, "n_head" => n_head,
    "use_bpe" => USE_BPE, "compute_tier" => string(COMPUTE_TIER)
)

# ── Initialize parameters as Param(Matrix{Float32}) ──
state_dict = Dict{String, Any}()
state_dict["wte"]     = init_param(vocab_size, n_embd)
state_dict["wpe"]     = init_param(block_size, n_embd)
state_dict["lm_head"] = init_param(vocab_size, n_embd)
for i in 0:n_layer-1
    state_dict["layer$i.attn_wq"]  = init_param(n_embd, n_embd)
    state_dict["layer$i.attn_wk"]  = init_param(n_embd, n_embd)
    state_dict["layer$i.attn_wv"]  = init_param(n_embd, n_embd)
    state_dict["layer$i.attn_wo"]  = init_param(n_embd, n_embd)
    state_dict["layer$i.mlp_fc1"]  = init_param(4 * n_embd, n_embd)
    state_dict["layer$i.mlp_fc2"]  = init_param(n_embd, 4 * n_embd)
end

param_keys = get_param_keys(n_layer)
all_params = collect_params(state_dict, param_keys)
total_num_params = sum(length(value(p)) for p in all_params)
println("\nModel: $total_num_params params ($(round(total_num_params/1e3, digits=1))K) — tier: $COMPUTE_TIER")

params = build_params(state_dict, n_layer)

# ── Dynamic training schedule ──
steps_per_epoch = length(train_tokens)
num_steps = clamp(target_epochs * steps_per_epoch, 500, 100000)
warmup_iters = max(20, round(Int, warmup_frac * num_steps))
eval_interval = max(20, num_steps ÷ 20)

println("\n── Training schedule (computed from data) ──")
println("  steps_per_epoch: $steps_per_epoch")
println("  num_steps:       $num_steps  ($target_epochs epochs)")
println("  eval_interval:   $eval_interval")
println("  warmup_iters:    $warmup_iters")

---
## 7. Training Loop with Validation + Best-Model Checkpointing

Adam optimizer with linear LR decay.
Validates every 50 steps, saves `best_model.json` when val loss improves.
Checkpoints sync to HuggingFace Hub automatically.
Periodic checkpoints every 200 steps.

In [None]:
function compute_val_loss(val_tokens, params, BOS, block_size, n_layer, n_head, head_dim, n_embd)
    total_loss = 0.0
    total_tokens = 0
    kv_key_mats, kv_val_mats, kv_lens = alloc_kv_cache(n_layer, n_embd, block_size)
    for tokens in val_tokens
        n = min(block_size, length(tokens) - 1)
        reset_kv_cache!(kv_lens)
        for pos in 1:n
            token_id = tokens[pos]
            target_id = tokens[pos + 1]
            # Outside @diff, Params behave like plain arrays
            logits = gpt(token_id, pos, kv_key_mats, kv_val_mats, kv_lens, params, n_layer, n_head, head_dim)
            probs = softmax_ag(logits)
            p_val = Float64.(value(probs))
            total_loss += -log(max(p_val[target_id], 1e-10))
            total_tokens += 1
        end
    end
    return total_loss / max(total_tokens, 1)
end

println("compute_val_loss defined (optimized: pre-allocated KV cache, typed params)")

In [None]:
# ── Adam optimizer state (per-parameter matrices) ──
adam_m = Dict{String, Matrix{Float32}}()
adam_v = Dict{String, Matrix{Float32}}()
for k in param_keys
    sz = size(value(state_dict[k]))
    adam_m[k] = zeros(Float32, sz)
    adam_v[k] = zeros(Float32, sz)
end

best_val_loss = Inf
train_loss_history = Float64[]
val_loss_history = Float64[]

LOCAL_CKPT = "checkpoints"
mkpath(LOCAL_CKPT)

function save_and_sync(path, sd, pk, uc, hp; kwargs...)
    save_checkpoint(path, sd, pk, uc, hp; kwargs...)
    hf_sync(path)
end

# ── Cosine LR with warmup ──
function get_lr(step, max_steps, base_lr, min_lr, warmup_steps)
    if step < warmup_steps
        return base_lr * step / warmup_steps
    end
    progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
    return min_lr + 0.5 * (base_lr - min_lr) * (1.0 + cos(Float64(pi) * progress))
end

# ── Training loop with gradient clipping + cosine LR ──
function train_loop!(state_dict, params, param_keys, train_tokens, val_tokens,
                     adam_m, adam_v, uchars, hyperparams;
                     num_steps::Int, lr::Float64, b1::Float64, b2::Float64, eps::Float64,
                     n_layer::Int, n_head::Int, head_dim::Int, n_embd::Int,
                     block_size::Int, BOS::Int,
                     max_grad_norm::Float64=1.0,
                     min_lr::Float64=1e-5,
                     warmup_iters::Int=50,
                     best_val_loss::Float64=Inf,
                     train_loss_history::Vector{Float64}=Float64[],
                     val_loss_history::Vector{Float64}=Float64[],
                     start_step::Int=1)

    kv_key_mats, kv_val_mats, kv_lens = alloc_kv_cache(n_layer, n_embd, block_size)

    hf_status = !isempty(get(ENV, "HF_REPO", "")) ? "HF:$(ENV["HF_REPO"])" : ((@isdefined(HF_REPO_ID) && !isempty(HF_REPO_ID)) ? "HF:$HF_REPO_ID" : "HF:(not configured)")
    println("--- training $num_steps steps (steps $start_step..$(start_step + num_steps - 1)) ---")
    println("    LR: cosine $lr → $min_lr, warmup: $warmup_iters, grad_clip: $max_grad_norm")
    println("    Local: checkpoints/  |  $hf_status")
    t_start = time()
    last_save_time = time()
    SAVE_INTERVAL = 600
    completed_steps = start_step - 1

    try
        for step in start_step:(start_step + num_steps - 1)
            completed_steps = step
            tokens = train_tokens[mod1(step, length(train_tokens))]
            n = min(block_size, length(tokens) - 1)

            reset_kv_cache!(kv_lens)

            tape = @diff begin
                loss_sum = Float32(0)
                for pos in 1:n
                    token_id  = tokens[pos]
                    target_id = tokens[pos + 1]
                    logits = gpt(token_id, pos, kv_key_mats, kv_val_mats, kv_lens, params, n_layer, n_head, head_dim)
                    probs = softmax_ag(logits)
                    loss_sum = loss_sum + (-log(probs[target_id]))
                end
                loss_sum / Float32(n)
            end

            avg_loss = Float64(value(tape))
            push!(train_loss_history, avg_loss)

            # Extract all gradients first
            grads = Dict{String, Matrix{Float32}}()
            for k in param_keys
                g = grad(tape, state_dict[k])
                if g !== nothing
                    grads[k] = to_dense_grad(g)
                end
            end

            # Gradient clipping (global norm)
            grad_norm_sq = 0.0
            for (_, g) in grads
                grad_norm_sq += sum(g .^ 2)
            end
            grad_norm = sqrt(grad_norm_sq)
            clip_scale = grad_norm > max_grad_norm ? Float32(max_grad_norm / grad_norm) : Float32(1.0)

            # Adam update with cosine LR
            lr_t = get_lr(step, start_step + num_steps - 1, lr, min_lr, warmup_iters)
            for k in param_keys
                haskey(grads, k) || continue
                g_dense = grads[k] .* clip_scale
                adam_m[k] .= Float32(b1) .* adam_m[k] .+ Float32(1 - b1) .* g_dense
                adam_v[k] .= Float32(b2) .* adam_v[k] .+ Float32(1 - b2) .* g_dense .^ 2
                m_hat = adam_m[k] ./ Float32(1 - b1^step)
                v_hat = adam_v[k] ./ Float32(1 - b2^step)
                value(state_dict[k]) .-= Float32(lr_t) .* m_hat ./ (sqrt.(v_hat) .+ Float32(eps))
            end

            # Validate + checkpoint
            if step % eval_interval == 0 || step == start_step
                val_loss = compute_val_loss(val_tokens, params, BOS, block_size, n_layer, n_head, head_dim, n_embd)
                push!(val_loss_history, val_loss)
                elapsed = time() - t_start
                wandb_log(; step=step, train_loss=avg_loss, val_loss=val_loss, lr=lr_t)

                improved = ""
                if val_loss < best_val_loss
                    best_val_loss = val_loss
                    save_and_sync("checkpoints/best_model.json", state_dict, param_keys, uchars, hyperparams;
                        adam_m=adam_m, adam_v=adam_v, step=step,
                        lr=lr, b1=b1, b2=b2,
                        best_val_loss=best_val_loss,
                        train_losses=train_loss_history, val_losses=val_loss_history,
                        total_steps=step, num_steps_target=start_step + num_steps - 1)
                    improved = " << best!"
                end

                @printf("step %5d / %5d | train %.4f | val %.4f | lr %.2e | gnorm %.1f | %.1fs%s\n",
                        step, start_step + num_steps - 1, avg_loss, val_loss, lr_t, grad_norm, elapsed, improved)
            elseif step % max(1, eval_interval ÷ 5) == 0
                elapsed = time() - t_start
                @printf("step %5d / %5d | train %.4f | lr %.2e | %.1fs\n",
                        step, start_step + num_steps - 1, avg_loss, lr_t, elapsed)
            end

            if step % max(100, eval_interval * 2) == 0
                save_and_sync("checkpoints/checkpoint_latest.json", state_dict, param_keys, uchars, hyperparams;
                    adam_m=adam_m, adam_v=adam_v, step=step,
                    lr=lr, b1=b1, b2=b2,
                    best_val_loss=best_val_loss,
                    train_losses=train_loss_history, val_losses=val_loss_history,
                    total_steps=step, num_steps_target=start_step + num_steps - 1)
                last_save_time = time()
            end

            if time() - last_save_time > SAVE_INTERVAL
                save_and_sync("checkpoints/checkpoint_latest.json", state_dict, param_keys, uchars, hyperparams;
                    adam_m=adam_m, adam_v=adam_v, step=step,
                    lr=lr, b1=b1, b2=b2,
                    best_val_loss=best_val_loss,
                    train_losses=train_loss_history, val_losses=val_loss_history,
                    total_steps=step, num_steps_target=start_step + num_steps - 1)
                last_save_time = time()
                println("  [auto-save at step $step]")
            end

            if step % GC_INTERVAL == 0
                GC.gc(false)
            end
        end
    catch e
        if e isa InterruptException
            println("\n\nTraining interrupted at step $completed_steps!")
        else
            println("\n\nTraining error at step $completed_steps: $e")
        end
        save_and_sync("checkpoints/checkpoint_interrupted.json", state_dict, param_keys, uchars, hyperparams;
            adam_m=adam_m, adam_v=adam_v, step=completed_steps,
            lr=lr, b1=b1, b2=b2,
            best_val_loss=best_val_loss,
            train_losses=train_loss_history, val_losses=val_loss_history,
            total_steps=completed_steps, num_steps_target=start_step + num_steps - 1)
        if !(e isa InterruptException)
            rethrow(e)
        end
    end

    elapsed = time() - t_start
    @printf("\ntraining complete in %.1f seconds\n", elapsed)
    return best_val_loss, train_loss_history, val_loss_history, completed_steps
end

# ── Run training ──
best_val_loss, train_loss_history, val_loss_history, final_step = train_loop!(
    state_dict, params, param_keys, train_tokens, val_tokens,
    adam_m, adam_v, uchars, hyperparams;
    num_steps=num_steps, lr=lr, b1=b1, b2=b2, eps=eps,
    n_layer=n_layer, n_head=n_head, head_dim=head_dim, n_embd=n_embd,
    block_size=block_size, BOS=BOS,
    max_grad_norm=max_grad_norm, min_lr=min_lr, warmup_iters=warmup_iters,
    best_val_loss=best_val_loss,
    train_loss_history=train_loss_history,
    val_loss_history=val_loss_history)

wandb_finish()

save_and_sync("checkpoints/final_model.json", state_dict, param_keys, uchars, hyperparams;
    adam_m=adam_m, adam_v=adam_v, step=final_step,
    lr=lr, b1=b1, b2=b2,
    best_val_loss=best_val_loss,
    train_losses=train_loss_history, val_losses=val_loss_history,
    total_steps=final_step, num_steps_target=num_steps)

println("\nBest val loss: $(@sprintf("%.4f", best_val_loss))")

---
## 8. Inference — Hallucinated Philosophy

Generate new philosophy-like text using temperature-controlled sampling.

In [None]:
function generate_text(state_dict, n_layer, n_head, head_dim, n_embd, block_size, vocab_size, BOS;
                       max_tokens=200, temperature=0.8)
    params = build_params(state_dict, n_layer)
    kv_key_mats, kv_val_mats, kv_lens = alloc_kv_cache(n_layer, n_embd, block_size)
    reset_kv_cache!(kv_lens)

    token_id = BOS
    generated_ids = Int[]

    for pos in 1:max_tokens
        logits = gpt(token_id, pos, kv_key_mats, kv_val_mats, kv_lens, params, n_layer, n_head, head_dim)
        logits_val = Float64.(value(logits))

        # Temperature-scaled sampling
        logits_val ./= temperature
        logits_val .-= maximum(logits_val)
        probs = exp.(logits_val)
        probs ./= sum(probs)

        # Sample from distribution
        r = rand()
        cum = 0.0
        next_id = 1
        for (i, p) in enumerate(probs)
            cum += p
            if r <= cum
                next_id = i
                break
            end
        end

        next_id == BOS && break  # stop at BOS
        push!(generated_ids, next_id)
        token_id = next_id
    end

    return decode_ids(generated_ids)
end

println("--- Generated Philosophy ---")
for i in 1:5
    text = generate_text(state_dict, n_layer, n_head, head_dim, n_embd, block_size, vocab_size, BOS;
                         max_tokens=300, temperature=0.8)
    @printf("\nSample %d:\n%s\n", i, text[1:min(end, 500)])
    println("---")
end

---
## 8a. Push Model to HuggingFace Hub

Push your trained checkpoint to HuggingFace for persistence across Colab sessions.  
Set `HF_REPO_ID` in the login cell above (e.g. `"yourusername/microgpt-philosophy"`).

In [None]:
# Push checkpoint to HuggingFace Hub
# Make sure HF_REPO_ID is set in the login cell above

if @isdefined(HF_REPO_ID) && !isempty(HF_REPO_ID)
    # Create repo if it doesn't exist yet
    hf_create_repo(HF_REPO_ID)

    # Push best model checkpoint
    if isfile("checkpoints/best_model.json")
        hf_push_checkpoint(HF_REPO_ID; checkpoint_path="checkpoints/best_model.json")
    else
        println("No best_model.json found — train first!")
    end

    # Also push final model if it exists
    if isfile("checkpoints/final_model.json")
        hf_push(HF_REPO_ID, "checkpoints/final_model.json")
    end

    println("\nDone! View your model at: https://huggingface.co/$HF_REPO_ID")
else
    println("Set HF_REPO_ID in the login cell (e.g. \"yourusername/microgpt-philosophy\")")
end

---
## 8b. Pull Checkpoint from HuggingFace Hub

Download a previously pushed checkpoint to resume training in a new Colab session.  
Run this before the Resume Training cell below.

In [None]:
# Pull checkpoint from HuggingFace to resume training
# Make sure HF_REPO_ID is set in the login cell above

if @isdefined(HF_REPO_ID) && !isempty(HF_REPO_ID)
    mkpath("checkpoints")
    hf_pull(HF_REPO_ID, "best_model.json"; local_dir="checkpoints")
    println("\nReady to resume from checkpoints/best_model.json")
    println("Run the 'Resume Training' cell below.")
else
    println("Set HF_REPO_ID in the login cell (e.g. \"yourusername/microgpt-philosophy\")")
end

---
## 9. Resume Training from Checkpoint

Load a saved checkpoint and continue training for more steps.  
Skip this cell if you're training from scratch above.

In [None]:
# ── Load checkpoint ──
# Change the path to load a different checkpoint
RESUME_FROM = "checkpoints/best_model.json"
EXTRA_STEPS = clamp(length(train_docs), 500, 25000)  # ~1 extra epoch

if !isfile(RESUME_FROM)
    # Try pulling from HuggingFace
    if @isdefined(HF_REPO_ID) && !isempty(HF_REPO_ID)
        println("Checkpoint not found locally, pulling from HuggingFace...")
        hf_pull(HF_REPO_ID, basename(RESUME_FROM); local_dir="checkpoints")
    end
    isfile(RESUME_FROM) || error("Checkpoint not found: $RESUME_FROM")
end

ckpt = load_checkpoint(RESUME_FROM)
state_dict = ckpt.state_dict
uchars = ckpt.uchars
BOS = ckpt.BOS
n_layer = ckpt.n_layer
n_embd = ckpt.n_embd
block_size = ckpt.block_size
n_head = ckpt.n_head
head_dim = ckpt.head_dim

hyperparams = Dict{String,Any}(
    "n_layer" => n_layer, "n_embd" => n_embd,
    "block_size" => block_size, "n_head" => n_head
)

# Reconstruct dataset and split (ordered, not shuffled)
docs = copy(TRAINING_DATA)
split_idx = max(1, Int(floor(0.9 * length(docs))))
train_docs = docs[1:split_idx]
val_docs = docs[split_idx+1:end]
if isempty(val_docs)
    val_docs = docs[max(1, end-4):end]
    train_docs = docs[1:max(1, end-5)]
end

# Rebuild O(1) tokenizer and pre-tokenize
char_to_id = Dict{Char, Int}(ch => i for (i, ch) in enumerate(uchars))
train_tokens = [tokenize_doc(doc, char_to_id, BOS) for doc in train_docs]
val_tokens = [tokenize_doc(doc, char_to_id, BOS) for doc in val_docs]

param_keys = get_param_keys(n_layer)

# Build type-stable params tuple
params = build_params(state_dict, n_layer)

# Restore optimizer (Adam state per param key)
lr = ckpt.lr; b1 = ckpt.b1; b2 = ckpt.b2; eps = 1e-8
adam_m = if !isempty(ckpt.adam_m)
    ckpt.adam_m
else
    Dict{String, Matrix{Float32}}(k => zeros(Float32, size(value(state_dict[k]))) for k in param_keys)
end
adam_v = if !isempty(ckpt.adam_v)
    ckpt.adam_v
else
    Dict{String, Matrix{Float32}}(k => zeros(Float32, size(value(state_dict[k]))) for k in param_keys)
end

start_step = ckpt.step + 1
best_val_loss = ckpt.best_val_loss
train_loss_history = copy(ckpt.train_losses)
val_loss_history = copy(ckpt.val_losses)

# ── Initialize W&B logging (if API key is set) ──
if haskey(ENV, "WANDB_API_KEY") && !isempty(ENV["WANDB_API_KEY"])
    wandb_init()
end

println("\nResuming from step $(ckpt.step) -> training $EXTRA_STEPS more steps")
println("Best val loss so far: $(round(best_val_loss, digits=4))")

best_val_loss, train_loss_history, val_loss_history, final_step = train_loop!(
    state_dict, params, param_keys, train_tokens, val_tokens,
    adam_m, adam_v, uchars, hyperparams;
    num_steps=EXTRA_STEPS, lr=lr, b1=b1, b2=b2, eps=Float64(eps),
    n_layer=n_layer, n_head=n_head, head_dim=head_dim, n_embd=n_embd,
    block_size=block_size, BOS=BOS,
    best_val_loss=best_val_loss,
    train_loss_history=train_loss_history,
    val_loss_history=val_loss_history,
    start_step=start_step)

wandb_finish()

save_and_sync("checkpoints/final_model.json", state_dict, param_keys, uchars, hyperparams;
    adam_m=adam_m, adam_v=adam_v, step=final_step,
    lr=lr, b1=b1, b2=b2,
    best_val_loss=best_val_loss,
    train_losses=train_loss_history, val_losses=val_loss_history,
    total_steps=final_step, num_steps_target=start_step + EXTRA_STEPS - 1)

---
## 10. Download Checkpoint

Download the best model checkpoint to use with the inference server.
In Colab, use the Files panel (left sidebar) to download, or pull from HuggingFace Hub.

In [None]:
# List saved checkpoints
if isdir("checkpoints")
    files = readdir("checkpoints")
    println("Saved checkpoints:")
    for f in files
        path = joinpath("checkpoints", f)
        size_kb = round(filesize(path) / 1024, digits=1)
        println("  $path ($(size_kb) KB)")
    end
else
    println("No checkpoints directory found. Train first!")
end