# Helper functions for training alphas

In [None]:
import CSV
import DataFrames: DataFrame
import JLD2
import Memoize: @memoize
import NBInclude: @nbinclude
import NNlib: softmax
import Optim
import ProgressMeter: @showprogress
import Setfield
import Setfield: @set
import SparseArrays

## General utilities

In [None]:
@nbinclude("AlphaUtils.ipynb");

In [None]:
function get_data_path(file)
    path = pwd()
    while basename(path) != "notebooks"
        path = dirname(path)
    end
    path = dirname(path)
    "$path/data/$file"
end;

In [None]:
function set_logging_outdir(name)
    redirect_logging(get_data_path("alphas/$name"); overwrite = false)
end;

## Static data

In [None]:
@memoize function num_users()
    df = DataFrame(CSV.File(get_data_path("processed_data/username_to_uid.csv")))
    length(df.uid)
end

@memoize function num_items(medium)
    df = DataFrame(CSV.File(get_data_path("processed_data/$(medium)_to_uid.csv")))
    length(df.uid)
end;

In [None]:
@memoize function get_timestamp_encodings()
    function parse_line(file::IO, field::String, format::Type = Int)
        line = readline(file)
        fields = split(strip(line), ",")
        @assert length(fields) == 2
        @assert fields[1] == field
        parse(format, fields[2])
    end

    open(get_data_path("processed_data/timestamps.csv")) do f
        return parse_line(f, "min_timestamp"), parse_line(f, "max_timestamp")
    end
end


@memoize function day_in_timestamp_units()
    min_timestamp, max_timestamp = get_timestamp_encodings()
    seconds_in_day = 24 * 60 * 60
    return seconds_in_day / Float64(max_timestamp - min_timestamp)
end

days_in_timestamp_units(d) = d * day_in_timestamp_units()

function timestamp_to_unix(ts)
    @assert ts > 0
    min_timestamp, max_timestamp = get_timestamp_encodings()
    unix_time = Int64(round(ts * (max_timestamp - min_timestamp) + min_timestamp))
end

function unix_to_timestamp(unix_time)
    min_timestamp, max_timestamp = get_timestamp_encodings()
    ts = (unix_time - min_timestamp) / (max_timestamp - min_timestamp)
    convert(Float32, ts)
end

function timestamp_to_date(ts)
    Dates.unix2datetime(timestamp_to_unix(ts))
end;

In [None]:
function get_status(status::Symbol)
    status_encoding = Dict(
        :rewatching => 7,
        :completed => 6,
        :watching => 5,
        :plan_to_watch => 4,
        :on_hold => 3,
        :dropped => 2,
        :wont_watch => 1,
        :none => 0,
    )
    status_encoding[status]
end;

## Reading and writing data

In [None]:
@kwdef struct RatingsDataset
    source::Vector{Int32} = []
    medium::Vector{Int32} = []
    userid::Vector{Int32} = []
    itemid::Vector{Int32} = []
    status::Vector{Int32} = []
    rating::Vector{Float32} = []
    update_order::Vector{Int32} = []
    updated_at::Vector{Float32} = []
    created_at::Vector{Float32} = []
    started_at::Vector{Float32} = []
    finished_at::Vector{Float32} = []
    progress::Vector{Float32} = []
    repeat_count::Vector{Int32} = []
    priority::Vector{Float32} = []
    sentiment::Vector{Int32} = []
    sentiment_score::Vector{Float32} = []
    owned::Vector{Float32} = []
    metric::Vector{Float32} = []
    alpha::Vector{Float32} = []
end;

# append two datasets
function Base.cat(x::RatingsDataset, y::RatingsDataset)
    # argument validation
    x_is_nonempty = false
    y_is_nonempty = false
    for field in fieldnames(RatingsDataset)
        if field in [:medium]
            @assert getfield(x, field) == getfield(y, field)
            continue
        end
        if length(getfield(x, field)) != 0
            x_is_nonempty = true
        end
        if length(getfield(y, field)) != 0
            y_is_nonempty = true
        end
    end
    if x_is_nonempty && y_is_nonempty
        for field in fieldnames(RatingsDataset)
            a = length(getfield(x, field)) != 0
            b = length(getfield(y, field)) != 0
            @assert a == b "cat: missing field $field"
        end
    end

    RatingsDataset(
        [vcat(getfield(x, c), getfield(y, c)) for c in fieldnames(RatingsDataset)]...,
    )
end

function Base.filter(x::RatingsDataset, mask::BitVector)
    filter_array(a) = length(a) > 0 ? a[mask] : []
    RatingsDataset([filter_array(getfield(x, c)) for c in fieldnames(RatingsDataset)]...)
end;

In [None]:
function SparseArrays.sparse(x::RatingsDataset, medium)
    SparseArrays.sparse(
        x.itemid .+ 1,
        x.userid .+ 1,
        x.metric,
        num_items(medium),
        num_users(),
    )
end;

In [None]:
const ALL_SPLITS = ["training", "validation", "test", "negative"]
const ALL_METRICS = ["rating", "watch", "plantowatch", "drop"]
const ALL_MEDIUMS = ["manga", "anime"];

In [None]:
function get_raw_split(split::String, medium::String, fields::Vector{Symbol}, alpha)
    stem = get_data_path("splits/$(split).$(medium)")
    df = RatingsDataset()
    data = Any[nothing for _ = 1:length(fields)]
    Threads.@threads for i = 1:length(fields)
        data[i] = JLD2.load("$stem.$(fields[i]).jld2", String(fields[i]))
    end
    for i = 1:length(fields)
        df = Setfield.set(df, Setfield.PropertyLens{fields[i]}(), data[i])
    end
    if !isnothing(alpha)
        alpha_df = JLD2.load(get_data_path("alphas/$(alpha)/alpha.jld2"), "$split.$medium")
        if length(fields) > 0
            @assert length(getfield(df, fields[1])) == length(alpha_df.alpha)
        end
        df = @set df.alpha = alpha_df.alpha
    end
    df
end;

In [None]:
function get_split(
    split::String,
    metric::String,
    medium::String,
    fields::Vector{Symbol},
    alpha = nothing,
)
    validargs = split in ALL_SPLITS && metric in ALL_METRICS && medium in ALL_MEDIUMS
    @assert validargs "($split, $metric, $medium) ∉ (split, metric, medium)"

    extrafields = Dict(
        "rating" => [:rating],
        "watch" => [:status],
        "plantowatch" => [:status],
        "drop" => [:status],
    )
    fetch = union(Set(fields), Set(extrafields[metric]))
    delete!(fetch, :metric)
    df = get_raw_split(split, medium, collect(fetch), alpha)

    if metric == "rating"
        df = filter(df, df.rating .!= 0)
        df = @set df.metric = copy(df.rating)
    elseif metric == "watch"
        df = filter(df, df.status .> get_status(:plan_to_watch))
        df = @set df.metric = ones(Float32, length(df.status))
    elseif metric == "plantowatch"
        df = filter(df, df.status .== get_status(:plan_to_watch))
        df = @set df.metric = ones(Float32, length(df.status))
    elseif metric == "drop"
        df = @set df.metric = zeros(Float32, length(df.status))
        df = filter(df, df.status .> get_status(:none))
        for i = 1:length(df.status)
            if df.status[i] <= get_status(:dropped)
                df.metric[i] = 1
            end
        end
    else
        @assert false
    end

    # clear any columns you are not interested in
    if !isnothing(alpha)
        push!(fields, :alpha)
    end
    for f in fieldnames(RatingsDataset)
        if f ∉ fields
            df = Setfield.set(df, Setfield.PropertyLens{f}(), [])
        end
    end
    df
end;

In [None]:
function write_alpha(
    model::Function,
    medium::String,
    outdir::String,
    splits::Vector{String},
)
    @assert "negative" in splits
    alpha = Dict()
    for split in splits
        df = get_raw_split(split, medium, [:userid, :itemid], nothing)
        x = model(df.userid, df.itemid)
        @assert length(x) == length(df.userid)
        alpha["$split.$medium"] = RatingsDataset(alpha = x)
    end
    outdir = mkpath(get_data_path("alphas/$outdir"))
    JLD2.save("$outdir/alpha.jld2", alpha; compress = true)
end;

In [None]:
function read_alpha(alpha::String, split::String, metric::String, medium::String)
    get_split(split, metric, medium, [:userid, :itemid], alpha)
end;

In [None]:
function write_params(params, outdir)
    outdir = mkpath(get_data_path("alphas/$outdir"))
    JLD2.save("$outdir/params.jld2", params; compress = true)
end

function read_params(alpha)
    JLD2.load(get_data_path("alphas/$alpha/params.jld2"))
end;

## Loss functions

In [None]:
function loss(x, y, w, metric)
    safelog(x) = log(x .+ Float32(eps(Float64))) # so that log(0) doesn't NaN
    if metric == "rating"
        lossfn = (x, y) -> (x - y) .^ 2
    elseif metric in ["watch", "plantowatch"]
        lossfn = (x, y) -> -y .* safelog.(x)
    elseif metric == "drop"
        lossfn = (x, y) -> -(y .* safelog.(x) + (1 .- y) .* safelog.(1 .- x))
    else
        @assert false
    end
    sum(lossfn(x, y) .* w) / sum(w)
end;

In [None]:
# find β s.t. loss(X * β, y, w) is minimized
function regress(X, y, w, metric)
    if metric == "rating"
        Xw = (X .* sqrt.(w))
        yw = (y .* sqrt.(w))
        # prevent singular matrix
        λ = 1f-9 * LinearAlgebra.I(size(Xw)[2])
        return (Xw'Xw + λ) \ Xw'yw
    elseif metric in ["watch", "plantowatch", "drop"]
        return softmax(
            Optim.minimizer(
                Optim.optimize(
                    β -> loss(X * softmax(β), y, w, metric),
                    fill(0.0f0, size(X)[2]),
                    Optim.LBFGS(),
                    autodiff = :forward,
                    Optim.Options(g_tol = 1e-6, iterations = 100),
                ),
            ),
        )
    else
        @assert false
    end
end;

In [None]:
function powerdecay(x, a)
    x == 0 ? zero(eltype(a)) : sign(x) * abs(x)^a
end

function powerdecay(x::Vector, a)
    y = Array{eltype(a)}(undef, length(x))
    Threads.@threads for i = 1:length(x)
        @inbounds y[i] = powerdecay(x[i], a)
    end
    y
end

function get_counts(data::Vector{Int32})
    counts = Dict{Int32,Int32}()
    @showprogress for i = 1:length(data)
        u = data[i]
        if u ∉ keys(counts)
            counts[u] = 0
        end
        counts[u] += 1
    end
    counts

    ret = zeros(Int32, length(data))
    for i = 1:length(data)
        ret[i] = counts[data[i]]
    end
    ret
end;

In [None]:
function get_features(split::String, metric::String, medium::String, alphas::Vector{String})
    # labels
    df = get_split(split, metric, medium, [:userid, :metric])
    y = df.metric

    # weights
    w = powerdecay(get_counts(df.userid), -1.0f0)

    # inputs
    X = [read_alpha(a, split, metric, medium).alpha for a in alphas]
    if metric in ["watch", "plantowatch"]
        push!(X, fill(1.0f0 / num_items(medium), length(y)))
    elseif metric == "drop"
        push!(X, fill(1.0f0, length(y)))
    end
    X = hcat(X...)

    X, y, w
end;

In [None]:
function compute_loss(
    metric::String,
    medium::String,
    alphas::Vector{String},
    eval_split::String,
    regression_split::Union{String,Nothing} = nothing,
)
    if isnothing(regression_split)
        regression_split = eval_split
    end
    β = regress(get_features(regression_split, metric, medium, alphas)..., metric)
    X, y, w = get_features(eval_split, metric, medium, alphas)
    x = X * β
    loss(x, y, w, metric), β
end;