## Training
* Trains a neural network with the given hyperparameters

In [20]:
function train_epoch!(m, ps, opt; split="training")
    epoch = get_epoch(split)
    sampling_order = get_sampling_order(split)
    batchloss(x, y, z, w) = training_loss(m, x, y, z, w)
    nbatches = Int(ceil(length(sampling_order) / G.batch_size))
    for iter = 1:nbatches
        batch, _ = get_batch(epoch, iter, G.batch_size, sampling_order, split == "training")
        Flux.train!(batchloss, ps, batch, opt)
    end
end;

In [21]:
function apply_zero_gradient!(m, ps, opt, apply)
    # simulate training over other users by passing a zero gradient
    if apply
        zerobatches =
            Int(ceil(num_users() / G.batch_size)) - Int(ceil(G.num_users / G.batch_size))
        zerograd = gradient(ps) do
            sum(0 * m(1))
        end
        for _ = 1:zerobatches
            Flux.update!(opt, ps, zerograd)
        end
    end
end;

In [22]:
# trains a model with the given hyperparams and returns its validation loss
function train_model(
    hyp;
    max_checkpoints = 1000,
    epochs_per_checkpoint = 1,
    patience = 0,
    verbose = false,
    init_model = nothing,
    fine_tune_layers = nothing,
)
    global G = hyp
    opt = get_optimizer(G.optimizer, G.learning_rate, G.regularization_params)
    Random.seed!(G.seed)
    if isnothing(init_model)
        m = build_model() |> device
    else
        m = init_model |> device
    end
    best_model = m |> cpu
    if isnothing(fine_tune_layers)
        ps = Flux.params(m)
    else
        ps = Flux.params(m[fine_tune_layers])
    end
    stopper = early_stopper(max_iters = max_checkpoints, patience = patience)

    losses = []
    loss = Inf
    while (!stop!(stopper, loss))
        for i = 1:epochs_per_checkpoint
            train_epoch!(m, ps, opt)
            apply_zero_gradient!(m, ps, opt, !isnothing(fine_tune_layers))
        end
        loss = split_loss(m, "validation")
        push!(losses, loss)
        if loss == minimum(losses)
            best_model = m |> cpu
        end
        if verbose
            training_loss = split_loss(m, "training")
            @info "losses: $training_loss $loss"
        end
    end
    global G = nothing
    epochs = stopper.iters - stopper.iters_without_improvement
    best_model, epochs, minimum(losses)
end;