<a href="https://colab.research.google.com/github/DavinciDreams/JuliaGPT/blob/main/juliaflux_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# JuliaFlux v2 — SOTA Small Language Model

GPU-accelerated small language model trained on classical texts using the trivium/quadrivium curriculum.
Uses Flux.jl + CUDA.jl with modern LLaMA-style architecture.

**Architecture (LLaMA-style):**
- RMSNorm (replaces LayerNorm)
- Rotary Position Embeddings (RoPE, replaces learned position embeddings)
- SwiGLU activation (replaces GELU FFN)
- Grouped Query Attention (GQA: 6 Q heads, 2 KV heads)
- Weight-tied output projection
- Gradient clipping, cosine LR with warmup
- BPE tokenization with char-level fallback
- Curriculum learning (trivium → quadrivium → philosophy)

Based on: https://github.com/DavinciDreams/JuliaGPT

## 0. Login & Setup

This cell runs in **Python** to read your Colab secrets and save them for Julia.

1. Add secrets via the key icon in the left sidebar:
   - `HF_TOKEN` — your HuggingFace access token
   - `WANDB_KEY` — your Weights & Biases API key
   - `HF_REPO` — your model repo (e.g. `LisaMegaWatts/JuliaFluxGPT`)
   - `HF_DATA_REPO` — your dataset repo (e.g. `LisaMegaWatts/philosophy-corpus`)
2. Run cells 0-1 (login + install Julia, ~3-5 min)
3. **Runtime > Change runtime type > Julia 1.10**
4. Continue with the remaining Julia cells

In [None]:
# ── Minimal Python setup: install tools + save Colab secrets for Julia ──
!pip install -q wandb huggingface_hub datasets tokenizers

import json, pathlib, os

secrets = {}
try:
    from google.colab import userdata
    for key in ("HF_TOKEN", "WANDB_KEY", "HF_REPO", "HF_DATA_REPO"):
        try: secrets[key] = userdata.get(key)
        except Exception: pass
except ImportError:
    pass

secrets_path = pathlib.Path.home() / ".julia_secrets.json"
secrets_path.write_text(json.dumps(secrets))
secrets_path.chmod(0o600)

found = [k for k in secrets if secrets[k]]
print(f"Secrets saved: {', '.join(found) if found else 'none found'}")

# ── Download pre-cleaned dataset from HuggingFace ──
DEFAULT_DATA_REPO = "LisaMegaWatts/philosophy-corpus"
data_repo = secrets.get("HF_DATA_REPO", "") or DEFAULT_DATA_REPO
data_dir = pathlib.Path("data")
train_file = data_dir / "train.txt"
val_file = data_dir / "val.txt"

if not (train_file.exists() and val_file.exists()):
    print(f"\nDownloading dataset from HuggingFace: {data_repo}")
    data_dir.mkdir(exist_ok=True)
    if secrets.get("HF_TOKEN"):
        os.environ["HF_TOKEN"] = secrets["HF_TOKEN"]
    try:
        from datasets import load_dataset
        ds = load_dataset(data_repo)
        with open(train_file, "w") as f:
            for row in ds["train"]:
                f.write(row["text"] + "\n")
        split_name = "validation" if "validation" in ds else "val"
        with open(val_file, "w") as f:
            for row in ds[split_name]:
                f.write(row["text"] + "\n")
        print(f"  train.txt: {train_file.stat().st_size:,} bytes")
        print(f"  val.txt:   {val_file.stat().st_size:,} bytes")
    except Exception as e:
        print(f"  Dataset download failed: {e}")
else:
    print(f"\nDataset already downloaded: {data_dir}/")

# ── Download tokenizer.json if available ──
tokenizer_file = data_dir / "tokenizer.json"
if not tokenizer_file.exists():
    try:
        from huggingface_hub import hf_hub_download
        hf_hub_download(repo_id=data_repo, filename="tokenizer.json",
                       local_dir=str(data_dir), repo_type="dataset")
        print(f"  tokenizer.json downloaded")
    except Exception:
        print("  No tokenizer.json on HuggingFace (will use char-level tokenizer)")

# ── Download curriculum phase files if available ──
for phase in ["trivium", "quadrivium", "philosophy"]:
    phase_file = data_dir / f"train_{phase}.txt"
    if not phase_file.exists():
        try:
            from huggingface_hub import hf_hub_download
            hf_hub_download(repo_id=data_repo, filename=f"train_{phase}.txt",
                           local_dir=str(data_dir), repo_type="dataset")
            print(f"  train_{phase}.txt downloaded")
        except Exception:
            pass

print("\nDone! Now run the next cell to install Julia (~3-5 min).")
print("Then: Runtime > Change runtime type > Julia 1.10")

## 1. Install Julia Kernel
This cell downloads and installs Julia + IJulia + Flux packages. **Takes ~5-10 minutes** on first run.

**After it finishes:**
1. Go to **Runtime > Change runtime type**
2. Pick **Julia 1.10**
3. Continue running the cells below

In [None]:
%%shell
set -e

JULIA_VERSION="1.10.5"
JULIA_MINOR="1.10"

if [ ! -d "/usr/local/julia-${JULIA_VERSION}" ]; then
    echo "Downloading Julia ${JULIA_VERSION}..."
    wget -q https://julialang-s3.julialang.org/bin/linux/x64/${JULIA_MINOR}/julia-${JULIA_VERSION}-linux-x86_64.tar.gz
    tar xzf julia-${JULIA_VERSION}-linux-x86_64.tar.gz -C /usr/local/
    rm julia-${JULIA_VERSION}-linux-x86_64.tar.gz
    ln -sf /usr/local/julia-${JULIA_VERSION}/bin/julia /usr/local/bin/julia
    echo "Julia installed."
else
    echo "Julia already installed."
fi

julia -e '
    using Pkg
    Pkg.add("IJulia")
    Pkg.add(["Flux", "Zygote", "Optimisers", "CUDA", "cuDNN",
             "Downloads", "Statistics", "Random", "Printf",
             "LinearAlgebra", "JLD2", "NNlib", "JSON3"])
    using IJulia
    installkernel("Julia")
'

echo ""
echo "==========================================================="
echo "  Julia kernel installed!                                   "
echo "  Now: Runtime -> Change runtime type -> pick Julia 1.10   "
echo "  Then run the cells below.                                 "
echo "==========================================================="

## 1b. Load Credentials + W&B / HuggingFace Helpers (Julia)

Reads tokens from `~/.julia_secrets.json` (written by the Python setup cell).

In [None]:
using JSON3

# ── Ensure pip-installed binaries (huggingface-cli, wandb) are on PATH ──
for p in ["/usr/local/bin", joinpath(homedir(), ".local/bin"), "/root/.local/bin"]
    if isdir(p) && !occursin(p, get(ENV, "PATH", ""))
        ENV["PATH"] = p * ":" * get(ENV, "PATH", "")
    end
end

# ── Read credentials from ~/.julia_secrets.json (written by Python setup cell) ──
function load_secrets()
    path = expanduser("~/.julia_secrets.json")
    if !isfile(path)
        @warn "No secrets file found at $path — run the Python setup cell first"
        return Dict{String,String}()
    end
    raw = JSON3.read(read(path, String))
    return Dict{String,String}(string(k) => string(v) for (k, v) in pairs(raw) if !isempty(string(v)))
end

secrets = load_secrets()

# W&B
if haskey(secrets, "WANDB_KEY")
    ENV["WANDB_API_KEY"] = secrets["WANDB_KEY"]
    println("W&B API key: found")
else
    println("W&B API key: not found (add WANDB_KEY to Colab secrets)")
end

# HuggingFace token
if haskey(secrets, "HF_TOKEN")
    ENV["HF_TOKEN"] = secrets["HF_TOKEN"]
    println("HF token: found")
else
    println("HF token: not found (add HF_TOKEN to Colab secrets)")
end

# HuggingFace repo ID
HF_REPO_ID = get(secrets, "HF_REPO", "")
if !isempty(HF_REPO_ID)
    println("HF repo: ", HF_REPO_ID)
else
    println("HF repo: not set (add HF_REPO to Colab secrets or set HF_REPO_ID manually)")
end

In [None]:
WANDB_PROJECT = "juliaflux-v2-philosophy"
WANDB_RUN_ID = "juliaflux-" * join(rand('a':'z', 6))

# Write a tiny Python helper that reads JSON lines on stdin
write("_wandb_log.py", """
import wandb, json, sys, os
project = os.environ.get("WANDB_PROJECT", "juliaflux-v2-philosophy")
run_id = os.environ.get("WANDB_RUN_ID", None)
run = wandb.init(project=project, id=run_id, resume="allow",
                 config={"model": "juliaflux-gpt", "architecture": "llama-style transformer"})
print(f"W&B run: {run.url}", flush=True)
for line in sys.stdin:
    line = line.strip()
    if not line:
        continue
    try:
        data = json.loads(line)
        wandb.log(data)
    except Exception as e:
        print(f"wandb log error: {e}", file=sys.stderr, flush=True)
wandb.finish()
""")

wandb_proc = nothing

function wandb_init()
    global wandb_proc, WANDB_PROJECT, WANDB_RUN_ID
    if !haskey(ENV, "WANDB_API_KEY") || isempty(ENV["WANDB_API_KEY"])
        println("W&B: skipped (no API key)")
        return
    end
    ENV["WANDB_PROJECT"] = WANDB_PROJECT
    ENV["WANDB_RUN_ID"] = WANDB_RUN_ID
    wandb_proc = open(`python3 _wandb_log.py`, "r+")
    println("W&B: initialized ($WANDB_PROJECT / $WANDB_RUN_ID)")
end

function wandb_log(; kwargs...)
    global wandb_proc
    wandb_proc === nothing && return
    metrics = Dict(string(k) => v for (k, v) in kwargs)
    try
        println(wandb_proc, JSON3.write(metrics))
        flush(wandb_proc)
    catch e
        println("W&B log error: $e")
    end
end

function wandb_finish()
    global wandb_proc
    wandb_proc === nothing && return
    try close(wandb_proc) catch end
    wandb_proc = nothing
    println("W&B: run finished")
end

# ═══════════════════════════════════════════════════════════════
# HuggingFace Hub helpers
# ═══════════════════════════════════════════════════════════════

function hf_push(repo_id::String, local_path::String; remote_path::String="")
    rp = isempty(remote_path) ? basename(local_path) : remote_path
    run(`huggingface-cli upload $repo_id $local_path $rp`)
    println("Pushed $local_path -> $repo_id/$rp")
end

function hf_pull(repo_id::String, remote_path::String; local_dir::String="checkpoints")
    mkpath(local_dir)
    run(`huggingface-cli download $repo_id $remote_path --local-dir $local_dir`)
    println("Pulled $repo_id/$remote_path -> $local_dir/")
end

function hf_push_checkpoint(repo_id::String; checkpoint_path::String="checkpoints/best_model.jld2")
    isfile(checkpoint_path) || error("Checkpoint not found: $checkpoint_path")
    hf_push(repo_id, checkpoint_path)
end

function hf_create_repo(repo_id::String)
    try
        run(`huggingface-cli repo create $repo_id --type model`)
        println("Created HF repo: $repo_id")
    catch
        println("HF repo already exists or creation skipped: $repo_id")
    end
end

function hf_sync(local_path::String)
    if !@isdefined(HF_REPO_ID) || isempty(HF_REPO_ID)
        return
    end
    try
        hf_push(HF_REPO_ID, local_path)
    catch e
        println("  HF sync failed: $e")
    end
end

println("W&B + HuggingFace helpers defined")

---
## 2. Imports & Setup
Load Flux.jl ecosystem packages. CUDA is auto-detected for GPU acceleration.

In [None]:
using Flux
using Zygote
using Optimisers
using CUDA
using NNlib
using Downloads
using Statistics
using Random
using Printf
using LinearAlgebra
using JLD2

Random.seed!(1337)

device = CUDA.functional() ? gpu : cpu
println("Device: ", device)
println("CUDA functional: ", CUDA.functional())
if CUDA.functional()
    println("GPU: ", CUDA.name(CUDA.device()))
    mem = CUDA.totalmem(CUDA.device())
    println("VRAM: ", round(mem / 1024^3, digits=1), " GB")
    println("CUDA version: ", CUDA.runtime_version())
end

---
## 3. Hyperparameters

SOTA architecture defaults: RoPE, SwiGLU, GQA, RMSNorm.

In [None]:
# ── Model architecture (LLaMA-style) ──
block_size     = 256       # context window
n_embd         = 384       # embedding dim
n_head         = 6         # Q attention heads
n_kv_head      = 2         # KV heads for GQA (each KV head serves 3 Q heads)
n_layer        = 6         # transformer layers
dropout        = 0.1
bias           = false
rope_base      = 10000.0f0 # RoPE frequency base

# ── Training ──
batch_size     = 64
learning_rate  = 3e-4
max_iters      = 5000
eval_interval  = 500
eval_iters     = 100
warmup_iters   = 200
min_lr         = 1e-5
max_grad_norm  = 1.0f0     # gradient clipping threshold

# ── Curriculum learning ──
curriculum_enabled = true
curriculum_warmup  = 1000   # steps of trivium-only before mixing in harder texts

println("Architecture: n_embd=$n_embd, n_layer=$n_layer, n_head=$n_head (Q), n_kv_head=$n_kv_head (KV)")
println("GQA ratio: $(n_head ÷ n_kv_head) Q heads per KV head")
println("Training: batch=$batch_size, lr=$learning_rate, iters=$max_iters, clip=$max_grad_norm")
println("Curriculum: enabled=$curriculum_enabled, warmup=$curriculum_warmup steps")

---
## 4. Dataset — Classical Curriculum

Loads pre-cleaned philosophy corpus from the text pipeline, organized by
the classical trivium/quadrivium curriculum:

1. **Trivium** (language arts): grammar, rhetoric, logic
2. **Quadrivium** (mathematical arts): arithmetic, geometry, music, astronomy
3. **Philosophy**: ethics, metaphysics, politics

The curriculum learning approach trains on simpler texts first (trivium),
then progressively adds harder material — inspired by "Textbooks Are All You Need"
(Microsoft Phi research).

**Data flow:**
```
text-pipeline/ → clean → chunk → train.txt + train_trivium.txt + train_quadrivium.txt + train_philosophy.txt
                                → push to HuggingFace
                                                    ↓
juliaflux_v2.ipynb → pulls data/ → BPE tokenize → curriculum batches → train
```

In [None]:
# ── Load pre-cleaned data from text pipeline ──
DATA_DIR = "data"
train_file = joinpath(DATA_DIR, "train.txt")
val_file   = joinpath(DATA_DIR, "val.txt")

DEFAULT_DATA_REPO = "LisaMegaWatts/philosophy-corpus"
HF_DATA_REPO = let r = get(secrets, "HF_DATA_REPO", ""); isempty(r) ? DEFAULT_DATA_REPO : r end

# Julia-side fallback: try huggingface-cli if Python download missed
if !isfile(train_file) || !isfile(val_file)
    println("Data not found locally, trying huggingface-cli download...")
    mkpath(DATA_DIR)
    try
        run(`huggingface-cli download $HF_DATA_REPO train.txt val.txt --local-dir $DATA_DIR`)
    catch e
        @warn "Download failed: $e"
    end
end

if !isfile(train_file) || !isfile(val_file)
    error("No data found in $DATA_DIR/. Re-run the Python setup cell.")
end

train_text = read(train_file, String)
val_text   = read(val_file, String)

println("Data loaded from $DATA_DIR/ ($HF_DATA_REPO)")
println("  train.txt: $(length(train_text)) chars ($(count('\n', train_text)) chunks)")
println("  val.txt:   $(length(val_text)) chars ($(count('\n', val_text)) chunks)")

# ── Load curriculum phase files (optional) ──
phase_data = Dict{String, String}()
for phase in ["trivium", "quadrivium", "philosophy"]
    phase_file = joinpath(DATA_DIR, "train_$(phase).txt")
    if isfile(phase_file)
        phase_data[phase] = read(phase_file, String)
        n_chunks = count('\n', phase_data[phase])
        println("  train_$(phase).txt: $(length(phase_data[phase])) chars ($n_chunks chunks)")
    end
end

if isempty(phase_data)
    println("\n  No curriculum phase files found — will use full corpus for all phases")
    curriculum_enabled = false
end

# ── Try to download tokenizer.json ──
tokenizer_file = joinpath(DATA_DIR, "tokenizer.json")
if !isfile(tokenizer_file)
    try
        run(`huggingface-cli download $HF_DATA_REPO tokenizer.json --local-dir $DATA_DIR`)
        println("  Downloaded tokenizer.json")
    catch
        println("  No tokenizer.json available (will use char-level tokenizer)")
    end
end

In [None]:
# ── Tokenizer: BPE with character-level fallback ──
using JSON3

TOKENIZER_PATH = joinpath(DATA_DIR, "tokenizer.json")
USE_BPE = isfile(TOKENIZER_PATH)

if USE_BPE
    println("Loading BPE tokenizer from $TOKENIZER_PATH ...")
    tok_raw = read(TOKENIZER_PATH, String)
    tok_json = JSON3.read(tok_raw)

    # Parse vocabulary: token_string -> id (convert to 1-indexed for Julia)
    bpe_vocab = Dict{String, Int}()
    for (tok_str, id) in pairs(tok_json.model.vocab)
        bpe_vocab[string(tok_str)] = Int(id) + 1
    end

    # Parse merges: ordered list of (a, b) pairs
    bpe_merges = Vector{Tuple{String,String}}()
    for merge_str in tok_json.model.merges
        parts = split(string(merge_str), " ", limit=2)
        if length(parts) == 2
            push!(bpe_merges, (String(parts[1]), String(parts[2])))
        end
    end

    # Reverse vocab: id -> token_string
    bpe_id_to_token = Dict{Int, String}(id => tok for (tok, id) in bpe_vocab)

    global vocab_size = length(bpe_vocab)

    # BPE encode: apply merges in priority order
    function bpe_encode_word(word::Vector{String})
        tokens = copy(word)
        for (a, b) in bpe_merges
            i = 1
            while i < length(tokens)
                if tokens[i] == a && tokens[i+1] == b
                    tokens = vcat(tokens[1:i-1], [a * b], tokens[i+2:end])
                else
                    i += 1
                end
            end
        end
        return tokens
    end

    function encode(s::String)
        # Byte-level BPE: each byte is a starting token
        chars = [string(c) for c in s]
        tokens = bpe_encode_word(chars)
        ids = Int[]
        for tok in tokens
            id = get(bpe_vocab, tok, nothing)
            if id !== nothing
                push!(ids, id)
            end
        end
        return ids
    end

    function decode(ids::Vector{Int})
        tokens = [get(bpe_id_to_token, id, "") for id in ids]
        return join(tokens)
    end

    println("BPE tokenizer: vocab_size=$vocab_size, $(length(bpe_merges)) merges")

else
    # ── Fallback: character-level tokenizer ──
    println("No tokenizer.json found — using character-level tokenizer")

    full_text = train_text * "\n" * val_text
    chars = sort(unique(full_text))
    filter!(c -> c != '\n', chars)
    global vocab_size = length(chars)

    stoi = Dict(c => i for (i, c) in enumerate(chars))
    itos = Dict(i => c for (i, c) in enumerate(chars))

    encode(s::String) = [stoi[c] for c in s if haskey(stoi, c)]
    decode(ids::Vector{Int}) = join(itos[i] for i in ids)

    println("Char-level tokenizer: vocab_size=$vocab_size -> [$(join(chars))]")
end

# ── Encode training and validation data ──
train_clean = replace(strip(train_text), '\n' => ' ')
val_clean   = replace(strip(val_text), '\n' => ' ')

global train_data = encode(train_clean)
global val_data   = encode(val_clean)

# ── Encode curriculum phase data ──
global phase_encoded = Dict{String, Vector{Int}}()
for (phase, text) in phase_data
    clean = replace(strip(text), '\n' => ' ')
    phase_encoded[phase] = encode(clean)
    println("  Phase $phase: $(length(phase_encoded[phase])) tokens")
end

println("\nTrain: $(length(train_data)) tokens")
println("Val:   $(length(val_data)) tokens")
println("Total: $(length(train_data) + length(val_data)) tokens")

---
## 5. Model Architecture

LLaMA-style transformer with SOTA components.
All structs defined in one cell (Julia limitation: structs cannot be redefined).

| Component | Old (v1) | New (v2) |
|-----------|----------|----------|
| Normalization | LayerNorm | **RMSNorm** |
| Position encoding | Learned absolute | **RoPE** (rotary) |
| FFN activation | GELU (2 matrices) | **SwiGLU** (3 matrices) |
| Attention | Standard MHA | **GQA** (grouped query) |
| Output head | Separate Dense | **Weight-tied** with embedding |
| Gradient | No clipping | **ClipNorm(1.0)** |

In [None]:
# ── Curriculum learning state ──
global curriculum_step = 0

function get_batch(split="train")
    global curriculum_step

    if split == "val"
        d = val_data
    elseif curriculum_enabled && !isempty(phase_encoded)
        # Curriculum: start with trivium, progressively add harder material
        progress = min(curriculum_step / curriculum_warmup, 1.0)

        if progress < 0.33 && haskey(phase_encoded, "trivium") && !isempty(phase_encoded["trivium"])
            d = phase_encoded["trivium"]
        elseif progress < 0.66
            # Mix trivium + quadrivium
            sources = Vector{Int}[]
            haskey(phase_encoded, "trivium") && push!(sources, phase_encoded["trivium"])
            haskey(phase_encoded, "quadrivium") && push!(sources, phase_encoded["quadrivium"])
            d = isempty(sources) ? train_data : vcat(sources...)
        else
            d = train_data  # full corpus
        end
    else
        d = train_data
    end

    # Ensure data is long enough
    if length(d) <= block_size + 1
        d = train_data
    end

    ix = rand(1:length(d) - block_size, batch_size)
    x = hcat([d[i:i+block_size-1] for i in ix]...)
    y = hcat([d[i+1:i+block_size] for i in ix]...)
    x = permutedims(x)   # (B, T)
    y = permutedims(y)
    x = x |> device
    y = y |> device
    return x, y
end

In [None]:
# ══════════════════════════════════════════════════════════════════
# ALL MODEL STRUCTS IN ONE CELL (Julia structs cannot be redefined)
# Architecture: RMSNorm + RoPE + SwiGLU + GQA + Weight Tying
# ══════════════════════════════════════════════════════════════════

using NNlib: batched_mul

# ── Pre-compute causal mask ──
const CAUSAL_MASK = triu(fill(typemin(Float32), block_size, block_size), 1)
const CAUSAL_MASK_GPU = CUDA.functional() ? cu(CAUSAL_MASK) : CAUSAL_MASK

# ──────────────────────────────────────────────────────────────────
# RoPE: Rotary Position Embeddings
# ──────────────────────────────────────────────────────────────────

const HEAD_DIM = n_embd ÷ n_head

function precompute_rope_freqs(head_dim::Int, max_seq_len::Int; base::Float32 = 10000.0f0)
    half_dim = head_dim ÷ 2
    freqs = Float32[1.0f0 / (base ^ (Float32(2 * (i - 1)) / Float32(head_dim))) for i in 1:half_dim]
    positions = Float32.(collect(0:max_seq_len-1))
    angles = freqs * positions'
    return cos.(angles), sin.(angles)
end

const ROPE_COS, ROPE_SIN = precompute_rope_freqs(HEAD_DIM, block_size; base=rope_base)
const ROPE_COS_GPU = CUDA.functional() ? cu(ROPE_COS) : ROPE_COS
const ROPE_SIN_GPU = CUDA.functional() ? cu(ROPE_SIN) : ROPE_SIN

function apply_rope(x, cos_f, sin_f, T::Int)
    d = size(x, 1) ÷ 2
    x1 = x[1:d, :, :]
    x2 = x[d+1:2d, :, :]
    c = cos_f[:, 1:T]
    s = sin_f[:, 1:T]
    return vcat(x1 .* c .- x2 .* s, x1 .* s .+ x2 .* c)
end

# ──────────────────────────────────────────────────────────────────
# RMSNorm (replaces LayerNorm)
# ──────────────────────────────────────────────────────────────────

struct RMSNorm{W <: AbstractVector}
    weight::W
    eps::Float32
end

Flux.@layer RMSNorm

function RMSNorm(dim::Int; eps::Float32 = 1.0f-6)
    RMSNorm(ones(Float32, dim), eps)
end

function (rn::RMSNorm)(x)
    rms = sqrt.(mean(x .^ 2, dims=1) .+ rn.eps)
    return (x ./ rms) .* rn.weight
end

# ──────────────────────────────────────────────────────────────────
# SwiGLU Feed-Forward Network (replaces GELU FFN)
# ──────────────────────────────────────────────────────────────────

struct SwiGLUFFN
    w_gate::Dense
    w_up::Dense
    w_down::Dense
    drop::Dropout
end

Flux.@layer SwiGLUFFN

function SwiGLUFFN(n_embd::Int; bias=false, dropout=0.0)
    raw_inner = Int(floor(4 * n_embd * 2 / 3))
    inner_dim = max(64, 64 * div(raw_inner + 32, 64))
    SwiGLUFFN(
        Dense(n_embd => inner_dim; bias),
        Dense(n_embd => inner_dim; bias),
        Dense(inner_dim => n_embd; bias),
        Dropout(dropout)
    )
end

function (ff::SwiGLUFFN)(x)
    ff.drop(ff.w_down(NNlib.swish(ff.w_gate(x)) .* ff.w_up(x)))
end

# ──────────────────────────────────────────────────────────────────
# GQA-capable Causal Self-Attention
# ──────────────────────────────────────────────────────────────────

struct CausalSelfAttention
    wq::Dense
    wkv::Dense
    proj::Dense
    n_head::Int
    n_kv_head::Int
end

Flux.@layer CausalSelfAttention trainable=(wq, wkv, proj)

function CausalSelfAttention(n_embd::Int, n_head::Int, n_kv_head::Int; bias=false)
    head_dim = n_embd ÷ n_head
    kv_dim = head_dim * n_kv_head
    CausalSelfAttention(
        Dense(n_embd => n_embd; bias),
        Dense(n_embd => 2 * kv_dim; bias),
        Dense(n_embd => n_embd; bias),
        n_head,
        n_kv_head
    )
end

function (attn::CausalSelfAttention)(x)
    C, T, B = size(x)
    nh = attn.n_head
    nkv = attn.n_kv_head
    hs = C ÷ nh
    kv_dim = hs * nkv
    groups = nh ÷ nkv

    q = attn.wq(x)
    kv = attn.wkv(x)
    k = kv[1:kv_dim, :, :]
    v = kv[kv_dim+1:2*kv_dim, :, :]

    # Reshape to per-head tensors
    q = reshape(permutedims(reshape(q, hs, nh, T, B), (1, 3, 2, 4)), hs, T, nh * B)
    k = reshape(permutedims(reshape(k, hs, nkv, T, B), (1, 3, 2, 4)), hs, T, nkv * B)
    v = reshape(permutedims(reshape(v, hs, nkv, T, B), (1, 3, 2, 4)), hs, T, nkv * B)

    # Apply RoPE to Q and K
    cos_f = x isa CuArray ? ROPE_COS_GPU : ROPE_COS
    sin_f = x isa CuArray ? ROPE_SIN_GPU : ROPE_SIN
    q = apply_rope(q, cos_f, sin_f, T)
    k = apply_rope(k, cos_f, sin_f, T)

    # GQA: repeat KV heads to match Q heads
    if groups > 1
        k_4d = reshape(k, hs, T, nkv, B)
        v_4d = reshape(v, hs, T, nkv, B)
        k_rep = repeat(reshape(k_4d, hs, T, nkv, 1, B), 1, 1, 1, groups, 1)
        v_rep = repeat(reshape(v_4d, hs, T, nkv, 1, B), 1, 1, 1, groups, 1)
        k = reshape(permutedims(k_rep, (1, 2, 4, 3, 5)), hs, T, nh * B)
        v = reshape(permutedims(v_rep, (1, 2, 4, 3, 5)), hs, T, nh * B)
    end

    # Attention scores
    scale = Float32(1 / sqrt(hs))
    wei = batched_mul(permutedims(q, (2, 1, 3)), k) .* scale

    mask = x isa CuArray ? CAUSAL_MASK_GPU[1:T, 1:T] : CAUSAL_MASK[1:T, 1:T]
    wei = wei .+ mask
    wei = softmax(wei; dims=2)

    out = batched_mul(v, permutedims(wei, (2, 1, 3)))
    out = reshape(permutedims(reshape(out, hs, T, nh, B), (1, 3, 2, 4)), C, T, B)

    attn.proj(out)
end

# ──────────────────────────────────────────────────────────────────
# TransformerBlock (pre-norm residual with RMSNorm)
# ──────────────────────────────────────────────────────────────────

struct TransformerBlock
    ln1::RMSNorm
    attn::CausalSelfAttention
    ln2::RMSNorm
    ffwd::SwiGLUFFN
end

Flux.@layer TransformerBlock

function TransformerBlock(n_embd::Int, n_head::Int, n_kv_head::Int; dropout=0.0)
    TransformerBlock(
        RMSNorm(n_embd),
        CausalSelfAttention(n_embd, n_head, n_kv_head),
        RMSNorm(n_embd),
        SwiGLUFFN(n_embd; dropout)
    )
end

function (block::TransformerBlock)(x)
    x = x .+ block.attn(block.ln1(x))
    x = x .+ block.ffwd(block.ln2(x))
    x
end

# ──────────────────────────────────────────────────────────────────
# TiedDense: weight-tied output projection
# ──────────────────────────────────────────────────────────────────

struct TiedDense{W <: AbstractMatrix}
    weight_ref::W
end

Flux.@layer TiedDense trainable=()

function (td::TiedDense)(x)
    C, T, B = size(x)
    W = td.weight_ref
    x_flat = reshape(x, C, T * B)
    out = W' * x_flat
    reshape(out, size(W, 2), T, B)
end

# ──────────────────────────────────────────────────────────────────
# GPT Model (LLaMA-style: no wpe, RoPE, weight-tied lm_head)
# ──────────────────────────────────────────────────────────────────

struct GPT
    wte::Embedding
    drop::Dropout
    blocks::Chain
    ln_f::RMSNorm
    lm_head::TiedDense
end

Flux.@layer GPT

function GPT(; vocab_size, n_embd, block_size, n_layer, n_head, n_kv_head, dropout=0.1)
    wte = Embedding(vocab_size => n_embd)
    GPT(
        wte,
        Dropout(dropout),
        Chain([TransformerBlock(n_embd, n_head, n_kv_head; dropout) for _ in 1:n_layer]...),
        RMSNorm(n_embd),
        TiedDense(wte.weight)
    )
end

function (m::GPT)(idx)
    B, T = size(idx)
    tok = permutedims(m.wte(idx), (1, 3, 2))
    x = m.drop(tok)
    x = m.blocks(x)
    x = m.ln_f(x)
    m.lm_head(x)
end

println("Model structs defined: RMSNorm, SwiGLUFFN, CausalSelfAttention (GQA), TransformerBlock, TiedDense, GPT")
println("SOTA: RoPE (head_dim=$(HEAD_DIM)), SwiGLU, GQA ($(n_head)Q/$(n_kv_head)KV), RMSNorm, weight tying")
println("Pre-computed: causal mask $(size(CAUSAL_MASK)), RoPE tables $(size(ROPE_COS))")

In [None]:
model = GPT(;
    vocab_size = vocab_size,
    n_embd     = n_embd,
    block_size = block_size,
    n_layer    = n_layer,
    n_head     = n_head,
    n_kv_head  = n_kv_head,
    dropout    = dropout
) |> device

n_params = sum(length, Flux.trainables(model))
println("Model created on $device")
println("Parameters: $(n_params) ($(round(n_params/1e6, digits=2))M)")
println("  Weight tying saves $(vocab_size * n_embd) params = $(round(vocab_size * n_embd / 1e3, digits=1))K")

if CUDA.functional()
    println("GPU memory: $(round(CUDA.used_memory() / 1024^2, digits=1)) MB")
end

# Smoke test
x_test, y_test = get_batch("train")
logits_test = model(x_test)
println("Forward pass OK — logits: $(size(logits_test))")
@assert size(logits_test, 1) == vocab_size
@assert size(logits_test, 2) == block_size
@assert size(logits_test, 3) == batch_size

---
## 6. Checkpoint Save/Load

In [None]:
LOCAL_CKPT = "checkpoints"
mkpath(LOCAL_CKPT)

function save_checkpoint(path::String, model, opt_state;
                          step::Int=0, best_val_loss::Float64=Inf,
                          train_losses::Vector{Float64}=Float64[],
                          val_losses::Vector{Float64}=Float64[])
    mkpath(dirname(path))
    model_cpu = cpu(model)
    opt_cpu = cpu(opt_state)
    JLD2.jldsave(path;
        model_state = Flux.state(model_cpu),
        opt_state = opt_cpu,
        step = step,
        best_val_loss = best_val_loss,
        train_losses = train_losses,
        val_losses = val_losses,
        hyperparams = Dict(
            "vocab_size" => vocab_size,
            "n_embd" => n_embd,
            "block_size" => block_size,
            "n_layer" => n_layer,
            "n_head" => n_head,
            "n_kv_head" => n_kv_head,
            "dropout" => dropout,
            "use_bpe" => USE_BPE,
            "rope_base" => rope_base
        )
    )
    vl_str = best_val_loss == Inf ? "Inf" : @sprintf("%.4f", best_val_loss)
    println("Checkpoint saved: $path (step $step, best_val_loss=$vl_str)")
end

function save_and_sync(path, model, opt_state; kwargs...)
    save_checkpoint(path, model, opt_state; kwargs...)
    hf_sync(path)
end

function load_checkpoint(path::String, device_fn)
    println("Loading checkpoint from $path ...")
    data = JLD2.load(path)

    hp = data["hyperparams"]
    m = GPT(;
        vocab_size = hp["vocab_size"],
        n_embd     = hp["n_embd"],
        block_size = hp["block_size"],
        n_layer    = hp["n_layer"],
        n_head     = hp["n_head"],
        n_kv_head  = get(hp, "n_kv_head", hp["n_head"]),
        dropout    = get(hp, "dropout", 0.1)
    )
    Flux.loadmodel!(m, data["model_state"])
    m = m |> device_fn

    opt = data["opt_state"]

    println("  step=$(data[\"step\"]), best_val=$(round(data[\"best_val_loss\"], digits=4))")
    return (;
        model = m,
        opt_state = opt |> device_fn,
        step = data["step"],
        best_val_loss = data["best_val_loss"],
        train_losses = get(data, "train_losses", Float64[]),
        val_losses = get(data, "val_losses", Float64[])
    )
end

println("Checkpoint save/load defined (JLD2 + HuggingFace sync)")

---
## 7. Training Loop

Adam optimizer with cosine LR + warmup + gradient clipping.
Reports both loss and perplexity. Curriculum learning advances automatically.

In [None]:
using Printf

# ── Generate text helper (defined here so training loop can call it) ──
function generate_text(model, max_tokens=200; temperature=0.8, prompt="")
    model_eval = Flux.testmode!(deepcopy(model))
    if !isempty(prompt)
        prompt_ids = encode(prompt)
        idx = reshape(prompt_ids, 1, :) |> device
    else
        idx = reshape([rand(1:vocab_size)], 1, 1) |> device
    end
    generated = Int[]
    for _ in 1:max_tokens
        idx_cond = idx[:, max(1, end-block_size+1):end]
        logits = model_eval(idx_cond)
        logits_last = logits[:, end, 1]
        probs = softmax(logits_last ./ Float32(temperature))
        probs_cpu = Float64.(cpu(probs))
        r = rand()
        cum = 0.0
        next_id = 1
        for (i, p) in enumerate(probs_cpu)
            cum += p
            if r <= cum
                next_id = i
                break
            end
        end
        push!(generated, next_id)
        next_token = reshape([next_id], 1, 1) |> device
        idx = hcat(idx, next_token)
    end
    return decode(generated)
end

function estimate_loss(model, n_iters=eval_iters)
    model_eval = Flux.testmode!(deepcopy(model))
    losses = Dict{String, Float64}()
    for split in ["train", "val"]
        total = 0.0
        for _ in 1:n_iters
            x, y = get_batch(split)
            logits = model_eval(x)
            y_flat = reshape(y, :)
            logits_flat = reshape(logits, vocab_size, :)
            onehot = Flux.onehotbatch(y_flat, 1:vocab_size) |> device
            loss = Flux.logitcrossentropy(logits_flat, onehot)
            total += loss
        end
        losses[split] = total / n_iters
    end
    losses["train_ppl"] = exp(losses["train"])
    losses["val_ppl"] = exp(losses["val"])
    return losses
end

function compute_diversity(text::String)
    words = split(text)
    isempty(words) && return (distinct1=0.0, distinct2=0.0, rep_rate=0.0)
    distinct1 = length(Set(words)) / length(words)
    bigrams = [words[i] * " " * words[i+1] for i in 1:length(words)-1]
    distinct2 = isempty(bigrams) ? 0.0 : length(Set(bigrams)) / length(bigrams)
    if length(words) >= 3
        trigrams = [join(words[i:i+2], " ") for i in 1:length(words)-2]
        rep_rate = 1.0 - length(Set(trigrams)) / length(trigrams)
    else
        rep_rate = 0.0
    end
    return (distinct1=round(distinct1, digits=3), distinct2=round(distinct2, digits=3), rep_rate=round(rep_rate, digits=3))
end

function get_lr(iter)
    if iter < warmup_iters
        return learning_rate * iter / warmup_iters
    end
    decay_ratio = (iter - warmup_iters) / (max_iters - warmup_iters)
    coeff = 0.5 * (1.0 + cos(Float64(pi) * decay_ratio))
    return min_lr + coeff * (learning_rate - min_lr)
end

# ── Optimizer with gradient clipping ──
opt_state = Flux.setup(
    OptimiserChain(ClipNorm(max_grad_norm), Adam(learning_rate)),
    model
)

best_val = Inf
train_loss_history = Float64[]
val_loss_history = Float64[]

if haskey(ENV, "WANDB_API_KEY") && !isempty(ENV["WANDB_API_KEY"])
    wandb_init()
end

SAVE_INTERVAL = 600
last_save_time = time()
completed_iter = 0

println("Training for $max_iters steps (curriculum=$(curriculum_enabled))...")
t_start = time()

try
    for iter in 1:max_iters
        global completed_iter = iter

        # Advance curriculum
        if curriculum_enabled
            global curriculum_step = iter
        end

        lr_t = get_lr(iter)
        Flux.adjust!(opt_state, lr_t)

        x, y = get_batch("train")
        loss, grads = Flux.withgradient(model) do m
            logits = m(x)
            y_flat = reshape(y, :)
            logits_flat = reshape(logits, vocab_size, :)
            onehot = Flux.onehotbatch(y_flat, 1:vocab_size) |> device
            Flux.logitcrossentropy(logits_flat, onehot)
        end
        Flux.update!(opt_state, model, grads[1])
        push!(train_loss_history, Float64(loss))

        if iter % 100 == 0 && CUDA.functional()
            GC.gc(false)
        end

        if iter % eval_interval == 0 || iter == 1
            losses = estimate_loss(model)
            push!(val_loss_history, losses["val"])
            elapsed = round(time() - t_start, digits=1)
            wandb_log(; step=iter, train_loss=losses["train"], val_loss=losses["val"],
                       train_ppl=losses["train_ppl"], val_ppl=losses["val_ppl"], lr=lr_t)

            improved = ""
            if losses["val"] < best_val
                best_val = losses["val"]
                save_and_sync("checkpoints/best_model.jld2", model, opt_state;
                    step=iter, best_val_loss=best_val,
                    train_losses=train_loss_history, val_losses=val_loss_history)
                improved = " << best!"
            end

            phase_str = curriculum_enabled ? " [$(curriculum_step < curriculum_warmup * 0.33 ? "trivium" : curriculum_step < curriculum_warmup * 0.66 ? "tri+quad" : "full")]" : ""
            @printf("step %5d | train %.4f (ppl %.1f) | val %.4f (ppl %.1f) | lr %.2e | %.1fs%s%s\n",
                    iter, losses["train"], losses["train_ppl"],
                    losses["val"], losses["val_ppl"], lr_t, elapsed, phase_str, improved)

            # Diversity check every 5th eval
            if iter % (eval_interval * 5) == 0
                sample = generate_text(model, 200; temperature=0.8)
                div = compute_diversity(sample)
                @printf("  diversity: D1=%.3f D2=%.3f rep=%.3f\n", div.distinct1, div.distinct2, div.rep_rate)
                wandb_log(; step=iter, distinct1=div.distinct1, distinct2=div.distinct2, rep_rate=div.rep_rate)
            end
        end

        if iter % 1000 == 0
            save_and_sync("checkpoints/checkpoint_latest.jld2", model, opt_state;
                step=iter, best_val_loss=best_val,
                train_losses=train_loss_history, val_losses=val_loss_history)
            last_save_time = time()
        end

        if time() - last_save_time > SAVE_INTERVAL
            save_and_sync("checkpoints/checkpoint_latest.jld2", model, opt_state;
                step=iter, best_val_loss=best_val,
                train_losses=train_loss_history, val_losses=val_loss_history)
            last_save_time = time()
            println("  [auto-save at step $iter]")
        end
    end
catch e
    if e isa InterruptException
        println("\n\nInterrupted at step $completed_iter!")
    else
        println("\n\nError at step $completed_iter: $e")
    end
    save_and_sync("checkpoints/checkpoint_interrupted.jld2", model, opt_state;
        step=completed_iter, best_val_loss=best_val,
        train_losses=train_loss_history, val_losses=val_loss_history)
    e isa InterruptException || rethrow(e)
end

elapsed = round(time() - t_start, digits=1)
println("\nTraining complete in $(elapsed)s. Best val loss: $(round(best_val, digits=4)) (ppl $(round(exp(best_val), digits=1)))")
wandb_finish()

save_and_sync("checkpoints/final_model.jld2", model, opt_state;
    step=max_iters, best_val_loss=best_val,
    train_losses=train_loss_history, val_losses=val_loss_history)

---
## 8. Inference — Generate Text

Temperature-controlled sampling with optional prompt.

In [None]:
println("--- Generated Philosophy ---")
for i in 1:5
    text = generate_text(model, 300; temperature=0.8)
    @printf("\nSample %d:\n%s\n", i, text[1:min(end, 500)])
    println("---")
end

---
## 8a. Push Model to HuggingFace Hub
Push your trained checkpoint to HuggingFace for persistence across Colab sessions.
Set `HF_REPO_ID` in the login cell above.

In [None]:
if @isdefined(HF_REPO_ID) && !isempty(HF_REPO_ID)
    hf_create_repo(HF_REPO_ID)

    if isfile("checkpoints/best_model.jld2")
        hf_push_checkpoint(HF_REPO_ID; checkpoint_path="checkpoints/best_model.jld2")
    else
        println("No best_model.jld2 found -- train first!")
    end

    if isfile("checkpoints/final_model.jld2")
        hf_push(HF_REPO_ID, "checkpoints/final_model.jld2")
    end

    println("\nDone! View your model at: https://huggingface.co/$HF_REPO_ID")
else
    println("Set HF_REPO_ID in the login cell (e.g. \"yourusername/juliaflux-philosophy\")")
end

---
## 8b. Pull Checkpoint from HuggingFace Hub
Download a previously pushed checkpoint to resume training in a new Colab session.

In [None]:
if @isdefined(HF_REPO_ID) && !isempty(HF_REPO_ID)
    mkpath("checkpoints")
    hf_pull(HF_REPO_ID, "best_model.jld2"; local_dir="checkpoints")
    println("\nReady to resume from checkpoints/best_model.jld2")
    println("Run the 'Resume Training' cell below.")
else
    println("Set HF_REPO_ID in the login cell (e.g. \"yourusername/juliaflux-philosophy\")")
end

---
## 9. Resume Training from Checkpoint
Load a saved checkpoint and continue training for more steps.
Skip this cell if you're training from scratch above.

In [None]:
RESUME_FROM = "checkpoints/best_model.jld2"
EXTRA_ITERS = 2000

if !isfile(RESUME_FROM)
    if @isdefined(HF_REPO_ID) && !isempty(HF_REPO_ID)
        println("Checkpoint not found locally, pulling from HuggingFace...")
        hf_pull(HF_REPO_ID, basename(RESUME_FROM); local_dir="checkpoints")
    end
    isfile(RESUME_FROM) || error("Checkpoint not found: $RESUME_FROM")
end

ckpt = load_checkpoint(RESUME_FROM, device)
model = ckpt.model
opt_state = ckpt.opt_state
start_iter = ckpt.step + 1
best_val = ckpt.best_val_loss
train_loss_history = copy(ckpt.train_losses)
val_loss_history = copy(ckpt.val_losses)
end_iter = ckpt.step + EXTRA_ITERS

if haskey(ENV, "WANDB_API_KEY") && !isempty(ENV["WANDB_API_KEY"])
    wandb_init()
end

println("\nResuming from step $(ckpt.step) -> training to step $end_iter")
println("Best val loss so far: $(round(best_val, digits=4))")
t_start = time()
last_save_time = time()

try
    for iter in start_iter:end_iter
        global completed_iter = iter

        if curriculum_enabled
            global curriculum_step = iter
        end

        lr_t = get_lr(min(iter, max_iters))
        Flux.adjust!(opt_state, lr_t)

        x, y = get_batch("train")
        loss, grads = Flux.withgradient(model) do m
            logits = m(x)
            y_flat = reshape(y, :)
            logits_flat = reshape(logits, vocab_size, :)
            onehot = Flux.onehotbatch(y_flat, 1:vocab_size) |> device
            Flux.logitcrossentropy(logits_flat, onehot)
        end
        Flux.update!(opt_state, model, grads[1])
        push!(train_loss_history, Float64(loss))

        if iter % 100 == 0 && CUDA.functional()
            GC.gc(false)
        end

        if iter % eval_interval == 0
            losses = estimate_loss(model)
            push!(val_loss_history, losses["val"])
            elapsed = round(time() - t_start, digits=1)
            wandb_log(; step=iter, train_loss=losses["train"], val_loss=losses["val"],
                       train_ppl=losses["train_ppl"], val_ppl=losses["val_ppl"], lr=lr_t)

            improved = ""
            if losses["val"] < best_val
                best_val = losses["val"]
                save_and_sync("checkpoints/best_model.jld2", model, opt_state;
                    step=iter, best_val_loss=best_val,
                    train_losses=train_loss_history, val_losses=val_loss_history)
                improved = " << best!"
            end

            @printf("step %5d / %5d | train %.4f (ppl %.1f) | val %.4f (ppl %.1f) | lr %.2e | %.1fs%s\n",
                    iter, end_iter, losses["train"], losses["train_ppl"],
                    losses["val"], losses["val_ppl"], lr_t, elapsed, improved)
        end

        if iter % 1000 == 0
            save_and_sync("checkpoints/checkpoint_latest.jld2", model, opt_state;
                step=iter, best_val_loss=best_val,
                train_losses=train_loss_history, val_losses=val_loss_history)
            last_save_time = time()
        end

        if time() - last_save_time > SAVE_INTERVAL
            save_and_sync("checkpoints/checkpoint_latest.jld2", model, opt_state;
                step=iter, best_val_loss=best_val,
                train_losses=train_loss_history, val_losses=val_loss_history)
            last_save_time = time()
            println("  [auto-save at step $iter]")
        end
    end
catch e
    if e isa InterruptException
        println("\n\nTraining interrupted at step $completed_iter!")
    else
        println("\n\nTraining error at step $completed_iter: $e")
    end
    save_and_sync("checkpoints/checkpoint_interrupted.jld2", model, opt_state;
        step=completed_iter, best_val_loss=best_val,
        train_losses=train_loss_history, val_losses=val_loss_history)
    e isa InterruptException || rethrow(e)
end

elapsed = round(time() - t_start, digits=1)
@printf("\nResume training complete in %.1fs\n", elapsed)
wandb_finish()

save_and_sync("checkpoints/final_model.jld2", model, opt_state;
    step=end_iter, best_val_loss=best_val,
    train_losses=train_loss_history, val_losses=val_loss_history)

---
## 10. Download Checkpoint

Download the best model checkpoint to use elsewhere.
In Colab, use the Files panel (left sidebar) to download, or pull from HuggingFace Hub.

In [None]:
if isdir("checkpoints")
    files = readdir("checkpoints")
    println("Saved checkpoints:")
    for f in files
        path = joinpath("checkpoints", f)
        size_kb = round(filesize(path) / 1024, digits=1)
        println("  $path ($(size_kb) KB)")
    end
else
    println("No checkpoints directory found. Train first!")
end