## Data Preprocessing
* An epoch is an efficient representation of all the models inputs, outputs, residualization, and weights
* We generate one epoch per split and memoize them

In [None]:
import DelimitedFiles: readdlm
import Statistics: mean, std;

### Building blocks

In [None]:
function one_hot_inputs(implicit::Bool, num_users::Int)
    convert.(Int32, collect(1:num_users))
end;

In [None]:
function explicit_inputs(task::String, num_users::Int, residual_alphas::Vector{String})
    df = get_split("training", task, "explicit")
    df = RatingsDataset(
        user = df.user,
        item = df.item,
        rating = df.rating .-
                 read_alpha(residual_alphas, "training", task, "explicit", false).rating,
    )
    sparse(filter_users(df, num_users))
end;

In [None]:
function explicit_validity_inputs(task::String, num_users::Int)
    df = get_split("training", task, "explicit")
    df = RatingsDataset(user = df.user, item = df.item, rating = fill(1, length(df.rating)))
    sparse(filter_users(df, num_users))
end;

In [None]:
function implicit_inputs(task::String, num_users::Int)
    df = get_split("training", task, "implicit")
    sparse(filter_users(df, num_users))
end;

In [None]:
function explicit_implicit_inputs(
    task::String,
    num_users::Int,
    residual_alphas::Vector{String},
)
    vcat(
        explicit_inputs(task, num_users, residual_alphas),
        explicit_validity_inputs(task, num_users),
        implicit_inputs(task, num_users),
    )
end;

In [None]:
function field_inputs(task::String, num_users::Int, field::Symbol)
    df = get_split("training", task, "implicit"; fields = [:user, :item, field])
    df = RatingsDataset(user = df.user, item = df.item, rating = getfield(df, field))
    if field == :status
        max_status = 5
        @assert all(df.rating .<= max_status)
        df = @set df.rating = df.rating ./ max_status
    end
    sparse(filter_users(df, num_users))
end;

In [None]:
function get_ordinal_timestamps(task::String, num_users::Int)
    df = get_split("training", task, "implicit"; fields = [:user, :item, :timestamp])
    seen_items = zeros(Float32, num_users())
    ordinal_timestamps = zeros(Float32, length(df.timestamp))
    for i in sortperm(df.timestamp)
        ordinal_timestamps[i] = seen_items[df.user[i]]
        seen_items[df.user[i]] += 1
    end
    for i in length(ordinal_timestamps)
        ordinal_timestamps[i] /= max(1, seen_items[df.user[i]])
    end
    filter_users(
        RatingsDataset(user = df.user, item = df.item, rating = df.timestamp),
        num_users,
    )
end;

In [None]:
function get_timestamps(task::String, num_users::Int)
    df = get_split("training", task, "implicit"; fields = [:user, :item, :timestamp])
    filter_users(
        RatingsDataset(user = df.user, item = df.item, rating = df.timestamp),
        num_users,
    )
end;

### Dispatch

In [None]:
@memoize function get_epoch_inputs(
    input_data::String,
    task::String,
    implicit::Bool,
    num_users::Int,
    input_alphas::Vector{String},
)
    if input_data == "one_hot"
        return one_hot_inputs(implicit, num_users)
    elseif input_data == "implicit"
        return implicit_inputs(task, num_users)
    elseif input_data == "explicit"
        return explicit_inputs(task, num_users, input_alphas)
    elseif input_data == "explicit_implicit"
        return explicit_implicit_inputs(task, num_users, input_alphas)
    else
        @assert false
    end
end;

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_outputs(
    split::String,
    task::String,
    content::String,
    implicit::Bool,
    num_users::Int,
)
    sparse(filter_users(get_split(split, task, content), num_users))
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_residuals(
    split::String,
    task::String,
    content::String,
    residual_alphas::Vector{String},
    implicit::Bool,
    num_users::Int,
)        
    sparse(filter_users(read_alpha(residual_alphas, split, task, content, implicit), num_users))
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_weights(
    split::String,
    task::String,
    content::String,
    user_weight_decay::Real,
    item_weight_decay::Real,
    temporal_weight_decay::Real,
    num_users::Int,
)
    if split == "training"
        weights =
            powerdecay(get_counts(split, task, content), user_weight_decay) .*
            powerdecay(
                get_counts(split, task, content; by_item = true),
                item_weight_decay,
            ) .* powerlawdecay(
                (
                    1 .-
                    max.(
                        get_split(split, task, content; fields = [:timestamp]).timestamp,
                        0.0f0,
                    )
                ) ./ year_in_timestamp_units(),
                temporal_weight_decay,
            )
    else
        weights = powerdecay(get_counts(split, task, content), weighting_scheme("inverse"))
    end

    df = get_split(split, task, content)
    df = filter_users(
        RatingsDataset(user = df.user, item = df.item, rating = weights),
        num_users,
    )
    sparse(df)
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_timestamps(
    split::String,
    task::String,
    content::String,
    num_users::Int,
)
    sparse(get_timestamps(task, num_users))
end;

In [None]:
# returns (X, Y, Z, W) = (inputs, outputs, residualization alpha, weights)
function get_epoch(split::String)
    if split == "training"
        task = "all"
    elseif split in ["validation", "test"]
        task = G.task
    else
        @assert false
    end
    X = get_epoch_inputs(
        G.input_data,        
        task,
        G.implicit,
        G.num_users,
        G.input_alphas,
    )
    Y = get_epoch_outputs(split, task, G.content, G.implicit, G.num_users)
    Z = get_epoch_residuals(
        split,
        task,
        G.content,
        G.residual_alphas,
        G.implicit,
        G.num_users,
    )
    W = get_epoch_weights(
        split,
        task,
        G.content,
        G.user_weight_decay,
        G.item_weight_decay,
        G.temporal_weight_decay,
        G.num_users,
    )
    epoch = (X, Y, Z, W)
    if should_temporal_batch(G.model)
        T = get_epoch_timestamps(split, task, G.content, G.num_users)
        epoch = (epoch..., T)
    end
    epoch
end;

### Utilities

In [None]:
epoch_size(epoch) = size(epoch[1])[end]