## Loss Functions
* The `model_loss` is either the crossentropy loss or squared error, depending on the input data
    * Note that we take the sum over all items, so using a bigger batchsize will have a bigger `model_loss`
* The `uncalibrated_split_loss` is the average model loss over the entire training/validation/test split
* The `split_loss` is the same as the uncalibrated split loss, except we first train a linear regression 
  from the model's output to the ground truth

In [None]:
function model_loss(m, x, y, z, w)
    p = m(x)
    β = m[end].β
    if G.implicit
        β = sigmoid.(β)
        q = softmax(p) .* (1 .- β) + z .* β
        ϵ = Float32(eps(Float64))
        return sum(w .* -y .* log.(q .+ ϵ))
    else
        q = p + z .* β
        return sum(w .* (q - y) .^ 2)
    end
end

function uncalibrated_split_loss(m, split::String)
    epoch = get_epoch(split)
    loss = 0.0
    weights = 0.0
    for iter = 1:Int(ceil(epoch_size(epoch) / G.batch_size))
        batch, _ = get_batch(epoch, iter, G.batch_size, false)
        loss += model_loss(m, batch...)
        weights += sum(batch[end])
        device_free!(batch)
    end
    Float32(loss / weights)
end;

In [None]:
# reregress the residualization before computing loss

function train_calibration_layers(m, split::String; rng=Random.GLOBAL_RNG)
    β = m[end].β
    m = Chain(m[1:end-1]..., ScalarLayer(1) |> device, StorageLayer(1) |> device)
    m[end].β .*= β
    ps = Flux.params(m[end-1:end])
    opt = ADAM(0.01)
    train_epoch!(m, ps, opt; split = split, rng=rng)
    m
end

function train_calibration_layers(m, split::String, hyp::Hyperparams; rng=Random.GLOBAL_RNG)
    m = m |> device    
    global G = hyp
    m = train_calibration_layers(m, split; rng=rng)
    global G = nothing
    m |> cpu
end

function split_loss(m, split::String; rng=Random.GLOBAL_RNG)
    m = train_calibration_layers(m, split; rng=rng)
    uncalibrated_split_loss(m, split)
end;