<a href="https://colab.research.google.com/github/DavinciDreams/JuliaGPT/blob/main/juliaflux.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# JuliaFlux GPT — A Flux.jl Transformer for Philosophy

GPU-accelerated character-level GPT trained on classical philosophy texts.
Uses Flux.jl + Zygote.jl + CUDA.jl for automatic differentiation and GPU compute.

**Architecture** (following GPT-2 with simplifications):
- Multi-head causal self-attention with batched matrix ops
- LayerNorm, GELU activations, residual connections
- Cosine LR schedule with warmup
- JLD2 checkpoint persistence (local + HuggingFace Hub)
- W&B logging + HuggingFace Hub integration

Based on: https://github.com/DavinciDreams/JuliaGPT

## 0. Login & Setup

This cell runs in **Python** to read your Colab secrets and save them for Julia.

1. Add secrets via the key icon in the left sidebar:
   - `HF_TOKEN` — your HuggingFace access token
   - `WANDB_KEY` — your Weights & Biases API key
   - `HF_REPO` — your model repo (e.g. `LisaMegaWatts/JuliaFluxGPT`)
   - `HF_DATA_REPO` — your dataset repo (e.g. `LisaMegaWatts/philosophy-corpus`)
2. Run cells 0-1 (login + install Julia, ~3-5 min)
3. **Runtime > Change runtime type > Julia 1.10**
4. Continue with the remaining Julia cells

In [None]:
# ── Minimal Python setup: install tools + save Colab secrets for Julia ──
!pip install -q wandb huggingface_hub datasets

import json, pathlib, os

secrets = {}
try:
    from google.colab import userdata
    for key in ("HF_TOKEN", "WANDB_KEY", "HF_REPO", "HF_DATA_REPO"):
        try: secrets[key] = userdata.get(key)
        except Exception: pass
except ImportError:
    pass

secrets_path = pathlib.Path.home() / ".julia_secrets.json"
secrets_path.write_text(json.dumps(secrets))
secrets_path.chmod(0o600)

found = [k for k in secrets if secrets[k]]
print(f"Secrets saved: {', '.join(found) if found else 'none found'}")
print(f"  -> {secrets_path}")
if not found:
    print("Add HF_TOKEN, WANDB_KEY, HF_REPO via the key icon in the sidebar.")

# ── Download pre-cleaned dataset from HuggingFace ──
DEFAULT_DATA_REPO = "LisaMegaWatts/philosophy-corpus"
data_repo = secrets.get("HF_DATA_REPO", "") or DEFAULT_DATA_REPO
data_dir = pathlib.Path("data")
train_file = data_dir / "train.txt"
val_file = data_dir / "val.txt"

if not (train_file.exists() and val_file.exists()):
    print(f"\nDownloading dataset from HuggingFace: {data_repo}")
    data_dir.mkdir(exist_ok=True)
    if secrets.get("HF_TOKEN"):
        os.environ["HF_TOKEN"] = secrets["HF_TOKEN"]
    try:
        from datasets import load_dataset
        ds = load_dataset(data_repo)
        with open(train_file, "w") as f:
            for row in ds["train"]:
                f.write(row["text"] + "\n")
        split_name = "validation" if "validation" in ds else "val"
        with open(val_file, "w") as f:
            for row in ds[split_name]:
                f.write(row["text"] + "\n")
        print(f"  train.txt: {train_file.stat().st_size:,} bytes")
        print(f"  val.txt:   {val_file.stat().st_size:,} bytes")
    except Exception as e:
        print(f"  Dataset download failed: {e}")
        print("  Copy data/train.txt and data/val.txt manually.")
else:
    print(f"\nDataset already downloaded: {data_dir}/")

print("\nDone! Now run the next cell to install Julia (~3-5 min).")
print("Then: Runtime > Change runtime type > Julia 1.10")

## 1. Install Julia KernelThis cell downloads and installs Julia + IJulia + Flux packages. **Takes ~5-10 minutes** on first run.**After it finishes:**1. Go to **Runtime > Change runtime type**2. Pick **Julia 1.10**3. Continue running the cells below

In [None]:
%%shell
set -e

JULIA_VERSION="1.10.5"
JULIA_MINOR="1.10"

if [ ! -d "/usr/local/julia-${JULIA_VERSION}" ]; then
    echo "Downloading Julia ${JULIA_VERSION}..."
    wget -q https://julialang-s3.julialang.org/bin/linux/x64/${JULIA_MINOR}/julia-${JULIA_VERSION}-linux-x86_64.tar.gz
    tar xzf julia-${JULIA_VERSION}-linux-x86_64.tar.gz -C /usr/local/
    rm julia-${JULIA_VERSION}-linux-x86_64.tar.gz
    ln -sf /usr/local/julia-${JULIA_VERSION}/bin/julia /usr/local/bin/julia
    echo "Julia installed."
else
    echo "Julia already installed."
fi

julia -e '
    using Pkg
    Pkg.add("IJulia")
    Pkg.add(["Flux", "Zygote", "Optimisers", "CUDA", "cuDNN",
             "Downloads", "Statistics", "Random", "Printf",
             "LinearAlgebra", "JLD2", "NNlib", "JSON3"])
    using IJulia
    installkernel("Julia")
'

echo ""
echo "==========================================================="
echo "  Julia kernel installed!                                   "
echo "  Now: Runtime -> Change runtime type -> pick Julia 1.10       "
echo "  Then run the cells below.                              "
echo "===========================================================" 


## 1b. Load Credentials + W&B / HuggingFace Helpers (Julia)

Reads tokens from `~/.julia_secrets.json` (written by the Python setup cell).
W&B logging uses a persistent Python subprocess fed JSON lines from Julia.
HuggingFace helpers use `huggingface-cli` to push/pull checkpoints.

In [None]:
using JSON3

# ── Ensure pip-installed binaries (huggingface-cli, wandb) are on PATH ──
for p in ["/usr/local/bin", joinpath(homedir(), ".local/bin"), "/root/.local/bin"]
    if isdir(p) && !occursin(p, get(ENV, "PATH", ""))
        ENV["PATH"] = p * ":" * get(ENV, "PATH", "")
    end
end

# ── Read credentials from ~/.julia_secrets.json (written by Python setup cell) ──
function load_secrets()
    path = expanduser("~/.julia_secrets.json")
    if !isfile(path)
        @warn "No secrets file found at $path — run the Python setup cell first"
        return Dict{String,String}()
    end
    raw = JSON3.read(read(path, String))
    return Dict{String,String}(string(k) => string(v) for (k, v) in pairs(raw) if !isempty(string(v)))
end

secrets = load_secrets()

# W&B
if haskey(secrets, "WANDB_KEY")
    ENV["WANDB_API_KEY"] = secrets["WANDB_KEY"]
    println("W&B API key: found")
else
    println("W&B API key: not found (add WANDB_KEY to Colab secrets)")
end

# HuggingFace token
if haskey(secrets, "HF_TOKEN")
    ENV["HF_TOKEN"] = secrets["HF_TOKEN"]
    println("HF token: found")
else
    println("HF token: not found (add HF_TOKEN to Colab secrets)")
end

# HuggingFace repo ID
HF_REPO_ID = get(secrets, "HF_REPO", "")
if !isempty(HF_REPO_ID)
    println("HF repo: ", HF_REPO_ID)
else
    println("HF repo: not set (add HF_REPO to Colab secrets or set HF_REPO_ID manually)")
end

In [None]:
WANDB_PROJECT = "juliaflux-philosophy"
WANDB_RUN_ID = "juliaflux-" * join(rand('a':'z', 6))

# Write a tiny Python helper that reads JSON lines on stdin
write("_wandb_log.py", """
import wandb, json, sys, os
project = os.environ.get("WANDB_PROJECT", "juliaflux-philosophy")
run_id = os.environ.get("WANDB_RUN_ID", None)
run = wandb.init(project=project, id=run_id, resume="allow",
                 config={"model": "juliaflux-gpt", "architecture": "multi-layer transformer"})
print(f"W&B run: {run.url}", flush=True)
for line in sys.stdin:
    line = line.strip()
    if not line:
        continue
    try:
        data = json.loads(line)
        wandb.log(data)
    except Exception as e:
        print(f"wandb log error: {e}", file=sys.stderr, flush=True)
wandb.finish()
""")

wandb_proc = nothing

function wandb_init()
    global wandb_proc, WANDB_PROJECT, WANDB_RUN_ID
    if !haskey(ENV, "WANDB_API_KEY") || isempty(ENV["WANDB_API_KEY"])
        println("W&B: skipped (no API key)")
        return
    end
    ENV["WANDB_PROJECT"] = WANDB_PROJECT
    ENV["WANDB_RUN_ID"] = WANDB_RUN_ID
    wandb_proc = open(`python3 _wandb_log.py`, "r+")
    println("W&B: initialized ($WANDB_PROJECT / $WANDB_RUN_ID)")
end

function wandb_log(; kwargs...)
    global wandb_proc
    wandb_proc === nothing && return
    metrics = Dict(string(k) => v for (k, v) in kwargs)
    try
        println(wandb_proc, JSON3.write(metrics))
        flush(wandb_proc)
    catch e
        println("W&B log error: $e")
    end
end

function wandb_finish()
    global wandb_proc
    wandb_proc === nothing && return
    try close(wandb_proc) catch end
    wandb_proc = nothing
    println("W&B: run finished")
end

# ═══════════════════════════════════════════════════════════════
# HuggingFace Hub helpers
# ═══════════════════════════════════════════════════════════════

function hf_push(repo_id::String, local_path::String; remote_path::String="")
    rp = isempty(remote_path) ? basename(local_path) : remote_path
    run(`huggingface-cli upload $repo_id $local_path $rp`)
    println("Pushed $local_path -> $repo_id/$rp")
end

function hf_pull(repo_id::String, remote_path::String; local_dir::String="checkpoints")
    mkpath(local_dir)
    run(`huggingface-cli download $repo_id $remote_path --local-dir $local_dir`)
    println("Pulled $repo_id/$remote_path -> $local_dir/")
end

function hf_push_checkpoint(repo_id::String; checkpoint_path::String="checkpoints/best_model.jld2")
    isfile(checkpoint_path) || error("Checkpoint not found: $checkpoint_path")
    hf_push(repo_id, checkpoint_path)
end

function hf_create_repo(repo_id::String)
    try
        run(`huggingface-cli repo create $repo_id --type model`)
        println("Created HF repo: $repo_id")
    catch
        println("HF repo already exists or creation skipped: $repo_id")
    end
end

# ═══════════════════════════════════════════════════════════════
# Sync helper: push checkpoint to HuggingFace Hub
# ═══════════════════════════════════════════════════════════════

function hf_sync(local_path::String)
    if !@isdefined(HF_REPO_ID) || isempty(HF_REPO_ID)
        return
    end
    try
        hf_push(HF_REPO_ID, local_path)
    catch e
        println("  HF sync failed: $e")
    end
end

println("W&B + HuggingFace helpers defined")

---## 2. Imports & SetupLoad Flux.jl ecosystem packages. CUDA is auto-detected for GPU acceleration.

In [None]:
using Flux
using Zygote
using Optimisers
using CUDA
using NNlib
using Downloads
using Statistics
using Random
using Printf
using LinearAlgebra
using JLD2

Random.seed!(1337)

device = CUDA.functional() ? gpu : cpu
println("Device: ", device)
println("CUDA functional: ", CUDA.functional())
if CUDA.functional()
    println("GPU: ", CUDA.name(CUDA.device()))
    mem = CUDA.totalmem(CUDA.device())
    println("VRAM: ", round(mem / 1024^3, digits=1), " GB")
    println("CUDA version: ", CUDA.runtime_version())
end

---## 3. Hyperparameters

In [None]:
# ── Model architecture ──
block_size     = 256       # context window (doubled from 128)
n_embd         = 384       # embedding dim (up from 256 — better A100 utilization)
n_head         = 6         # attention heads (up from 4)
n_layer        = 6         # transformer layers (up from 4)
dropout        = 0.1
bias           = false

# ── Training ──
batch_size     = 64        # batch size (doubled from 32 — A100 has plenty of VRAM)
learning_rate  = 3e-4      # slightly lower LR for larger model
max_iters      = 5000
eval_interval  = 500
eval_iters     = 100
warmup_iters   = 200
min_lr         = 1e-5

println("n_embd=$n_embd, n_layer=$n_layer, n_head=$n_head, block_size=$block_size, batch_size=$batch_size")

---
## 4. Dataset — Load Pre-Cleaned Corpus

Loads **pre-cleaned philosophy corpus** produced by the
[text-pipeline](https://github.com/DavinciDreams/buildwithbooks).

**Data flow:**
```
text-pipeline/ -> download + clean + chunk -> output/train.txt, val.txt
                -> push to HuggingFace Dataset
                                    |
juliaflux.ipynb -> pulls data/train.txt, val.txt -> tokenize -> train
```

The pipeline applies an 8-stage cleaner (boilerplate stripping, unicode normalization,
character filtering to `a-z .` only) and sentence-boundary chunking (40-256 chars).

**To add new texts:** add sources via the pipeline's Gradio UI or manifest,
re-run the pipeline, push to HuggingFace, and re-run this cell.

In [None]:
# ── Load pre-cleaned data from text pipeline ──
# The Python setup cell downloads from HuggingFace before kernel switch.
# Default dataset: LisaMegaWatts/philosophy-corpus

DATA_DIR = "data"
train_file = joinpath(DATA_DIR, "train.txt")
val_file   = joinpath(DATA_DIR, "val.txt")

# Julia-side fallback: try huggingface-cli if Python download missed
DEFAULT_DATA_REPO = "LisaMegaWatts/philosophy-corpus"
HF_DATA_REPO = let r = get(secrets, "HF_DATA_REPO", ""); isempty(r) ? DEFAULT_DATA_REPO : r end

if !isfile(train_file) || !isfile(val_file)
    println("Data not found locally, trying huggingface-cli download...")
    mkpath(DATA_DIR)
    try
        run(`huggingface-cli download $HF_DATA_REPO train.txt val.txt --local-dir $DATA_DIR`)
        println("Downloaded from $HF_DATA_REPO")
    catch e
        @warn "HuggingFace download failed: $e"
    end
end

if !isfile(train_file) || !isfile(val_file)
    error("""
    No pre-cleaned data found in $DATA_DIR/!

    The Python setup cell should download from $HF_DATA_REPO automatically.
    Re-run the Python setup cell, or copy files manually:
        cp text-pipeline/output/train.txt data/train.txt
        cp text-pipeline/output/val.txt   data/val.txt
    """)
end

train_text = read(train_file, String)
val_text   = read(val_file, String)

println("Pre-cleaned data loaded from $DATA_DIR/ ($HF_DATA_REPO)")
println("  train.txt: $(length(train_text)) chars ($(count('\n', train_text)) chunks)")
println("  val.txt:   $(length(val_text)) chars ($(count('\n', val_text)) chunks)")

In [None]:
# ── Build character-level tokenizer from pre-cleaned corpus ──
# Pipeline uses allowed_chars: "a-z ." so vocab is ~28 chars (letters + space + period)

# Join train + val to discover full character set
full_text = train_text * "\n" * val_text
chars = sort(unique(full_text))

# Remove newline from vocab — it's a chunk separator, not content
filter!(c -> c != '\n', chars)
global vocab_size = length(chars)

stoi = Dict(c => i for (i, c) in enumerate(chars))
itos = Dict(i => c for (i, c) in enumerate(chars))

encode(s) = [stoi[c] for c in s if haskey(stoi, c)]
decode(ids) = join(itos[i] for i in ids)

# Encode train and val separately (pipeline already split 90/10)
# Replace newlines (chunk separators) with spaces for continuous text
train_clean = replace(strip(train_text), '\n' => ' ')
val_clean   = replace(strip(val_text), '\n' => ' ')

global train_data = encode(train_clean)
global val_data   = encode(val_clean)

println("Vocab: $(vocab_size) chars -> [$(join(chars))]")
println("Train: $(length(train_data)) tokens")
println("Val:   $(length(val_data)) tokens")
println("Total: $(length(train_data) + length(val_data)) tokens")

---
## 5. Batch Loader & Model Architecture

Mini-batch loader + full GPT model definition using Flux.jl structs.

In [None]:
function get_batch(split="train")
    d = split == "train" ? train_data : val_data
    ix = rand(1:length(d)-block_size, batch_size)
    x = hcat([d[i:i+block_size-1] for i in ix]...)
    y = hcat([d[i+1:i+block_size] for i in ix]...)
    x = permutedims(x)   # now (B, T)
    y = permutedims(y)
    x = x |> device
    y = y |> device
    return x, y
end

In [None]:
# ── All model structs in ONE cell (Julia structs can't be redefined) ──
using NNlib: batched_mul

# Pre-compute causal mask once (avoids recreating on every forward pass)
const CAUSAL_MASK = triu(fill(typemin(Float32), block_size, block_size), 1)
const CAUSAL_MASK_GPU = CUDA.functional() ? cu(CAUSAL_MASK) : CAUSAL_MASK

# Pre-compute position indices (avoids CPU allocation each forward pass)
const POS_RANGE = collect(Int32, 1:block_size)
const POS_RANGE_GPU = CUDA.functional() ? cu(POS_RANGE) : POS_RANGE

struct CausalSelfAttention
    qkv::Dense       # single projection: n_embd -> 3*n_embd
    proj::Dense       # output projection: n_embd -> n_embd
    n_head::Int
end

Flux.@layer CausalSelfAttention trainable=(qkv, proj)

function CausalSelfAttention(n_embd::Int, n_head::Int; bias=false)
    CausalSelfAttention(
        Dense(n_embd => 3 * n_embd; bias),
        Dense(n_embd => n_embd; bias),
        n_head
    )
end

function (attn::CausalSelfAttention)(x)
    C, T, B = size(x)
    hs = C ÷ attn.n_head
    nh = attn.n_head

    qkv = attn.qkv(x)
    q = qkv[1:C, :, :]
    k = qkv[C+1:2C, :, :]
    v = qkv[2C+1:3C, :, :]

    q = reshape(permutedims(reshape(q, hs, nh, T, B), (1, 3, 2, 4)), hs, T, nh * B)
    k = reshape(permutedims(reshape(k, hs, nh, T, B), (1, 3, 2, 4)), hs, T, nh * B)
    v = reshape(permutedims(reshape(v, hs, nh, T, B), (1, 3, 2, 4)), hs, T, nh * B)

    scale = Float32(1 / sqrt(hs))
    wei = batched_mul(permutedims(q, (2, 1, 3)), k) .* scale

    # Use pre-computed causal mask (slice to current sequence length)
    mask = x isa CuArray ? CAUSAL_MASK_GPU[1:T, 1:T] : CAUSAL_MASK[1:T, 1:T]
    wei = wei .+ mask
    wei = softmax(wei; dims=2)

    out = batched_mul(v, permutedims(wei, (2, 1, 3)))
    out = reshape(permutedims(reshape(out, hs, T, nh, B), (1, 3, 2, 4)), C, T, B)

    attn.proj(out)
end

struct FeedForward
    net::Chain
end

Flux.@layer FeedForward

function FeedForward(n_embd::Int; bias=false, dropout=0.0)
    FeedForward(Chain(
        Dense(n_embd => 4 * n_embd; bias),
        gelu,
        Dense(4 * n_embd => n_embd; bias),
        Dropout(dropout)
    ))
end

(ff::FeedForward)(x) = ff.net(x)

struct TransformerBlock
    ln1::LayerNorm
    attn::CausalSelfAttention
    ln2::LayerNorm
    ffwd::FeedForward
end

Flux.@layer TransformerBlock

function TransformerBlock(n_embd::Int, n_head::Int; dropout=0.0)
    TransformerBlock(
        LayerNorm(n_embd),
        CausalSelfAttention(n_embd, n_head),
        LayerNorm(n_embd),
        FeedForward(n_embd; dropout)
    )
end

function (block::TransformerBlock)(x)
    x = x .+ block.attn(block.ln1(x))
    x = x .+ block.ffwd(block.ln2(x))
    x
end

struct GPT
    wte::Embedding
    wpe::Embedding
    drop::Dropout
    blocks::Chain
    ln_f::LayerNorm
    lm_head::Dense
end

Flux.@layer GPT

function GPT(; vocab_size, n_embd, block_size, n_layer, n_head, dropout=0.1)
    GPT(
        Embedding(vocab_size => n_embd),
        Embedding(block_size => n_embd),
        Dropout(dropout),
        Chain([TransformerBlock(n_embd, n_head; dropout) for _ in 1:n_layer]...),
        LayerNorm(n_embd),
        Dense(n_embd => vocab_size; bias=false)
    )
end

function (m::GPT)(idx)
    B, T = size(idx)
    tok = permutedims(m.wte(idx), (1, 3, 2))
    # Use pre-computed position indices (already on correct device)
    pos_src = idx isa CuArray ? POS_RANGE_GPU : POS_RANGE
    pos_ids = repeat(reshape(pos_src[1:T], 1, T), B, 1)
    pos = permutedims(m.wpe(pos_ids), (1, 3, 2))
    x = m.drop(tok .+ pos)
    x = m.blocks(x)
    x = m.ln_f(x)
    m.lm_head(x)
end

println("All model structs defined: CausalSelfAttention, FeedForward, TransformerBlock, GPT")
println("Pre-computed: causal mask ($(size(CAUSAL_MASK))), position indices ($(length(POS_RANGE)))")

In [None]:
model = GPT(;
    vocab_size = vocab_size,
    n_embd     = n_embd,
    block_size = block_size,
    n_layer    = n_layer,
    n_head     = n_head,
    dropout    = dropout
) |> device

n_params = sum(length, Flux.trainables(model))
println("Model created on $device")
println("Parameters: $(n_params) ($(round(n_params/1e6, digits=2))M)")

if CUDA.functional()
    used = CUDA.memory_status()
    println("GPU memory after model: $(round(CUDA.used_memory() / 1024^2, digits=1)) MB used")
end

# Quick smoke test
x_test, y_test = get_batch("train")
logits_test = model(x_test)
println("Forward pass OK -- logits: $(size(logits_test))")

if CUDA.functional()
    println("GPU memory after forward pass: $(round(CUDA.used_memory() / 1024^2, digits=1)) MB used")
end

---
## 6. Checkpoint Save/Load

Save and load model weights + optimizer state as JLD2 files.
Checkpoints are synced to HuggingFace Hub for persistence across Colab sessions.

In [None]:
LOCAL_CKPT = "checkpoints"
mkpath(LOCAL_CKPT)

function save_checkpoint(path::String, model, opt_state;
                          step::Int=0, best_val_loss::Float64=Inf,
                          train_losses::Vector{Float64}=Float64[],
                          val_losses::Vector{Float64}=Float64[])
    mkpath(dirname(path))
    model_cpu = cpu(model)
    opt_cpu = cpu(opt_state)
    JLD2.jldsave(path;
        model_state = Flux.state(model_cpu),
        opt_state = opt_cpu,
        step = step,
        best_val_loss = best_val_loss,
        train_losses = train_losses,
        val_losses = val_losses,
        hyperparams = Dict(
            "vocab_size" => vocab_size,
            "n_embd" => n_embd,
            "block_size" => block_size,
            "n_layer" => n_layer,
            "n_head" => n_head,
            "dropout" => dropout
        )
    )
    vl_str = best_val_loss == Inf ? "Inf" : @sprintf("%.4f", best_val_loss)
    println("Checkpoint saved: $path (step $step, best_val_loss=$vl_str)")
end

function save_and_sync(path, model, opt_state; kwargs...)
    save_checkpoint(path, model, opt_state; kwargs...)
    hf_sync(path)
end

function load_checkpoint(path::String, device_fn)
    println("Loading checkpoint from $path ...")
    data = JLD2.load(path)

    hp = data["hyperparams"]
    m = GPT(;
        vocab_size = hp["vocab_size"],
        n_embd     = hp["n_embd"],
        block_size = hp["block_size"],
        n_layer    = hp["n_layer"],
        n_head     = hp["n_head"],
        dropout    = get(hp, "dropout", 0.1)
    )
    Flux.loadmodel!(m, data["model_state"])
    m = m |> device_fn

    opt = data["opt_state"]

    println("  step=$(data["step"]), best_val=$(round(data["best_val_loss"], digits=4))")
    return (;
        model = m,
        opt_state = opt |> device_fn,
        step = data["step"],
        best_val_loss = data["best_val_loss"],
        train_losses = get(data, "train_losses", Float64[]),
        val_losses = get(data, "val_losses", Float64[])
    )
end

println("Checkpoint save/load defined (JLD2 + HuggingFace sync)")

---
## 7. Training Loop with Validation + Best-Model Checkpointing

Adam optimizer with cosine LR decay + warmup.
Validates every 500 steps, saves `best_model.jld2` when val loss improves.
Checkpoints sync to HuggingFace Hub automatically.
Defensive saves: try/catch with emergency checkpoint, time-based auto-save.

In [None]:
using Printf

function estimate_loss(model, n_iters=eval_iters)
    model_eval = Flux.testmode!(deepcopy(model))
    losses = Dict("train" => 0.0, "val" => 0.0)
    for split in ["train", "val"]
        total = 0.0
        for _ in 1:n_iters
            x, y = get_batch(split)
            logits = model_eval(x)
            # Compute loss entirely on GPU — reshape targets to match logits
            y_flat = reshape(y, :)
            logits_flat = reshape(logits, vocab_size, :)
            onehot = Flux.onehotbatch(y_flat, 1:vocab_size) |> device
            loss = Flux.logitcrossentropy(logits_flat, onehot)
            total += loss
        end
        losses[split] = total / n_iters
    end
    return losses
end

# LR schedule: warmup + cosine decay
function get_lr(iter)
    if iter < warmup_iters
        return learning_rate * iter / warmup_iters
    end
    decay_ratio = (iter - warmup_iters) / (max_iters - warmup_iters)
    coeff = 0.5 * (1.0 + cos(Float64(pi) * decay_ratio))
    return min_lr + coeff * (learning_rate - min_lr)
end

opt_state = Flux.setup(Adam(learning_rate), model)

best_val = Inf
train_loss_history = Float64[]
val_loss_history = Float64[]

# ── Initialize W&B logging (if API key is set) ──
if haskey(ENV, "WANDB_API_KEY") && !isempty(ENV["WANDB_API_KEY"])
    wandb_init()
end

SAVE_INTERVAL = 600  # auto-save every 10 min
last_save_time = time()
completed_iter = 0

hf_status = !isempty(HF_REPO_ID) ? "HF:$HF_REPO_ID" : "HF:(not configured)"
println("Training for $max_iters steps...")
println("    Local: $LOCAL_CKPT/  |  $hf_status")
t_start = time()

try
    for iter in 1:max_iters
        global completed_iter = iter

        # Update LR
        lr_t = get_lr(iter)
        Flux.adjust!(opt_state, lr_t)

        # Forward + backward + update (loss computed entirely on GPU)
        x, y = get_batch("train")
        loss, grads = Flux.withgradient(model) do m
            logits = m(x)
            y_flat = reshape(y, :)
            logits_flat = reshape(logits, vocab_size, :)
            onehot = Flux.onehotbatch(y_flat, 1:vocab_size) |> device
            Flux.logitcrossentropy(logits_flat, onehot)
        end
        Flux.update!(opt_state, model, grads[1])
        push!(train_loss_history, Float64(loss))

        # Incremental GC to prevent GPU memory fragmentation
        if iter % 100 == 0 && CUDA.functional()
            GC.gc(false)
        end

        # Eval + print + checkpoint
        if iter % eval_interval == 0 || iter == 1
            losses = estimate_loss(model)
            push!(val_loss_history, losses["val"])
            elapsed = round(time() - t_start, digits=1)
            wandb_log(; step=iter, train_loss=Float64(loss), val_loss=losses["val"], lr=lr_t)

            improved = ""
            if losses["val"] < best_val
                best_val = losses["val"]
                save_and_sync("checkpoints/best_model.jld2", model, opt_state;
                    step=iter, best_val_loss=best_val,
                    train_losses=train_loss_history, val_losses=val_loss_history)
                improved = " << new best!"
            end

            @printf("step %5d | train %.4f | val %.4f | lr %.2e | %.1fs%s\n",
                    iter, losses["train"], losses["val"], lr_t, elapsed, improved)
        end

        # Periodic checkpoint every 1000 steps
        if iter % 1000 == 0
            save_and_sync("checkpoints/checkpoint_latest.jld2", model, opt_state;
                step=iter, best_val_loss=best_val,
                train_losses=train_loss_history, val_losses=val_loss_history)
            last_save_time = time()
        end

        # Time-based auto-save every 10 min
        if time() - last_save_time > SAVE_INTERVAL
            save_and_sync("checkpoints/checkpoint_latest.jld2", model, opt_state;
                step=iter, best_val_loss=best_val,
                train_losses=train_loss_history, val_losses=val_loss_history)
            last_save_time = time()
            println("  [auto-save at step $iter]")
        end
    end
catch e
    if e isa InterruptException
        println("\n\nTraining interrupted at step $completed_iter!")
    else
        println("\n\nTraining error at step $completed_iter: $e")
    end
    println("Saving emergency checkpoint...")
    save_and_sync("checkpoints/checkpoint_interrupted.jld2", model, opt_state;
        step=completed_iter, best_val_loss=best_val,
        train_losses=train_loss_history, val_losses=val_loss_history)
    if !(e isa InterruptException)
        rethrow(e)
    end
end

elapsed = round(time() - t_start, digits=1)
println("\nTraining complete in $(elapsed)s. Best val loss: $(round(best_val, digits=4))")
wandb_finish()

# Final save
save_and_sync("checkpoints/final_model.jld2", model, opt_state;
    step=max_iters, best_val_loss=best_val,
    train_losses=train_loss_history, val_losses=val_loss_history)

---## 8. Inference — Generate TextGenerate new philosophy-like text using temperature-controlled sampling.

In [None]:
function generate_text(model, max_tokens=200; temperature=0.8)
    model_eval = Flux.testmode!(deepcopy(model))
    # Start with a random token
    idx = reshape([rand(1:vocab_size)], 1, 1) |> device
    generated = Int[]

    for _ in 1:max_tokens
        # Crop to block_size
        idx_cond = idx[:, max(1, end-block_size+1):end]
        logits = model_eval(idx_cond)     # (vocab, T, B)
        logits_last = logits[:, end, 1]   # (vocab,) last token logits

        # Temperature scaling + softmax
        probs = softmax(logits_last ./ Float32(temperature))
        probs_cpu = Float64.(cpu(probs))

        # Categorical sampling
        r = rand()
        cum = 0.0
        next_id = 1
        for (i, p) in enumerate(probs_cpu)
            cum += p
            if r <= cum
                next_id = i
                break
            end
        end

        push!(generated, next_id)
        next_token = reshape([next_id], 1, 1) |> device
        idx = hcat(idx, next_token)
    end

    return decode(generated)
end

println("--- inference (hallucinated philosophy) ---")
for i in 1:5
    text = generate_text(model, 300; temperature=0.8)
    @printf("\nsample %d:\n%s\n", i, text[1:min(end, 500)])
    println("---")
end


---## 8a. Push Model to HuggingFace HubPush your trained checkpoint to HuggingFace for persistence across Colab sessions.Set `HF_REPO_ID` in the login cell above.

In [None]:
if @isdefined(HF_REPO_ID) && !isempty(HF_REPO_ID)
    hf_create_repo(HF_REPO_ID)

    if isfile("checkpoints/best_model.jld2")
        hf_push_checkpoint(HF_REPO_ID; checkpoint_path="checkpoints/best_model.jld2")
    else
        println("No best_model.jld2 found -- train first!")
    end

    if isfile("checkpoints/final_model.jld2")
        hf_push(HF_REPO_ID, "checkpoints/final_model.jld2")
    end

    println("\nDone! View your model at: https://huggingface.co/$HF_REPO_ID")
else
    println("Set HF_REPO_ID in the login cell (e.g. \"yourusername/juliaflux-philosophy\")")
end


---## 8b. Pull Checkpoint from HuggingFace HubDownload a previously pushed checkpoint to resume training in a new Colab session.

In [None]:
if @isdefined(HF_REPO_ID) && !isempty(HF_REPO_ID)
    mkpath("checkpoints")
    hf_pull(HF_REPO_ID, "best_model.jld2"; local_dir="checkpoints")
    println("\nReady to resume from checkpoints/best_model.jld2")
    println("Run the 'Resume Training' cell below.")
else
    println("Set HF_REPO_ID in the login cell (e.g. \"yourusername/juliaflux-philosophy\")")
end


---## 9. Resume Training from CheckpointLoad a saved checkpoint and continue training for more steps.Skip this cell if you're training from scratch above.

In [None]:
RESUME_FROM = "checkpoints/best_model.jld2"
EXTRA_ITERS = 2000

if !isfile(RESUME_FROM)
    # Try pulling from HuggingFace
    if @isdefined(HF_REPO_ID) && !isempty(HF_REPO_ID)
        println("Checkpoint not found locally, pulling from HuggingFace...")
        hf_pull(HF_REPO_ID, basename(RESUME_FROM); local_dir="checkpoints")
    end
    isfile(RESUME_FROM) || error("Checkpoint not found: $RESUME_FROM")
end

ckpt = load_checkpoint(RESUME_FROM, device)
model = ckpt.model
opt_state = ckpt.opt_state
start_iter = ckpt.step + 1
best_val = ckpt.best_val_loss
train_loss_history = copy(ckpt.train_losses)
val_loss_history = copy(ckpt.val_losses)
end_iter = ckpt.step + EXTRA_ITERS

# ── Initialize W&B logging (if API key is set) ──
if haskey(ENV, "WANDB_API_KEY") && !isempty(ENV["WANDB_API_KEY"])
    wandb_init()
end

println("\nResuming from step $(ckpt.step) -> training to step $end_iter")
println("Best val loss so far: $(round(best_val, digits=4))")
t_start = time()
last_save_time = time()

try
    for iter in start_iter:end_iter
        global completed_iter = iter

        lr_t = get_lr(min(iter, max_iters))
        Flux.adjust!(opt_state, lr_t)

        x, y = get_batch("train")
        loss, grads = Flux.withgradient(model) do m
            logits = m(x)
            y_flat = reshape(y, :)
            logits_flat = reshape(logits, vocab_size, :)
            onehot = Flux.onehotbatch(y_flat, 1:vocab_size) |> device
            Flux.logitcrossentropy(logits_flat, onehot)
        end
        Flux.update!(opt_state, model, grads[1])
        push!(train_loss_history, Float64(loss))

        # Incremental GC to prevent GPU memory fragmentation
        if iter % 100 == 0 && CUDA.functional()
            GC.gc(false)
        end

        if iter % eval_interval == 0
            losses = estimate_loss(model)
            push!(val_loss_history, losses["val"])
            elapsed = round(time() - t_start, digits=1)
            wandb_log(; step=iter, train_loss=Float64(loss), val_loss=losses["val"], lr=lr_t)

            improved = ""
            if losses["val"] < best_val
                best_val = losses["val"]
                save_and_sync("checkpoints/best_model.jld2", model, opt_state;
                    step=iter, best_val_loss=best_val,
                    train_losses=train_loss_history, val_losses=val_loss_history)
                improved = " << new best!"
            end

            @printf("step %5d / %5d | train %.4f | val %.4f | lr %.2e | %.1fs%s\n",
                    iter, end_iter, losses["train"], losses["val"], lr_t, elapsed, improved)
        end

        if iter % 1000 == 0
            save_and_sync("checkpoints/checkpoint_latest.jld2", model, opt_state;
                step=iter, best_val_loss=best_val,
                train_losses=train_loss_history, val_losses=val_loss_history)
            last_save_time = time()
        end

        if time() - last_save_time > SAVE_INTERVAL
            save_and_sync("checkpoints/checkpoint_latest.jld2", model, opt_state;
                step=iter, best_val_loss=best_val,
                train_losses=train_loss_history, val_losses=val_loss_history)
            last_save_time = time()
            println("  [auto-save at step $iter]")
        end
    end
catch e
    if e isa InterruptException
        println("\n\nTraining interrupted at step $completed_iter!")
    else
        println("\n\nTraining error at step $completed_iter: $e")
    end
    println("Saving emergency checkpoint...")
    save_and_sync("checkpoints/checkpoint_interrupted.jld2", model, opt_state;
        step=completed_iter, best_val_loss=best_val,
        train_losses=train_loss_history, val_losses=val_loss_history)
    if !(e isa InterruptException)
        rethrow(e)
    end
end

elapsed = round(time() - t_start, digits=1)
@printf("\nResume training complete in %.1fs\n", elapsed)
wandb_finish()

save_and_sync("checkpoints/final_model.jld2", model, opt_state;
    step=end_iter, best_val_loss=best_val,
    train_losses=train_loss_history, val_losses=val_loss_history)

---
## 10. Download Checkpoint

Download the best model checkpoint to use elsewhere.
In Colab, use the Files panel (left sidebar) to download, or pull from HuggingFace Hub.

In [None]:
if isdir("checkpoints")
    files = readdir("checkpoints")
    println("Saved checkpoints:")
    for f in files
        path = joinpath("checkpoints", f)
        size_kb = round(filesize(path) / 1024, digits=1)
        println("  $path ($(size_kb) KB)")
    end
else
    println("No checkpoints directory found. Train first!")
end
