In [None]:
import JupyterFormatter
JupyterFormatter.enable_autoformat();

In [None]:
import NBInclude: @nbinclude
@nbinclude("Data.ipynb");

In [None]:
const CUTOFF_DAYS = Dict("streaming" => 7, "test" => 1)
const CUTOFF_ITEMS = Dict("streaming" => 5, "test" => 5);

In [None]:
SPLIT = ""

In [None]:
function create_splits!(dfs, max_valid_ts, cutoff_days, cutoff_items)
    train_dfs = Vector{eltype(dfs)}(undef, length(dfs))
    val_dfs = Vector{eltype(dfs)}(undef, length(dfs))
    @showprogress Threads.@threads for t = 1:length(dfs)
        df = dfs[t]
        df = subset(df, sortperm(collect(zip(df.userid, df.updated_at, df.update_order))))
        ts_cutoff = max_valid_ts - (24 * 60 * 60 * cutoff_days) / (MAX_TS - MIN_TS)
        userid = nothing
        num_val = [0 for _ in MEDIUM_MAP]
        val_mask = [false for _ = 1:length(df.userid)]
        for i in reverse(1:length(df.userid))
            if userid != df.userid[i]
                userid = df.userid[i]
                num_val .= 0
            end
            if (df.updated_at[i] > ts_cutoff) && (num_val[df.medium[i]+1] < cutoff_items)
                num_val[df.medium[i]+1] += 1
                val_mask[i] = true
            end
        end
        train_dfs[t] = subset(df, .!val_mask)
        val_dfs[t] = subset(df, val_mask)
        dfs[t] = RatingsDataset([[] for _ in fieldnames(RatingsDataset)]...) # free memory
    end
    train_dfs, val_dfs
end;

In [None]:
function filter_userids!(train_dfs, val_dfs)
    valid_userids = union([Set(df.userid) for df in val_dfs]...)
    @showprogress Threads.@threads for i = 1:length(train_dfs)
        train_dfs[i] = subset(train_dfs[i], train_dfs[i].userid .∈ (valid_userids,))
        val_dfs[i] = subset(val_dfs[i], val_dfs[i].userid .∈ (valid_userids,))
    end
end;

In [None]:
function save_split(name)
    source_dir = "$(name)_data"
    if !ispath(get_data_path(source_dir))
        return
    end
    dfs = load_datasets(source_dir)
    relabel_userids!(dfs)
    max_valid_ts = get_max_valid_ts(name)
    train_dfs, val_dfs =
        create_splits!(dfs, max_valid_ts, CUTOFF_DAYS[name], CUTOFF_ITEMS[name])
    filter_userids!(train_dfs, val_dfs)
    save_splits(train_dfs, val_dfs, name)
end;

In [None]:
save_split(SPLIT)