## Batching
* Turns an epoch into minibatches
* Each user will appear in a minibatch with a probability proportional to its sampling weight
* There is logic to predicting masked out items within a minibatch

In [None]:
import StatsBase: wsample, Weights

### Sample users to put in a minibatch

In [None]:
function get_sampling_order(epoch, split::String, rng)
    scheme = split == "training" ? G.user_sampling_scheme : "constant"
    task = split == "training" ? "all" : G.task
    if scheme == "constant"
        return collect(
            Set(filter_users(get_split(split, task, G.content), G.num_users).user),
        )
    else
        weights = powerdecay(
            get_counts(split, task, G.content; per_rating = true),
            weighting_scheme(scheme),
        )
        N = epoch_size(epoch)
        samples = N
        if should_temporal_batch(G.model) && split == "training"
            weights = get_temporal_sampling_order(weights, true, split, task, content)
            samples = min(samples, sum(weights .> 0))
        end
        return wsample(rng, 1:N, weights[1:N], samples)
    end
end;

@memoize function get_sampling_timestamps(
    split::String,
    task::String,
    content::String,
    per_rating::Bool,
)
    if per_rating
        timestamps = get_split(split, task, content, fields = [:timestamp]).timestamp
    else
        df = get_split(split, task, content, fields = [:user, :timestamp])
        timestamps = zeros(num_users(), Threads.nthreads())
        @tprogress Threads.@threads for i = 1:length(df.user)
            if df.timestamp[i] > timestamps[df.user[i], Threads.threadid()]
                timestamps[df.user[i], Threads.threadid()] = df.timestamp[i]
            end
        end
        timestamps = vec(maximum(timestamps; dims = 2))
    end
    timestamps
end

function get_temporal_sampling_order(
    weights,
    per_rating::Bool,
    split::String,
    task::String,
    content::String,
)
    # zero out weights for users that have no items past the temporal holdout
    timestamps = get_sampling_timestamps(split, task, content, per_rating)
    weights .* (timestamps .>= G.temporal_holdout)
end

In [None]:
slice(x::Nothing, range) = nothing
slice(x::AbstractVector, range) = x[range]
slice(x::AbstractMatrix, range) = x[:, range]
slice(x::Tuple, range) = slice.(x, (range,));

### Mask out items within a minibatch

In [None]:
# perform emphasized denoising and data augmentation on each minibatch

function holdout(x, mask)
    x .* repeat(mask, size(x)[1] ÷ size(mask)[1])
end

function holdout(x::Tuple, mask)
    holdout.(x, (mask,))
end

function holdout_allitems(batch, holdout_perc::Real, temporal_perc::Real, training::Bool, rng)
    if !training
        return batch
    end
    randfn = CUDA.functional() ? CUDA.rand : x -> rand(rng, x)
    batch_size = size(batch[4])[2]

    if isnan(temporal_perc)
        # randomly drop holdout_perc percent of items from a user's list
        entries_to_keep = randfn(num_items(), batch_size) .>= holdout_perc
        entries_to_predict = 1 .- entries_to_keep
    else
        # use the first temporal_perc percent of items to predict the remaining 1-temporal_perc entries
        temporal_entries_to_keep = batch[5] .<= temporal_perc

        # randomly drop holdout_perc percent of items from a user's list
        holdout_entries_to_keep = randfn(num_items(), batch_size) .>= holdout_perc

        entries_to_keep = holdout_entries_to_keep .* temporal_entries_to_keep
        entries_to_predict = 1 .- temporal_entries_to_keep
    end

    holdout(batch[1], entries_to_keep),
    batch[2],
    batch[3],
    holdout(batch[4], entries_to_predict)
end

function postprocess_batch(batch, training::Bool, rng)
    if !should_holdout_items(G.model)
        return batch
    end
    holdout_allitems(batch, G.holdout, G.temporal_holdout, training, rng)
end;

### Construct a minibatch from an epoch

In [None]:
# performs the following steps
# 1) shuffle the epoch by the sampling order
# 2) split the epoch into minibatches of size batch_size
# 3) return the iter-th minibatch
function get_batch(
    epoch,
    iter::Int,
    batch_size::Int,
    sampling_order,
    training::Bool,
    rng = Random.GLOBAL_RNG,
)
    range =
        sampling_order[(iter-1)*batch_size+1:min(iter * batch_size, length(sampling_order))]
    process(x) = slice(x, range) |> device
    batch = postprocess_batch(process.(epoch), training, rng)
    (batch[1], batch[2], batch[3], batch[4]), range
end;

function get_batch(
    epoch,
    iter::Int,
    batch_size::Int,
    training::Bool,
    rng = Random.GLOBAL_RNG,
)
    get_batch(epoch, iter, batch_size, 1:epoch_size(epoch), training, rng)
end;