## Training
* Trains a neural network with the given hyperparameters

In [None]:
function train_epoch!(m, opt; split::String = "training", rng = Random.GLOBAL_RNG)
    epoch = get_epoch(split)
    sampling_order = get_sampling_order(epoch, split, rng)
    nbatches = Int(ceil(length(sampling_order) / G.batch_size))
    losses = []
    @showprogress for iter = 1:nbatches
        batch, _ = get_batch(epoch, iter, G.batch_size, sampling_order, split == "training")
        loss, grads = Flux.withgradient(m) do model
            model_loss(model, batch...)
        end
        device_free!(batch)        
        Flux.update!(opt, m, grads[1])
        push!(losses, loss)        
    end
    mean(losses)
end;

In [None]:
# trains a model with the given hyperparams and returns its validation loss
function train_model(
    hyp::Hyperparams;
    max_checkpoints::Int = 1000,
    epochs_per_checkpoint::Int = 1,
    patience::Int = 0,
    verbose::String = "",
    init_model = nothing,
)
    global G = hyp
    opt_spec = get_optimizer(G.optimizer, G.learning_rate, G.optimizer_weight_decay)
    rng = Random.Xoshiro(G.seed)
    Random.seed!(rand(rng, UInt64))
    if CUDA.functional()
        Random.seed!(CUDA.default_rng(), rand(rng, UInt64))
        Random.seed!(CUDA.CURAND.default_rng(), rand(rng, UInt64))
    end

    if isnothing(init_model)
        m = build_model(rng = rng) |> device
    else
        m = init_model |> device
    end
    best_model = m |> cpu
    stopper = early_stopper(
        max_iters = max_checkpoints,
        patience = patience,
        min_rel_improvement = 1e-3,
    )
    opt = Optimisers.setup(opt_spec, m)

    losses = []
    loss = Inf
    training_loss = Inf
    while (!stop!(stopper, loss))
        for i = 1:epochs_per_checkpoint
            tloss = train_epoch!(m, opt; rng = rng)
        end
        loss = split_loss(m, "validation"; rng = rng)
        push!(losses, loss)
        if loss == minimum(losses)
            best_model = m |> cpu
            training_loss = tloss
        end
        if verbose == "info"
            @info "losses: $tloss $loss"
        end
    end
    global G = nothing
    epochs = stopper.iters - stopper.iters_without_improvement
    best_model, epochs, training_loss, minimum(losses)
end;