## Loss Functions
* The `model_loss` is either the crossentropy loss or squared error, depending on the input data
    * Note that we take the sum over all items, so using a bigger batchsize will have a bigger `model_loss`
* The `regularization_loss` depends on the model architecture, but is commonly an L2 loss
* The `split_loss` is either the weighted average crossentropy loss or weighted mean squared error, depending on the input data

In [3]:
function model_loss(m, x, y, z, w)
    p = m(x)
    β = m[end].β
    if G.implicit
        β = sigmoid.(β)
        q = softmax(p) .* (1 .- β) + z .* β
        ϵ = eps(Float64)
        return sum(w .* -y .* log.(q .+ ϵ))
    else
        q = p + z .* β
        return sum(w .* (q - y) .^ 2)
    end
end

function training_loss(m, x, y, z, w)
    model_loss(m, x, y, z, w) + regularization_loss(m, x)
end

function uncalibrated_split_loss(m, split)
    epoch = get_epoch(split)
    loss = 0.0
    weights = 0.0
    for iter = 1:Int(ceil(epoch_size(epoch) / G.batch_size))
        batch, _ = get_batch(epoch, iter, G.batch_size, false)
        loss += model_loss(m, batch[1]...)
        weights += sum(batch[1][end])
    end
    Float32(loss / weights)
end

# reregresses the residualization before computing loss
function split_loss(m, split; rng=Random.GLOBAL_RNG)
    β = m[end].β
    m = Chain(m[1:end-1]..., ScalarLayer(1) |> device, StorageLayer(1) |> device)
    m[end].β .*= β
    ps = Flux.params(m[end-1:end])
    opt = ADAM(0.01)
    train_epoch!(m, ps, opt; split = split, rng=rng)
    uncalibrated_split_loss(m, split)
end;