In [None]:
import JupyterFormatter
JupyterFormatter.enable_autoformat();

In [None]:
import Glob
import JLD2
import ProgressMeter: @showprogress, next!
import Random
import SHA
import StatsBase

In [None]:
include("import_lists.jl");

In [None]:
function seed_rng!(salt::String)
    init = first(read_csv(get_data_path("rng.csv")).seed)
    seed = first(reinterpret(UInt64, SHA.sha256(init * salt)))
    Random.seed!(seed)
end;

In [None]:
function load_datasets(dataset)
    get_files(d, m) = sort(Glob.glob("$(d)_data/*user_$(m)_list*jld2", get_data_path("")))
    anime_files = get_files(dataset, "anime")
    manga_files = get_files(dataset, "manga")
    files = collect(zip(sort.((anime_files, manga_files))...))
    for (a, m) in files
        @assert replace(a, "user_anime_list" => "user_manga_list") == m
    end
    dfs = Vector{RatingsDataset}(undef, length(files))
    @showprogress Threads.@threads for i = 1:length(dfs)
        dfs[i] = reduce(cat, [JLD2.load(f, "data") for f in files[i]])
    end
    dfs
end;

In [None]:
function split_by_user!(dfs, test_frac)
    users = collect(union([Set(df.userid) for df in dfs]...))
    num_test_users = round(Int, length(users) * test_frac)
    test_userids = Set(StatsBase.sample(users, num_test_users, replace = false))
    train_dfs = Vector{eltype(dfs)}(undef, length(dfs))
    test_dfs = Vector{eltype(dfs)}(undef, length(dfs))
    @showprogress Threads.@threads for i = 1:length(dfs)
        train_dfs[i] = subset(dfs[i], dfs[i].userid .∉ (test_userids,))
        test_dfs[i] = subset(dfs[i], dfs[i].userid .∈ (test_userids,))
        dfs[i] = RatingsDataset([[] for _ in fieldnames(RatingsDataset)]...) # free memory
    end
    train_dfs, test_dfs
end;

In [None]:
function drop_sparse_users!(dfs, min_items)
    user_counts = Dict()
    @showprogress for df in dfs
        for (k, v) in StatsBase.countmap(df.userid)
            @assert k ∉ keys(user_counts)
            user_counts[k] = v
        end
    end
    valid_userids = Set(k for (k, v) in user_counts if v >= min_items)
    @showprogress Threads.@threads for i = 1:length(dfs)
        dfs[i] = subset(dfs[i], dfs[i].userid .∈ (valid_userids,))
    end
end;

In [None]:
function get_max_valid_ts(dataset)
    maxunixtime = parse(
        Float64, 
        first(read_csv(get_data_path("processed_data/$dataset.timestamps.csv")).max_ts),
    )
    (maxunixtime - MIN_TS) / (MAX_TS - MIN_TS)
end;

In [None]:
function create_splits!(dfs, ts_cutoff, max_output_items, output_newest)
    input_dfs = Vector{eltype(dfs)}(undef, length(dfs))
    output_dfs = Vector{eltype(dfs)}(undef, length(dfs))
    @showprogress Threads.@threads for t = 1:length(dfs)
        df = dfs[t]
        df = subset(df, sortperm(collect(zip(df.userid, df.updated_at, df.update_order))))
        userid = nothing
        order = 1:length(df.userid) |> (output_newest ? reverse : identity)
        num_output = [0 for _ in MEDIUM_MAP]
        input_mask = BitArray([false for _ = 1:length(df.userid)])
        output_mask = BitArray([false for _ = 1:length(df.userid)])
        for i in order
            if userid != df.userid[i]
                userid = df.userid[i]
                num_output .= 0
            end
            if (df.updated_at[i] > ts_cutoff) &&
               (num_output[df.medium[i]+1] < max_output_items)
                num_output[df.medium[i]+1] += 1
                output_mask[i] = true
            end
            if output_newest
                input_mask[i] = !output_mask[i]
            else
                input_mask[i] = df.updated_at[i] <= ts_cutoff
            end
        end
        input_dfs[t] = subset(df, input_mask)
        output_dfs[t] = subset(df, output_mask)
        dfs[t] = RatingsDataset([[] for _ in fieldnames(RatingsDataset)]...) # free memory
    end
    input_dfs, output_dfs
end;

In [None]:
function filter_input!(input_dfs, output_dfs)
    valid_userids = union([Set(df.userid) for df in output_dfs]...)
    @showprogress Threads.@threads for i = 1:length(input_dfs)
        input_dfs[i] = subset(input_dfs[i], input_dfs[i].userid .∈ (valid_userids,))
    end
end;

In [None]:
function relabel_userids!(train_dfs, test_input_dfs, test_output_dfs, dataset)
    datasets = [train_dfs, test_input_dfs, test_output_dfs]
    userids = union([Set(df.userid) for dfs in datasets for df in dfs]...)
    userids = Random.shuffle(sort(collect(userids)))
    userid_map = Dict(u => i for (i, u) in Iterators.enumerate(userids))
    CSV.write(
        get_data_path("processed_data/$dataset.relabel_userid_map.csv"),
        DataFrames.DataFrame([(k, v) for (k, v) in userid_map], [:username, :userid]),
    )
    for dfs in datasets
        @showprogress Threads.@threads for i = 1:length(dfs)
            dfs[i].userid .= dfs[i].userid .|> x -> get(userid_map, x, 0)
            dfs[i] = subset(dfs[i], dfs[i].userid .!= 0)
        end
    end
end;

In [None]:
function save(dfs::Vector{RatingsDataset}, outdir::String)
    medium = vcat([getfield(x, :medium) for x in dfs]...)
    @showprogress for c in fieldnames(RatingsDataset)
        x = vcat([getfield(x, c) for x in dfs]...)
        d = Dict(m => x[medium.==MEDIUM_MAP[m]] for m in keys(MEDIUM_MAP))
        fn = "$outdir/$(string(c)).jld2"
        JLD2.save(fn, d; compress = false)
    end
end;

In [None]:
function save_dataset(train_dfs, test_input_dfs, test_output_dfs, dataset)
    outdir = get_data_path("splits/$dataset")
    if !ispath(outdir)
        mkpath(outdir)
    end
    for (dfs, name) in zip(
        [train_dfs, test_input_dfs, test_output_dfs],
        ["train", "test_input", "test_output"],
    )
        save(dfs, "$outdir/$name")
    end
end;