## Data Preprocessing
* An epoch is an efficient representation of all the models inputs, outputs, residualization, and weights
* We generate one epoch per split and memoize them

In [None]:
function one_hot_inputs(implicit, num_users)
    convert.(Int32, collect(1:num_users))
end;

In [None]:
function explicit_inputs(num_users, residual_alphas)
    df = get_split("training", false)
    df = RatingsDataset(
        df.user,
        df.item,
        df.rating .- read_alpha(residual_alphas, "training", false).rating,
    )
    sparse(filter_users(df, num_users))
end;

In [None]:
function explicit_validity_inputs(num_users, residual_alphas)
    df = get_split("training", false)
    df = RatingsDataset(df.user, df.item, fill(1, length(df.rating)))
    sparse(filter_users(df, num_users))
end;

In [None]:
function implicit_inputs(num_users)
    df = get_split("training", true)
    sparse(filter_users(df, num_users))
end;

In [None]:
function explicit_implicit_inputs(num_users, residual_alphas)
    vcat(
        explicit_inputs(num_users, residual_alphas),
        explicit_validity_inputs(num_users, residual_alphas),
        implicit_inputs(num_users),
    )
end;

In [None]:
function get_epoch_allitem_inputs(
    input_data,
    implicit,
    num_users,
    input_alphas,
)
    if input_data == "one_hot"
        return one_hot_inputs(implicit, num_users)
    elseif input_data == "implicit"
        return implicit_inputs(num_users)
    elseif input_data == "explicit"
        return explicit_inputs(num_users, input_alphas)
    elseif input_data == "explicit_implicit"
        return explicit_implicit_inputs(num_users, input_alphas)
    else
        @assert false
    end
end;

In [None]:
function get_epoch_item_inputs(
    input_data,
    implicit,
    num_users,
    input_alphas,
)
    df = filter_users(get_split(split, implicit), num_users)
    X = (df.user, df.item)
    U = get_epoch_allitem_inputs(input_data, implicit, num_users, input_alphas)
    A = convert.(Float32, sparse(LinearAlgebra.I(num_items())))
    (X, (U, A))
end

In [None]:
@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_inputs(
    input_data,
    output_data,
    implicit,
    num_users,
    input_alphas,
)
    if output_data == "allitems"
        return get_epoch_allitem_inputs(input_data, implicit, num_users, input_alphas)
    elseif output_data == "item"
        return get_epoch_item_inputs(input_data, implicit, num_users, input_alphas)
    else
        @assert false
    end
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_outputs(split, output_data, implicit, num_users)
    df = filter_users(get_split(split, implicit), num_users)
    if output_data == "allitems"
        return sparse(df)
    elseif output_data == "item"
        return df.rating
    else
        @assert false
    end
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_residuals(
    split,
    output_data,
    residual_alphas,
    implicit,
    num_users,
)
    df = filter_users(read_alpha(residual_alphas, split, implicit), num_users)
    if output_data == "allitems"
        return sparse(df)
    elseif output_data == "item"
        return df.rating
    else
        @assert false
    end
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_weights(
    split,
    output_data,
    user_weight_decay,
    item_weight_decay,
    implicit,
    num_users,
)
    if output_data == "allitems"
        activation = sparse
        per_rating = false
    elseif output_data == "item"
        activation = x -> x.rating
        per_rating = true
    else
        @assert false
    end
    
    if split == "training"
        weights =
            powerdecay(get_counts(split, implicit; per_rating = per_rating), user_weight_decay) .*
            powerdecay(get_counts(split, implicit; per_rating = per_rating, by_item = true), item_weight_decay)
    else
        weights = powerdecay(get_counts(split, implicit; per_rating = per_rating), weighting_scheme("inverse"))
    end

    df = get_split(split, implicit)
    df = filter_users(RatingsDataset(df.user, df.item, weights), num_users)
    activation(df)
end;

In [None]:
# returns (X, Y, Z, W) = (inputs, outputs, residualization alpha, weights)
function get_epoch(split)
    X = get_epoch_inputs(G.input_data, G.output_data, G.implicit, G.num_users, G.input_alphas)
    Y = get_epoch_outputs(split, G.output_data, G.implicit, G.num_users)
    Z = get_epoch_residuals(split, G.output_data, G.residual_alphas, G.implicit, G.num_users)
    W = get_epoch_weights(
        split,
        G.output_data,
        G.user_weight_decay,
        G.item_weight_decay,
        G.implicit,
        G.num_users,
    )
    X, Y, Z, W
end;