In [None]:
import JupyterFormatter
JupyterFormatter.enable_autoformat();

In [None]:
import CSV
import DataFrames
import Glob
import JLD2
import Memoize: @memoize
import Setfield
import Setfield: @set

## Static data

In [None]:
function get_data_path(file)
    joinpath(@__DIR__, "../../data/$file")
end;

In [None]:
function read_csv(x; kw...)
    CSV.read(x, DataFrames.DataFrame; types = String, missingstring = nothing, kw...)
end;

In [None]:
@memoize function num_users()
    df = read_csv(get_data_path("processed_data/relabel_userid_map.csv"))
    maximum(parse.(Int32, df.userid))
end

@memoize function num_items(medium)
    df = read_csv(get_data_path("processed_data/$medium.csv"), ntasks = 1)
    maximum(parse.(Int32, df.uid))
end;

In [None]:
@memoize function get_status(status::Symbol)::Int32
    df = read_csv(get_data_path("processed_data/status.csv"))
    status_encoding =
        Dict(Symbol(k) => parse(Int32, v) for (k, v) in zip(df.name, df.encoding))
    status_encoding[status]
end;

## Reading and writing data

In [None]:
@kwdef struct RatingsDataset
    source::Vector{Int32} = []
    medium::Vector{Int32} = []
    userid::Vector{Int32} = []
    itemid::Vector{Int32} = []
    status::Vector{Int32} = []
    rating::Vector{Float32} = []
    updated_at::Vector{Float64} = []
    created_at::Vector{Float64} = []
    started_at::Vector{Float64} = []
    finished_at::Vector{Float64} = []
    update_order::Vector{Int32} = []
    progress::Vector{Float32} = []
    progress_volumes::Vector{Float32} = []
    repeat_count::Vector{Int32} = []
    priority::Vector{Int32} = []
    sentiment::Vector{Int32} = []
    alpha::Vector{Float32} = []
    metric::Vector{Float32} = []
end

function subset(x::RatingsDataset, ord)
    mask(arr) = !isempty(arr) ? arr[ord] : arr
    RatingsDataset([mask(getfield(x, c)) for c in fieldnames(RatingsDataset)]...)
end

function cat(x::RatingsDataset, y::RatingsDataset)
    nonempty(df, f) = length(getfield(x, f)) != 0
    for f in fieldnames(RatingsDataset)
        @assert nonempty(x, f) == nonempty(y, f) "cat: missing field $field"
    end
    RatingsDataset(
        [vcat(getfield(x, c), getfield(y, c)) for c in fieldnames(RatingsDataset)]...,
    )
end;

In [None]:
function get_datasets()
    available = Set([
        basename(f) |> x -> split(x, ".")[1] for
        f in Glob.glob("splits/*.jld2", get_data_path(""))
    ])
    possible = ["training", "streaming", "test", "causal"]
    [x for x in possible if x in available]
end

const ALL_DATASETS = get_datasets()
const ALL_SPLITS = ["train", "test"]
const ALL_MEDIUMS = ["manga", "anime"]
const ALL_METRICS = ["rating", "watch", "plantowatch", "drop"];

In [None]:
function get_split(dataset::String, split::String, medium::String, fields::Vector{Symbol})
    @assert dataset in ALL_DATASETS
    @assert split in ALL_SPLITS
    @assert medium in ALL_MEDIUMS
    @assert Set(fields) ⊆ Set(fieldnames(RatingsDataset))
    df = RatingsDataset()
    fn = get_data_path("splits/$dataset.$split")
    for i = 1:length(fields)
        df = Setfield.set(
            df,
            Setfield.PropertyLens{fields[i]}(),
            JLD2.load("$fn.$(fields[i]).jld2", medium),
        )
    end
    @assert [length(getfield(df, f)) for f in fields] |> Set |> length <= 1
    df
end

function get_split(
    dataset::String,
    split::String,
    medium::String,
    fields::Vector{Symbol},
    alpha::String,
)
    @assert dataset != "training"
    df = get_split(dataset, split, medium, fields)
    alpha_df = JLD2.load(get_data_path("alphas/$alpha/alpha.jld2"), "$dataset.$split")
    df = @set df.alpha = alpha_df.alpha
    @assert all(length(getfield(df, f)) == length(df.alpha) for f in fields)
    df
end

function as_metric(df::RatingsDataset, metric::String)
    @assert metric in ALL_METRICS
    if metric == "rating"
        df = subset(df, df.rating .!= 0)
        df = @set df.metric = copy(df.rating)
    elseif metric == "watch"
        df = subset(df, df.status .> get_status(:planned))
        df = @set df.metric = ones(Float32, length(df.status))
    elseif metric == "plantowatch"
        df = subset(df, df.status .== get_status(:planned))
        df = @set df.metric = ones(Float32, length(df.status))
    elseif metric == "drop"
        df = subset(df, df.status .> get_status(:none))
        df = @set df.metric = df.status .<= get_status(:dropped)
    else
        @assert false
    end
    df
end;

In [None]:
function write_alpha(model::Function, medium::String, name::String)
    alphas = Dict()
    for dataset in ALL_DATASETS
        if dataset == "training"
            continue
        end
        cols = collect(setdiff(Set(fieldnames(RatingsDataset)), Set([:alpha, :metric])))
        userids = union(
            [Set(get_split(dataset, x, medium, [:userid]).userid) for x in ALL_SPLITS]...,
        )
        preds = model(get_split(dataset, "train", medium, cols), userids)
        for split in ALL_SPLITS
            df = get_split(dataset, split, medium, [:userid, :itemid])
            alpha = [preds[u][a] for (u, a) in zip(df.userid, df.itemid)]
            alphas["$dataset.$split"] = RatingsDataset(alpha = alpha)
        end
    end
    outdir = mkpath(get_data_path("alphas/$name"))
    JLD2.save("$outdir/alpha.jld2", alphas; compress = true)
end;

In [None]:
function write_params(params, outdir)
    outdir = mkpath(get_data_path("alphas/$outdir"))
    JLD2.save("$outdir/params.jld2", params; compress = true)
end

function read_params(outdir)
    JLD2.load(get_data_path("alphas/$outdir/params.jld2"))
end;