## Data Preprocessing
* An epoch is an efficient representation of all the models inputs, outputs, residualization, and weights
* We generate one epoch per split and memoize them

In [None]:
import DelimitedFiles: readdlm
import Statistics: mean, std
@nbinclude("../../Ensemble/ItemMetadata.ipynb");

### Building blocks

In [None]:
function one_hot_inputs(implicit::Bool, num_users::Int)
    convert.(Int32, collect(1:num_users))
end;

In [None]:
function explicit_inputs(num_users::Int, residual_alphas::Vector{String})
    df = get_split("training", "explicit")
    df = RatingsDataset(
        user=df.user,
        item=df.item,
        rating=df.rating .- read_alpha(residual_alphas, "training", "explicit", false).rating,
    )
    sparse(filter_users(df, num_users))
end;

In [None]:
function explicit_validity_inputs(num_users::Int)
    df = get_split("training", "explicit")
    df = RatingsDataset(user=df.user, item=df.item, rating=fill(1, length(df.rating)))
    sparse(filter_users(df, num_users))
end;

In [None]:
function implicit_inputs(num_users::Int)
    df = get_split("training", "implicit")
    sparse(filter_users(df, num_users))
end;

In [None]:
function explicit_implicit_inputs(num_users::Int, residual_alphas::Vector{String})
    vcat(
        explicit_inputs(num_users, residual_alphas),
        explicit_validity_inputs(num_users),
        implicit_inputs(num_users),
    )
end;

In [None]:
function explicit_implicit_tuple_inputs(num_users::Int, residual_alphas::Vector{String})
    explicit_inputs(num_users, residual_alphas),
        explicit_validity_inputs(num_users),
        implicit_inputs(num_users)
end;

In [None]:
function get_epoch_allitem_inputs(input_data::String, implicit::Bool, num_users::Int, input_alphas::Vector{String})
    if input_data == "one_hot"
        return one_hot_inputs(implicit, num_users)
    elseif input_data == "implicit"
        return implicit_inputs(num_users)
    elseif input_data == "explicit"
        return explicit_inputs(num_users, input_alphas)
    elseif input_data == "explicit_implicit"
        return explicit_implicit_inputs(num_users, input_alphas)
    elseif input_data == "explicit_implicit_tuple"
        return explicit_implicit_tuple_inputs(num_users, input_alphas)        
    else
        @assert false
    end
end;

### Dispatch

In [None]:
function get_epoch_item_inputs(users::Vector{Int32}, items::Vector{Int32}, input_data::String, implicit::Bool, num_users::Int, input_alphas::Vector{String})
    df = RatingsDataset(user=users, item=items)
    df = filter_users(df, num_users)
    U = get_epoch_allitem_inputs(input_data, implicit, num_users, input_alphas)
    H = convert.(Float32, sparse(LinearAlgebra.I(num_items())))
    M = get_neural_item_features()
    T = readdlm(get_data_path("processed_data/text_embedding.csv"), ',', Float32)
    I = readdlm(get_data_path("processed_data/image_embedding.csv"), ',', Float32)
    A = (H, M..., T, I)
    (df.user, U, df.item, A)
end;

In [None]:
function get_transformed_output_data(df, output_data::String)
    if output_data == "allitems"
        return sparse(df)
    elseif output_data == "item"
        return collect(df.rating')
    else
        @assert false
    end
end;

In [None]:
function get_epoch_inputs(
    users::Vector,
    items::Vector,
    input_data::String,
    output_data::String,
    implicit::Bool,
    num_users::Int,
    input_alphas::Vector{String},
)
    if output_data == "allitems"
        return get_epoch_allitem_inputs(input_data, implicit, num_users, input_alphas)
    elseif output_data == "item"
        return get_epoch_item_inputs(
            users,
            items,
            input_data,
            implicit,
            num_users,
            input_alphas,
        )
    else
        @assert false
    end
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_inputs(
    split::String,
    input_data::String,
    output_data::String,
    content::String,
    implicit::Bool,
    num_users::Int,
    input_alphas::Vector{String},
)
    df = get_split(split, content)
    get_epoch_inputs(
        df.user,
        df.item,
        input_data,
        output_data,
        implicit,
        num_users,
        input_alphas,
    )
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_outputs(
    split::String,
    output_data::String,
    content::String,
    implicit::Bool,
    num_users::Int,
)
    df = filter_users(get_split(split, content), num_users)
    get_transformed_output_data(df, output_data)
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_residuals(
    split::String,
    output_data::String,
    content::String,
    residual_alphas::Vector{String},
    implicit::Bool,
    num_users::Int,
)
    df = filter_users(read_alpha(residual_alphas, split, content, implicit), num_users)
    get_transformed_output_data(df, output_data)
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_weights(
    split::String,
    output_data::String,
    content::String,        
    user_weight_decay::Union{String,Real},
    item_weight_decay::Union{String,Real},
    num_users::Int,
)
    if split == "training"
        weights =
            powerdecay(get_counts(split, content), user_weight_decay) .*
            powerdecay(get_counts(split, content; by_item = true), item_weight_decay)
    else
        weights = powerdecay(get_counts(split, content), weighting_scheme("inverse"))
    end

    df = get_split(split, content)
    df = filter_users(RatingsDataset(user=df.user, item=df.item, rating=weights), num_users)
    get_transformed_output_data(df, output_data)
end;:

In [None]:
# returns (X, Y, Z, W) = (inputs, outputs, residualization alpha, weights)
function get_epoch(split::String)
    X = get_epoch_inputs(
        split,
        G.input_data,
        G.output_data,
        G.content,
        G.implicit,
        G.num_users,
        G.input_alphas,
    )
    Y = get_epoch_outputs(split, G.output_data, G.content, G.implicit, G.num_users)
    Z = get_epoch_residuals(
        split,
        G.output_data,
        G.content,        
        G.residual_alphas,
        G.implicit,
        G.num_users,
    )
    W = get_epoch_weights(
        split,
        G.output_data,
        G.content,        
        G.user_weight_decay,
        G.item_weight_decay,
        G.num_users,
    )
    X, Y, Z, W
end;

In [None]:
function epoch_size(epoch)
    if G.output_data == "allitems"
        X = epoch[1]
        if X isa Tuple
            X = X[1]
        end
        dims = size(X)
        return dims[length(dims)]
    elseif G.output_data == "item"
        return length(epoch[1][1])
    else
        @assert false
    end
end;