In [None]:
import JupyterFormatter
JupyterFormatter.enable_autoformat();

In [None]:
import Glob
import JLD2
import ProgressMeter: @showprogress, next!
import Random
import SHA
import StatsBase

In [None]:
include("import_lists.jl");

In [None]:
const MIN_ITEMS = 5;
const MAX_VALIDATION_DAYS = 7
const MAX_VALIDATION_ITEMS = 5
const VALIDATION_USER_FRAC = 0.01;

In [None]:
function load_datasets()
    anime_files = sort(Glob.glob("training_data/*user_anime_list*jld2", get_data_path("")))
    manga_files = sort(Glob.glob("training_data/*user_manga_list*jld2", get_data_path("")))
    files = collect(zip(anime_files, manga_files))
    for (a, m) in files
        @assert replace(a, "user_anime_list" => "user_manga_list") == m
    end
    dfs = Vector{RatingsDataset}(undef, length(files))
    @showprogress Threads.@threads for i = 1:length(dfs)
        dfs[i] = reduce(cat, [JLD2.load(f, "data") for f in files[i]])
    end
    dfs
end;

In [None]:
function drop_sparse_users!(dfs)
    user_counts = Dict()
    @showprogress for df in dfs
        for (k, v) in StatsBase.countmap(df.userid)
            @assert k ∉ keys(user_counts)
            user_counts[k] = v
        end
    end
    valid_userids = Set(k for (k, v) in user_counts if v >= MIN_ITEMS)

    @showprogress Threads.@threads for i = 1:length(dfs)
        dfs[i] = subset(dfs[i], dfs[i].userid .∈ (valid_userids,))
    end
end;

In [None]:
function relabel_userids!(dfs)
    userids = Int32[]
    @showprogress for df in dfs
        for k in Set(df.userid)
            push!(userids, k)
        end
    end
    @assert length(userids) == length(Set(userids))
    sort!(userids)
    hash = SHA.sha256(reinterpret(UInt8, userids))
    Random.seed!(first(reinterpret(UInt64, hash)))
    Random.shuffle!(userids)
    userid_map = Dict(u => i for (i, u) in Iterators.enumerate(userids))
    CSV.write(
        get_data_path("processed_data/relabel_userid_map.csv"),
        DataFrames.DataFrame([(k, v) for (k, v) in userid_map], [:username, :userid]),
    )
    @showprogress Threads.@threads for i = 1:length(dfs)
        dfs[i].userid .= dfs[i].userid .|> x -> userid_map[x]
    end
end;

In [None]:
function create_splits!(dfs)
    userid_map = read_csv(get_data_path("processed_data/relabel_userid_map.csv"))
    users = userid_map.userid .|> x -> parse(Int32, x)
    validation_users = Set(
        StatsBase.sample(
            users,
            round(Int, length(users) * VALIDATION_USER_FRAC),
            replace = false,
        ),
    )

    train_dfs = Vector{eltype(dfs)}(undef, length(dfs))
    val_dfs = Vector{eltype(dfs)}(undef, length(dfs))
    @showprogress Threads.@threads for t = 1:length(dfs)
        df = dfs[t]
        df = subset(df, sortperm(collect(zip(df.userid, df.updated_at, df.update_order))))
        ts_cutoff = 1 - (24 * 60 * 60 * MAX_VALIDATION_DAYS) / (MAX_TS - MIN_TS)
        userid = nothing
        num_val = [0 for _ in MEDIUM_MAP]
        val_mask = BitArray([false for _ = 1:length(df.userid)])
        for i in reverse(1:length(df.userid))
            if userid != df.userid[i]
                userid = df.userid[i]
                num_val .= 0
            end
            if (df.userid[i] in validation_users) &&
               (df.updated_at[i] > ts_cutoff) &&
               (num_val[df.medium[i]+1] < MAX_VALIDATION_ITEMS)
                num_val[df.medium[i]+1] += 1
                val_mask[i] = true
            end
        end
        train_dfs[t] = subset(df, .!val_mask)
        val_dfs[t] = subset(df, val_mask)
        dfs[t] = RatingsDataset([[] for _ in fieldnames(RatingsDataset)]...) # free memory
    end
    train_dfs, val_dfs
end;

In [None]:
function save(dfs::Vector{RatingsDataset}, filepath::String)
    @showprogress for c in fieldnames(RatingsDataset)
        x = vcat([getfield(x, c) for x in dfs]...)
        fn = filepath * string(c) * ".jld2"
        JLD2.save(fn, Dict("data" => x); compress = false)
    end
end;

In [None]:
function save_splits(train_dfs, val_dfs)
    outdir = get_data_path("splits")
    if !ispath(outdir)
        mkpath(outdir)
    end
    save(train_dfs, "$outdir/training.")
    save(val_dfs, "$outdir/validation.")
    rm(get_data_path("training_data"), recursive = true)
end;

In [None]:
dfs = load_datasets();

In [None]:
drop_sparse_users!(dfs);

In [None]:
relabel_userids!(dfs);

In [None]:
train_dfs, val_dfs = create_splits!(dfs);

In [None]:
save_splits(train_dfs, val_dfs);