In [None]:
# Cell 1 – Install packages (run once, restart after if needed)
using Pkg

Pkg.add([
    "Flux", "Zygote", "Optimisers",
    "CUDA", "cuDNN",
    "Downloads", "Statistics", "Random",
    "Printf", "LinearAlgebra", "JLD2",
    "NNlib"
])

[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.11/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.11/Manifest.toml`


In [None]:
# Cell 2 – Imports
using Flux
using Zygote
using Optimisers
using CUDA
using NNlib
using Downloads
using Statistics
using Random
using Printf
using LinearAlgebra
using JLD2

Random.seed!(1337)

device = CUDA.functional() ? gpu : cpu
println("Device: ", device)
println("CUDA functional: ", CUDA.functional())

Device: gpu
CUDA functional: true


In [None]:
# Cell 3 – Hyperparameters
block_size     = 128
n_embd         = 256
n_head         = 4
n_layer        = 4
dropout        = 0.1
bias           = false

batch_size     = 32
learning_rate  = 4e-4
max_iters      = 5000
eval_interval  = 500
eval_iters     = 100
warmup_iters   = 200
min_lr         = 1e-5

println("n_embd = $n_embd")

n_embd = 256


In [None]:
# Cell 4 – Download texts
function download_and_clean(url::String, fn::String; is_gutenberg=true)
    if !isfile(fn)
        println("Downloading $fn from $url")
        try
            Downloads.download(url, fn)
        catch e
            @warn "Download failed: $url → $e"
            return ""
        end
    end

    txt = read(fn, String)

    if is_gutenberg
        txt = replace(txt, r"(?is)^.*?\*{3}\s*START OF (THE|THIS) PROJECT GUTENBERG.*?\*{3}[\r\n]*" => "")
        txt = replace(txt, r"(?is)\*{3}\s*END OF (THE|THIS) PROJECT GUTENBERG.*$" => "")
    end

    txt = replace(txt, r"\r\n" => "\n")
    txt = replace(txt, r"\n{3,}" => "\n\n")
    txt = strip(txt)

    return txt
end

sources = Dict(
    "grammar"             => ("https://www.gutenberg.org/files/15665/15665-0.txt",        "latin_grammar.txt",      true),
    "categories"          => ("https://www.gutenberg.org/ebooks/2412.txt.utf-8",          "aristotle_categories.txt", true),
    "rhetoric"            => ("http://classics.mit.edu/Aristotle/rhetoric.mb.txt",        "aristotle_rhetoric.txt", false),
    "prior_analytics"     => ("http://classics.mit.edu/Aristotle/prior.mb.txt",           "prior_analytics.txt",    false),
    "posterior_analytics" => ("http://classics.mit.edu/Aristotle/posterior.mb.txt",       "posterior_analytics.txt", false),
    "topics"              => ("http://classics.mit.edu/Aristotle/topics.mb.txt",          "topics.txt",             false),
    "boethius"            => ("https://www.gutenberg.org/files/14328/14328-0.txt",        "boethius_consolation.txt", true),
    "heavens"             => ("http://classics.mit.edu/Aristotle/heavens.mb.txt",         "aristotle_heavens.txt",  false),
    "republic"            => ("https://www.gutenberg.org/files/1497/1497-0.txt",          "plato_republic.txt",     true),
    "apology"             => ("https://www.gutenberg.org/files/1656/1656-0.txt",          "plato_apology.txt",      true),
    "ethics"              => ("https://www.gutenberg.org/files/8438/8438-0.txt",          "aristotle_ethics.txt",   true),
    "emerson"             => ("https://www.gutenberg.org/files/2944/2944-0.txt",          "emerson_essays.txt",     true),
    "walden"              => ("https://www.gutenberg.org/files/205/205-0.txt",            "thoreau_walden.txt",     true),
    "epicurus"            => ("https://www.gutenberg.org/files/57342/57342-0.txt",        "diogenes_epicurus.txt",  true)
)

texts = Dict{String,String}()
for (key, (url, fn, is_gut)) in sources
    texts[key] = download_and_clean(url, fn; is_gutenberg=is_gut)
end

println("Downloaded $(length(texts)) texts.")

Downloaded 14 texts.


In [None]:
# Cell 5 – Unicode normalization
for fn in [
    "latin_grammar.txt",
    "aristotle_categories.txt",
    "aristotle_rhetoric.txt",
    "prior_analytics.txt",
    "posterior_analytics.txt",
    "topics.txt",
    "boethius_consolation.txt",
    "aristotle_heavens.txt",
    "plato_republic.txt",
    "plato_apology.txt",
    "aristotle_ethics.txt",
    "emerson_essays.txt",
    "thoreau_walden.txt",
    "diogenes_epicurus.txt"
]
    isfile(fn) || continue

    txt = read(fn, String)
    txt = replace(txt,
        "“" => "\"", "”" => "\"",
        "‘" => "'", "’" => "'",
        "—" => "--", "–" => "-",
        "…" => "...",
        "\u00A0" => " "
    )

    txt = replace(txt, r"\n{3,}" => "\n\n")
    txt = strip(txt)

    open(fn, "w") do io
        write(io, txt)
    end

    println("Normalized: $fn")
end

Normalized: latin_grammar.txt
Normalized: aristotle_categories.txt
Normalized: aristotle_rhetoric.txt
Normalized: prior_analytics.txt
Normalized: posterior_analytics.txt
Normalized: topics.txt
Normalized: boethius_consolation.txt
Normalized: aristotle_heavens.txt
Normalized: plato_republic.txt
Normalized: plato_apology.txt
Normalized: aristotle_ethics.txt
Normalized: emerson_essays.txt
Normalized: thoreau_walden.txt
Normalized: diogenes_epicurus.txt


In [None]:
# Cell 6 – Build full_text, vocab, data
full_text = ""
for fn in [
    "latin_grammar.txt",
    "aristotle_categories.txt",
    "aristotle_rhetoric.txt",
    "prior_analytics.txt",
    "posterior_analytics.txt",
    "topics.txt",
    "boethius_consolation.txt",
    "aristotle_heavens.txt",
    "plato_republic.txt",
    "plato_apology.txt",
    "aristotle_ethics.txt",
    "emerson_essays.txt",
    "thoreau_walden.txt",
    "diogenes_epicurus.txt"
]
    isfile(fn) || continue
    content = read(fn, String)
    full_text *= "\n\n=== $fn ===\n\n" * content
end

full_text = strip(full_text)

chars = sort(unique(full_text))
global vocab_size = length(chars)

stoi = Dict(c => i for (i,c) in enumerate(chars))
itos = Dict(i => c for (i,c) in enumerate(chars))

encode(s) = [stoi[c] for c in s]
decode(ids) = join(itos[i] for i in ids)

data = encode(full_text)
n = length(data)
train_split = floor(Int, n * 0.9)
global train_data = data[1:train_split]
global val_data   = data[train_split+1:end]

println("Vocab size: ", vocab_size)
println("Train tokens: ", length(train_data))

Vocab size: 254
Train tokens: 6026460


In [None]:
# Cell 7 – Batch loader
function get_batch(split="train")
    d = split == "train" ? train_data : val_data
    ix = rand(1:length(d)-block_size, batch_size)
    x = hcat([d[i:i+block_size-1] for i in ix]...)
    y = hcat([d[i+1:i+block_size] for i in ix]...)
    x = permutedims(x)   # now (B, T)
    y = permutedims(y)
    x = x |> device
    y = y |> device
    return x, y
end

get_batch (generic function with 2 methods)

In [None]:
# ── All model structs in ONE cell (Julia structs can't be redefined) ──
using NNlib: batched_mul

struct CausalSelfAttention
    qkv::Dense       # single projection: n_embd → 3*n_embd
    proj::Dense       # output projection: n_embd → n_embd
    n_head::Int
end

Flux.@layer CausalSelfAttention trainable=(qkv, proj)

function CausalSelfAttention(n_embd::Int, n_head::Int; bias=false)
    CausalSelfAttention(
        Dense(n_embd => 3 * n_embd; bias),
        Dense(n_embd => n_embd; bias),
        n_head
    )
end

function (attn::CausalSelfAttention)(x)
    # x: (C, T, B)  — Flux convention: features first
    C, T, B = size(x)
    hs = C ÷ attn.n_head
    nh = attn.n_head

    # Single QKV projection → split
    qkv = attn.qkv(x)                          # (3C, T, B)
    q = qkv[1:C, :, :]
    k = qkv[C+1:2C, :, :]
    v = qkv[2C+1:3C, :, :]

    # Reshape to (hs, T, nh*B) for batched_mul
    q = reshape(permutedims(reshape(q, hs, nh, T, B), (1, 3, 2, 4)), hs, T, nh * B)
    k = reshape(permutedims(reshape(k, hs, nh, T, B), (1, 3, 2, 4)), hs, T, nh * B)
    v = reshape(permutedims(reshape(v, hs, nh, T, B), (1, 3, 2, 4)), hs, T, nh * B)

    # Attention scores: (T, T, nh*B)
    scale = Float32(1 / sqrt(hs))
    wei = batched_mul(permutedims(q, (2, 1, 3)), k) .* scale

    # Causal mask: upper triangle = -Inf
    mask = triu(fill(typemin(Float32), T, T), 1)
    wei = wei .+ mask

    wei = softmax(wei; dims=2)

    # Apply attention: (hs, T, nh*B)
    out = batched_mul(v, permutedims(wei, (2, 1, 3)))

    # Reshape back to (C, T, B)
    out = reshape(permutedims(reshape(out, hs, T, nh, B), (1, 3, 2, 4)), C, T, B)

    attn.proj(out)
end

struct FeedForward
    net::Chain
end

Flux.@layer FeedForward

function FeedForward(n_embd::Int; bias=false, dropout=0.0)
    FeedForward(Chain(
        Dense(n_embd => 4 * n_embd; bias),
        gelu,
        Dense(4 * n_embd => n_embd; bias),
        Dropout(dropout)
    ))
end

(ff::FeedForward)(x) = ff.net(x)

struct TransformerBlock
    ln1::LayerNorm
    attn::CausalSelfAttention
    ln2::LayerNorm
    ffwd::FeedForward
end

Flux.@layer TransformerBlock

function TransformerBlock(n_embd::Int, n_head::Int; dropout=0.0)
    TransformerBlock(
        LayerNorm(n_embd),
        CausalSelfAttention(n_embd, n_head),
        LayerNorm(n_embd),
        FeedForward(n_embd; dropout)
    )
end

function (block::TransformerBlock)(x)
    x = x .+ block.attn(block.ln1(x))
    x = x .+ block.ffwd(block.ln2(x))
    x
end

struct GPT
    wte::Embedding
    wpe::Embedding
    drop::Dropout
    blocks::Chain
    ln_f::LayerNorm
    lm_head::Dense
end

Flux.@layer GPT

function GPT(; vocab_size, n_embd, block_size, n_layer, n_head, dropout=0.1)
    GPT(
        Embedding(vocab_size => n_embd),
        Embedding(block_size => n_embd),
        Dropout(dropout),
        Chain([TransformerBlock(n_embd, n_head; dropout) for _ in 1:n_layer]...),
        LayerNorm(n_embd),
        Dense(n_embd => vocab_size; bias=false)
    )
end

function (m::GPT)(idx)
    # idx: (B, T) integer token IDs
    # Flux Embedding: input (B,T) → output (n_embd, B, T)
    B, T = size(idx)

    tok = permutedims(m.wte(idx), (1, 3, 2))               # (C, B, T) → (C, T, B)

    positions = repeat(collect(1:T)', B, 1)                 # (B, T)
    pos = permutedims(m.wpe(positions), (1, 3, 2))          # (C, T, B)

    x = m.drop(tok .+ pos)
    x = m.blocks(x)
    x = m.ln_f(x)
    m.lm_head(x)                                            # (vocab, T, B)
end

println("All model structs defined: CausalSelfAttention, FeedForward, TransformerBlock, GPT")

In [None]:
# ── Create model + move to GPU ──
model = GPT(;
    vocab_size = vocab_size,
    n_embd     = n_embd,
    block_size = block_size,
    n_layer    = n_layer,
    n_head     = n_head,
    dropout    = dropout
) |> device

# Count parameters
n_params = sum(length, Flux.params(model))
println("Model created on $device")
println("Parameters: $(n_params) ($(round(n_params/1e6, digits=2))M)")

# Quick smoke test
x_test, y_test = get_batch("train")
logits_test = model(x_test)
println("Forward pass OK — logits: $(size(logits_test))")

In [None]:
# ── Training loop ──
using Printf

function estimate_loss(model, n_iters=eval_iters)
    model_eval = Flux.testmode!(deepcopy(model))
    losses = Dict("train" => 0.0, "val" => 0.0)
    for split in ["train", "val"]
        total = 0.0
        for _ in 1:n_iters
            x, y = get_batch(split)
            logits = model_eval(x)                              # (vocab, T, B)
            loss = Flux.logitcrossentropy(
                reshape(logits, vocab_size, :),                 # (vocab, T*B)
                Flux.onehotbatch(reshape(y, :), 1:vocab_size)   # (vocab, T*B)
            )
            total += loss
        end
        losses[split] = total / n_iters
    end
    return losses
end

# LR schedule: warmup + cosine decay
function get_lr(iter)
    if iter < warmup_iters
        return learning_rate * iter / warmup_iters
    end
    decay_ratio = (iter - warmup_iters) / (max_iters - warmup_iters)
    coeff = 0.5 * (1.0 + cos(Float64(pi) * decay_ratio))
    return min_lr + coeff * (learning_rate - min_lr)
end

opt_state = Flux.setup(Adam(learning_rate), model)

println("Training for $max_iters steps...")
t_start = time()
best_val = Inf

for iter in 1:max_iters
    # Update LR
    lr_t = get_lr(iter)
    Flux.adjust!(opt_state, lr_t)

    # Forward + backward + update
    x, y = get_batch("train")
    loss, grads = Flux.withgradient(model) do m
        logits = m(x)
        Flux.logitcrossentropy(
            reshape(logits, vocab_size, :),
            Flux.onehotbatch(reshape(y, :), 1:vocab_size)
        )
    end
    Flux.update!(opt_state, model, grads[1])

    # Eval + print
    if iter % eval_interval == 0 || iter == 1
        losses = estimate_loss(model)
        improved = losses["val"] < best_val ? " *" : ""
        if losses["val"] < best_val
            best_val = losses["val"]
        end
        elapsed = round(time() - t_start, digits=1)
        @printf("step %5d | train %.4f | val %.4f | lr %.2e | %.1fs%s\n",
                iter, losses["train"], losses["val"], lr_t, elapsed, improved)
    end
end

elapsed = round(time() - t_start, digits=1)
println("\nTraining complete in $(elapsed)s. Best val loss: $(round(best_val, digits=4))")