In [None]:
import JupyterFormatter
JupyterFormatter.enable_autoformat();

In [None]:
import Glob
import JLD2
import ProgressMeter: @showprogress, next!

In [None]:
include("import_lists.jl");

In [None]:
function load_datasets(source_dir)
    anime_files = sort(Glob.glob("$(source_dir)/*user_anime_list*jld2", get_data_path("")))
    manga_files = sort(Glob.glob("$(source_dir)/*user_manga_list*jld2", get_data_path("")))
    files = collect(zip(anime_files, manga_files))
    for (a, m) in files
        @assert replace(a, "user_anime_list" => "user_manga_list") == m
    end
    dfs = Vector{RatingsDataset}(undef, length(files))
    @showprogress Threads.@threads for i = 1:length(dfs)
        dfs[i] = reduce(cat, [JLD2.load(f, "data") for f in files[i]])
    end
    dfs
end;

In [None]:
function relabel_userids!(dfs)
    userid_df = read_csv(get_data_path("processed_data/relabel_userid_map.csv"))
    userid_map = Dict(
        parse(Int32, u) => parse(Int32, i) for
        (u, i) in zip(userid_df.username, userid_df.userid)
    )
    @showprogress Threads.@threads for i = 1:length(dfs)
        dfs[i].userid .= dfs[i].userid .|> x -> get(userid_map, x, 0)
        dfs[i] = subset(dfs[i], dfs[i].userid .!= 0)
    end
end;

In [None]:
function get_max_valid_ts(name)
    files = Glob.glob(
        "raw_$(name)_data/*/user_media_facts/user_status.*.csv",
        get_data_path(""),
    )
    maxunixtime =
        maximum([maximum(parse.(Int64, read_csv(f).access_timestamp)) for f in files])
    (maxunixtime - MIN_TS) / (MAX_TS - MIN_TS)
end;

In [None]:
function save(dfs::Vector{RatingsDataset}, filepath::String)
    medium = vcat([getfield(x, :medium) for x in dfs]...)
    @showprogress for c in fieldnames(RatingsDataset)
        x = vcat([getfield(x, c) for x in dfs]...)
        d = Dict(m => x[medium.==MEDIUM_MAP[m]] for m in keys(MEDIUM_MAP))
        fn = filepath * string(c) * ".jld2"
        JLD2.save(fn, d; compress = false)
    end
end;

In [None]:
function save_splits(train_dfs, val_dfs, name)
    outdir = get_data_path("splits")
    if !ispath(outdir)
        mkpath(outdir)
    end
    save(train_dfs, "$outdir/$name.train.")
    save(val_dfs, "$outdir/$name.test.")
end;