## Batching
* Turns an epoch into minibatches
* Each data point will appear in a minibatch with a probability proportional to its sampling weight

In [None]:
import StatsBase: sample, Weights

In [None]:
function SparseArrays.sparse(split::RatingsDataset)
    sparse(split.item, split.user, split.rating, num_items(), G.num_users)
end;

In [None]:
function slice(x::AbstractVector, range)
    x[range]
end

function slice(x::AbstractMatrix, range)
    x[:, range]
end;

In [None]:
function get_sampling_order(split)
    scheme = split == "training" ? G.user_sampling_scheme : "constant"
    if scheme == "constant"
        return shuffle(1:G.num_users)
    else
        weights = expdecay(
            get_counts(split, G.implicit; per_rating = false),
            weighting_scheme(scheme),
        )
        return sample(1:G.num_users, Weights(weights[1:G.num_users]), G.num_users)
    end
end;

In [None]:
# holdout is a generalization of unscaled dropout for autoencoders
# we randomly drop some percentage of the input and then try to reconstruct the input
# we give a weight of 1 to items that were not dropped and a weight of ϵ to items
# that were not dropped. 
# dropout is the special case where ϵ = 1
# emphasized denoising is the special case where ϵ = 0

function holdout(x, mask)
    x .* repeat(mask, size(x)[1] ÷ size(mask)[1])
end

function holdout(batch, holdout_perc, identity_weight)
    items = CUDA.rand(num_items()) .<= holdout_perc
    holdout(batch[1], 1 .- items),
    batch[2],
    batch[3],
    holdout(batch[4], items + identity_weight .* (1 .- items))
end

function postprocess_batch(batch, training::Bool)
    if should_holdout_items(G.model) && training
        holdout_perc = G.regularization_params[end]
        return holdout(batch, holdout_perc, 0)
    else
        return batch
    end
end;

In [None]:
# performs the following steps
# 1) shuffle the epoch by the sampling order
# 2) split the epoch into minibatches of size batch_size
# 3) return the iter-th minibatch
function get_batch(epoch, iter, batch_size, sampling_order, training::Bool)
    range = sampling_order[(iter-1)*batch_size+1:min(iter * batch_size, G.num_users)]
    process(x) = slice(x, range) |> device
    [postprocess_batch(process.(epoch), training)], range
end;

function get_batch(epoch, iter, batch_size, training::Bool)
    get_batch(epoch, iter, batch_size, 1:G.num_users, training)
end;