## Batching
* Turns an epoch into minibatches
* Each user will appear in a minibatch with a probability proportional to its sampling weight
* There is logic to predicting masked out items within a minibatch

In [None]:
import StatsBase: wsample, Weights

### Sample users to put in a minibatch

In [None]:
function get_sampling_order(epoch, split::String, rng)
    if split == "training"
        weights = powerdecay(
            get_counts(split, "all", G.content, G.medium; per_rating = false),
            weighting_scheme(G.user_sampling_scheme),
        )
        N = epoch_size(epoch)
        return wsample(rng, 1:N, weights[1:N], N)        
    else
        return collect(
            Set(filter_users(get_split(split, G.task, G.content, G.medium), G.num_users).user),
        )
    end
end;

In [None]:
slice(x::Nothing, range) = nothing
slice(x::AbstractVector, range) = x[range]
slice(x::AbstractMatrix, range) = x[:, range]
slice(x::Tuple, range) = slice.(x, (range,));

### Mask out items within a minibatch

In [None]:
# perform emphasized denoising on each minibatch

function holdout(x, mask)
    x .* repeat(mask, size(x)[1] ÷ size(mask)[1])
end

function holdout(x::Tuple, mask)
    holdout.(x, (mask,))
end

function holdout_allitems(batch, training::Bool, rng)
    if !training
        return batch
    end
    randfn = CUDA.functional() ? CUDA.rand : x -> rand(rng, x)
    batch_size = size(batch[4])[2]

    # randomly drop holdout_perc percent of items from a user's list
    if startswith(G.model, "universal")
        @assert length(G.holdout) == length(ALL_MEDIUMS)
        media_masks = [randfn(num_items(x), batch_size) .>= p for (x, p) in zip(ALL_MEDIUMS, G.holdout)]
        entries_to_keep = reduce(vcat, media_masks)
        entries_to_predict = 1 .- media_masks[findfirst(x -> x == G.medium, ALL_MEDIUMS)]        
    elseif startswith(G.model, "autoencoder")
        @assert length(G.holdout) == 1        
        entries_to_keep = randfn(num_items(G.medium), batch_size) .>= G.holdout[1]
        entries_to_predict = 1 .- entries_to_keep
    else
        @assert false
    end

    holdout(batch[1], entries_to_keep),
    batch[2],
    batch[3],
    holdout(batch[4], entries_to_predict)
end;

### Construct a minibatch from an epoch

In [None]:
function get_batch(
    epoch,
    iter::Int,
    batch_size::Int,
    sampling_order,
    training::Bool,
    rng = Random.GLOBAL_RNG,
)
    range =
        sampling_order[(iter-1)*batch_size+1:min(iter * batch_size, length(sampling_order))]
    process(x) = slice(x, range) |> device
    batch = holdout_allitems(process.(epoch), training, rng)
    (batch[1], batch[2], batch[3], batch[4]), range
end;