## Batching
* Turns an epoch into minibatches
* Each data point will appear in a minibatch with a probability proportional to its sampling weight

In [None]:
import StatsBase: sample, Weights

In [14]:
function SparseArrays.sparse(split::RatingsDataset)
    sparse(split.item, split.user, split.rating, num_items(), G.num_users)
end;

In [15]:
function slice(x::AbstractVector, range)
    x[range]
end

function slice(x::AbstractMatrix, range)
    x[:, range]
end;

In [16]:
function get_sampling_order(split)
    scheme = split == "training" ? G.user_sampling_scheme : "constant"
    if scheme == "constant"
        return shuffle(1:G.num_users)
    else
        weights = expdecay(
            get_counts(split, G.implicit; per_rating = false),
            weighting_scheme(scheme),
        )
        return sample(1:G.num_users, Weights(weights[1:G.num_users]), G.num_users)
    end
end;

In [17]:
# performs the following steps
# 1) shuffle the epoch by the sampling order
# 2) split the epoch into minibatches of size batch_size
# 3) return the iter-th minibatch
function get_batch(epoch, iter, batch_size, sampling_order)
    range = sampling_order[(iter-1)*batch_size+1:min(iter * batch_size, G.num_users)]
    process(x) = slice(x, range) |> device
    [process.(epoch)], range
end;

function get_batch(epoch, iter, batch_size)
    get_batch(epoch, iter, batch_size, 1:G.num_users)
end;