## Data Preprocessing
* An epoch is an efficient representation of all the models inputs, outputs, residualization, and weights
* We generate one epoch per split and memoize them

In [None]:
import DelimitedFiles: readdlm
import Statistics: mean, std;

### Building blocks

In [None]:
function explicit_inputs(task::String, medium::String, num_users::Int, residual_alphas::Vector{String})
    df = get_split("training", task, "explicit", medium; fields = [:user, :item, :rating])
    df = RatingsDataset(
        user = df.user,
        item = df.item,
        rating = df.rating .-
                 read_alpha(residual_alphas, "training", task, "explicit", medium, false).rating,
        medium = medium
    )
    sparse(filter_users(df, num_users))
end;

In [None]:
function implicit_inputs(task::String, medium::String, num_users::Int)
    df = get_split("training", task, "implicit", medium; fields = [:user, :item, :rating])
    sparse(filter_users(df, num_users))
end;

In [None]:
function universal_inputs(
    task::String,
    num_users::Int,
    residual_alphas::Vector{String},
)
    @assert length(residual_alphas) == length(ALL_MEDIUMS)
    inputs = []
    for i in 1:length(ALL_MEDIUMS)
        push!(inputs, explicit_inputs(task, ALL_MEDIUMS[i], num_users, residual_alphas[i:i]))
    end
    for x in ALL_MEDIUMS
        push!(inputs, implicit_inputs(task, x, num_users))
    end
    reduce(vcat, inputs)
end;

### Dispatch

In [None]:
@memoize function get_epoch_inputs(
    input_data::String,
    task::String,
    medium::String,
    implicit::Bool,
    num_users::Int,
    input_alphas::Vector{String},
)
    if input_data == "universal"
         return universal_inputs(task, num_users, input_alphas)
    else
        @assert false
    end
end;

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_outputs(
    split::String,
    task::String,
    content::String,
    medium::String,                
    implicit::Bool,
    num_users::Int,
)
    sparse(filter_users(get_split(split, task, content, medium; fields = [:user, :item, :rating]), num_users))
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_residuals(
    split::String,
    task::String,
    content::String,                
    medium::String,   
    residual_alphas::Vector{String},
    implicit::Bool,
    num_users::Int,
)        
    sparse(filter_users(read_alpha(residual_alphas, split, task, content, medium, implicit), num_users))
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_weights(
    split::String,
    task::String,
    content::String,
    medium::String,        
    user_weight_decay::Real,
    item_weight_decay::Real,
    temporal_weight_decay::Real,
    num_users::Int,
)
    if split == "training"
        weights =
            powerdecay(get_counts(split, task, content, medium), user_weight_decay) .*
            powerdecay(
                get_counts(split, task, content, medium; by_item = true),
                item_weight_decay,
            ) .* powerlawdecay(
                (
                    1 .-
                    max.(
                        get_split(split, task, content, medium; fields = [:timestamp]).timestamp,
                        0.0f0,
                    )
                ) ./ year_in_timestamp_units(medium),
                temporal_weight_decay,
            )
    else
        weights = powerdecay(get_counts(split, task, content, medium), weighting_scheme("inverse"))
    end

    df = get_split(split, task, content, medium; fields = [:user, :item])
    df = filter_users(
        RatingsDataset(user = df.user, item = df.item, rating = weights, medium=medium),
        num_users,
    )
    sparse(df)
end

In [None]:
# returns (X, Y, Z, W) = (inputs, outputs, residualization alpha, weights)
function get_epoch(split::String)
    if split == "training"
        task = "all"
    elseif split in ["validation", "test"]
        task = G.task
    else
        @assert false
    end
    
    X = get_epoch_inputs(
        G.input_data,        
        task,
        G.medium,
        G.implicit,
        G.num_users,
        G.input_alphas,
    )
    Y = get_epoch_outputs(split, task, G.content, G.medium, G.implicit, G.num_users)
    Z = get_epoch_residuals(
        split,
        task,
        G.content,
        G.medium, 
        G.residual_alphas,
        G.implicit,
        G.num_users,
    )
    W = get_epoch_weights(
        split,
        task,
        G.content,
        G.medium,
        G.user_weight_decay,
        G.item_weight_decay,
        G.temporal_weight_decay,
        G.num_users,
    )
    GC.gc()
    (X, Y, Z, W)
end;