## Data Preprocessing
* An epoch is an efficient representation of all the models inputs, outputs, residualization, and weights
* We generate one epoch per split and memoize them

In [None]:
import DelimitedFiles: readdlm
import Statistics: mean, std
@nbinclude("../../Ensemble/ItemMetadata.ipynb");

### RatingsDataset manipulation

In [None]:
function SparseArrays.sparse(x::RatingsDataset; field::Symbol = :rating)
    if (!@isdefined G) || isnothing(G)
        n_users = num_users()
    else
        n_users = G.num_users
    end
    sparse(x.item, x.user, getfield(x, field), num_items(), n_users)
end;

### Building blocks

In [None]:
function one_hot_inputs(implicit::Bool, num_users::Int)
    convert.(Int32, collect(1:num_users))
end;

In [None]:
function explicit_inputs(num_users::Int, residual_alphas::Vector{String})
    df = get_split("training", "explicit")
    df = RatingsDataset(
        user = df.user,
        item = df.item,
        rating = df.rating .-
                 read_alpha(residual_alphas, "training", "explicit", false).rating,
    )
    sparse(filter_users(df, num_users))
end;

In [None]:
function explicit_validity_inputs(num_users::Int)
    df = get_split("training", "explicit")
    df = RatingsDataset(user = df.user, item = df.item, rating = fill(1, length(df.rating)))
    sparse(filter_users(df, num_users))
end;

In [None]:
function implicit_inputs(num_users::Int)
    df = get_split("training", "implicit")
    sparse(filter_users(df, num_users))
end;

In [None]:
function explicit_implicit_inputs(num_users::Int, residual_alphas::Vector{String})
    vcat(
        explicit_inputs(num_users, residual_alphas),
        explicit_validity_inputs(num_users),
        implicit_inputs(num_users),
    )
end;

In [None]:
function explicit_implicit_tuple_inputs(num_users::Int, residual_alphas::Vector{String})
    explicit_inputs(num_users, residual_alphas),
    explicit_validity_inputs(num_users),
    implicit_inputs(num_users)
end;

In [None]:
function field_inputs(num_users::Int, content::String, field::Symbol)
    df = get_split("training", content; fields = [:user, :item, field])
    df = RatingsDataset(user = df.user, item = df.item, rating = getfield(df, field))
    if field == :status
        max_status = 5
        @assert all(df.rating .<= max_status)
        df = @set df.rating = df.rating ./ max_status
    end
    sparse(filter_users(df, num_users))
end;

In [None]:
function impression_metadata(num_users::Int, content::String, residual_alphas::Vector{String})
    vcat(
        explicit_inputs(num_users, residual_alphas),
        field_inputs(num_users, content, :status),
        field_inputs(num_users, content, :completion),
        field_inputs(num_users, content, :timestamp),
    )
end;

In [None]:
function get_ordinal_timestamps(num_users::Int)
    df = get_split("training", "implicit"; fields = [:user, :item, :order])
    filter_users(
        RatingsDataset(user = df.user, item = df.item, rating = df.order),
        num_users,
    )
end;

In [None]:
function get_timestamps(num_users::Int)
    df = get_split("training", "implicit"; fields = [:user, :item, :order])
    filter_users(
        RatingsDataset(user = df.user, item = df.item, rating = df.order),
        num_users,
    )
end;

### Dispatch

In [None]:
function get_epoch_allitem_inputs(
    input_data::String,
    implicit::Bool,
    num_users::Int,
    input_alphas::Vector{String},
)
    if input_data == "one_hot"
        return one_hot_inputs(implicit, num_users)
    elseif input_data == "implicit"
        return implicit_inputs(num_users)
    elseif input_data == "explicit"
        return explicit_inputs(num_users, input_alphas)
    elseif input_data == "explicit_implicit"
        return explicit_implicit_inputs(num_users, input_alphas)
    elseif input_data == "explicit_implicit_tuple"
        return explicit_implicit_tuple_inputs(num_users, input_alphas)
    elseif input_data == "impression_metadata"
        return impression_metadata(num_users, "implicit", input_alphas)    
    elseif input_data == "impression_metadata_ptw"
        return impression_metadata(num_users, "ptw", input_alphas)        
    else
        @assert false
    end
end;

In [None]:
function get_epoch_item_inputs(
    users::Vector{Int32},
    items::Vector{Int32},
    input_data::String,
    implicit::Bool,
    num_users::Int,
    input_alphas::Vector{String},
)
    U = get_epoch_allitem_inputs(input_data, implicit, num_users, input_alphas)
    T = sparse(get_ordinal_timestamps(num_users))
    H = convert.(Float32, sparse(LinearAlgebra.I(num_items())))
    M = get_neural_item_features()
    A = (H, M...)
    df = filter_users(RatingsDataset(user = users, item = items), num_users)
    (df.user, T, U, df.item, A)
end;

In [None]:
function get_transformed_output_data(df, output_data::String)
    if output_data == "allitems"
        return sparse(df)
    elseif output_data == "item"
        return collect(df.rating')
    else
        @assert false
    end
end;

In [None]:
function get_epoch_inputs(
    users::Vector,
    items::Vector,
    input_data::String,
    output_data::String,
    implicit::Bool,
    num_users::Int,
    input_alphas::Vector{String},
)
    if output_data == "allitems"
        return get_epoch_allitem_inputs(input_data, implicit, num_users, input_alphas)
    elseif output_data == "item"
        return get_epoch_item_inputs(
            users,
            items,
            input_data,
            implicit,
            num_users,
            input_alphas,
        )
    else
        @assert false
    end
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_inputs(
    split::String,
    input_data::String,
    output_data::String,
    content::String,
    implicit::Bool,
    num_users::Int,
    input_alphas::Vector{String},
)
    df = get_split(split, content)
    get_epoch_inputs(
        df.user,
        df.item,
        input_data,
        output_data,
        implicit,
        num_users,
        input_alphas,
    )
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_outputs(
    split::String,
    output_data::String,
    content::String,
    implicit::Bool,
    num_users::Int,
)
    df = filter_users(get_split(split, content), num_users)
    get_transformed_output_data(df, output_data)
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_residuals(
    split::String,
    output_data::String,
    content::String,
    residual_alphas::Vector{String},
    implicit::Bool,
    num_users::Int,
)
    df = filter_users(read_alpha(residual_alphas, split, content, implicit), num_users)
    get_transformed_output_data(df, output_data)
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_weights(
    split::String,
    output_data::String,
    content::String,
    user_weight_decay::Union{String,Real},
    item_weight_decay::Union{String,Real},
    temporal_weight_decay::Union{String,Real},        
    num_users::Int,
)
    if split == "training"
        weights =
            powerdecay(get_counts(split, content), user_weight_decay) .*
            powerdecay(get_counts(split, content; by_item = true), item_weight_decay)
            powerdecay(get_timestamps(split, content), temporal_weight_decay)
    else
        weights = powerdecay(get_counts(split, content), weighting_scheme("inverse"))
    end

    df = get_split(split, content)
    df = filter_users(
        RatingsDataset(user = df.user, item = df.item, rating = weights),
        num_users,
    )
    get_transformed_output_data(df, output_data)
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_ordinal_timestamps(
    split::String,
    output_data::String,
    content::String,
    num_users::Int,
)
    get_transformed_output_data(get_ordinal_timestamps(num_users), output_data)
end;

In [None]:
# returns (X, Y, Z, W) = (inputs, outputs, residualization alpha, weights)
function get_epoch(split::String)
    X = get_epoch_inputs(
        split,
        G.input_data,
        G.output_data,
        G.content,
        G.implicit,
        G.num_users,
        G.input_alphas,
    )
    Y = get_epoch_outputs(split, G.output_data, G.content, G.implicit, G.num_users)
    Z = get_epoch_residuals(
        split,
        G.output_data,
        G.content,
        G.residual_alphas,
        G.implicit,
        G.num_users,
    )
    W = get_epoch_weights(
        split,
        G.output_data,
        G.content,
        G.user_weight_decay,
        G.item_weight_decay,
        G.temporal_weight_decay,
        G.num_users,
    )
    epoch = (X, Y, Z, W)
    if G.temporal_batching
        T = get_epoch_ordinal_timestamps(split, G.output_data, G.content, G.num_users)
        epoch = (epoch..., T)
    end
    epoch
end;

### Utilities

In [None]:
function epoch_size(epoch)
    if G.output_data == "allitems"
        X = epoch[1]
        if X isa Tuple
            @assert startswith(G.model, "double_embedding")
            X = X[1]
        end
        return size(X)[end]
    elseif G.output_data == "item"
        return length(epoch[1][1])
    else
        @assert false
    end
end;