# Helper functions for training alphas

In [1]:
import CSV
import DataFrames: DataFrame
import JLD2
import Memoize: @memoize
import NBInclude: @nbinclude
import NNlib: softmax
import Optim
import ProgressMeter: @showprogress
import Setfield
import Setfield: @set
import SparseArrays

## General utilities

In [2]:
@nbinclude("AlphaUtils.ipynb");

In [3]:
function get_data_path(file)
    path = pwd()
    while basename(path) != "notebooks"
        path = dirname(path)
    end
    path = dirname(path)
    "$path/data/$file"
end;

In [4]:
function set_logging_outdir(name)
    redirect_logging(get_data_path("alphas/$name"); overwrite = false)
end
if @isdefined name
    set_logging_outdir(name)
end

## Static data

In [5]:
@memoize function num_users()
    df = DataFrame(CSV.File(get_data_path("processed_data/username_to_uid.csv")))
    length(df.uid)
end

@memoize function num_items(medium)
    df = DataFrame(CSV.File(get_data_path("processed_data/$(medium)_to_uid.csv")))
    length(df.uid)
end;

In [6]:
@memoize function get_timestamp_encodings()
    function parse_line(file::IO, field::String, format::Type = Int)
        line = readline(file)
        fields = split(strip(line), ",")
        @assert length(fields) == 2
        @assert fields[1] == field
        parse(format, fields[2])
    end

    open(get_data_path("processed_data/timestamps.csv")) do f
        return parse_line(f, "min_timestamp"), parse_line(f, "max_timestamp")
    end
end


@memoize function day_in_timestamp_units()
    min_timestamp, max_timestamp = get_timestamp_encodings()
    seconds_in_day = 24 * 60 * 60
    return seconds_in_day / Float64(max_timestamp - min_timestamp)
end

function timestamp_to_unix(ts)
    @assert ts > 0
    min_timestamp, max_timestamp = get_timestamp_encodings()
    unix_time = Int64(round(ts * (max_timestamp - min_timestamp) + min_timestamp))
end

function unix_to_timestamp(unix_time)
    min_timestamp, max_timestamp = get_timestamp_encodings()
    ts = (unix_time - min_timestamp) / (max_timestamp - min_timestamp)
    convert(Float32, ts)
end

function timestamp_to_date(ts)
    Dates.unix2datetime(timestamp_to_unix(ts))
end;

In [7]:
function get_status(status::Symbol)
    status_encoding = Dict(
        :rewatching => 7,
        :completed => 6,
        :watching => 5,
        :plan_to_watch => 4,
        :on_hold => 3,
        :dropped => 2,
        :wont_watch => 1,
        :none => 0,
    )
    status_encoding[status]
end;

## Reading and writing data

In [8]:
@kwdef struct RatingsDataset
    source::Vector{Int32} = []
    medium::Vector{Int32} = []
    userid::Vector{Int32} = []
    itemid::Vector{Int32} = []
    status::Vector{Int32} = []
    rating::Vector{Float32} = []
    update_order::Vector{Int32} = []
    updated_at::Vector{Float32} = []
    created_at::Vector{Float32} = []
    started_at::Vector{Float32} = []
    finished_at::Vector{Float32} = []
    progress::Vector{Float32} = []
    repeat_count::Vector{Int32} = []
    priority::Vector{Float32} = []
    sentiment::Vector{Int32} = []
    sentiment_score::Vector{Float32} = []
    owned::Vector{Float32} = []
    metric::Vector{Float32} = []
end;

# append two datasets
function Base.cat(x::RatingsDataset, y::RatingsDataset)
    # argument validation
    x_is_nonempty = false
    y_is_nonempty = false
    for field in fieldnames(RatingsDataset)
        if field in [:medium]
            @assert getfield(x, field) == getfield(y, field)
            continue
        end
        if length(getfield(x, field)) != 0
            x_is_nonempty = true
        end
        if length(getfield(y, field)) != 0
            y_is_nonempty = true
        end
    end
    if x_is_nonempty && y_is_nonempty
        for field in fieldnames(RatingsDataset)
            a = length(getfield(x, field)) != 0
            b = length(getfield(y, field)) != 0
            @assert a == b "cat: missing field $field"
        end
    end

    RatingsDataset(
        [vcat(getfield(x, c), getfield(y, c)) for c in fieldnames(RatingsDataset)]...,
    )
end

function Base.filter(x::RatingsDataset, mask::BitVector)
    filter_array(a) = length(a) > 0 ? a[mask] : []
    RatingsDataset([filter_array(getfield(x, c)) for c in fieldnames(RatingsDataset)]...)
end;

In [9]:
function SparseArrays.sparse(x::RatingsDataset)
    SparseArrays.sparse(x.item, x.user, x.metric, num_items(x.medium), num_users())
end;

In [10]:
const ALL_SPLITS = ["training", "validation", "test", "negative"]
const ALL_METRICS = ["rating", "watch", "plantowatch", "drop"]
const ALL_MEDIUMS = ["manga", "anime"];

In [40]:
function get_split(split::String, metric::String, medium::String, fields::Vector{Symbol})
    validargs = split in ALL_SPLITS && metric in ALL_METRICS && medium in ALL_MEDIUMS
    @assert validargs "($split, $metric, $medium) ∉ (split, metric, medium)"

    extrafields = Dict(
        "rating" => [:rating],
        "watch" => [:status],
        "plantowatch" => [:status],
        "drop" => [:status],
    )
    fetch = union(Set(fields), Set(extrafields[metric]))
    delete!(fetch, :metric)
    df = get_raw_split(split, medium, collect(fetch))

    if metric == "rating"
        df = filter(df, df.rating .!= 0)
        df = @set df.metric = copy(df.rating)
    elseif metric == "watch"
        df = filter(df, df.status .> get_status(:wont_watch))
        df = @set df.metric = ones(Float32, length(df.status))
    elseif metric == "plantowatch"
        df = filter(df, df.status .== get_status(:plan_to_watch))
        df = @set df.metric = ones(Float32, length(df.status))
    elseif metric == "drop"
        df = filter(
            df,
            (df.status .> get_status(:wont_watch)) .&&
            (df.status .!= get_status(:plan_to_watch)),
        )
        df = @set df.metric = zeros(Float32, length(df.status))
        for i = 1:length(df.status)
            if df.status[i] <= get_status(:dropped)
                df.metric[i] = 1
            end
        end
    else
        @assert false
    end

    # clear any columns you are not interested in
    for f in fieldnames(RatingsDataset)
        if f ∉ fields
            df = Setfield.set(df, Setfield.PropertyLens{f}(), [])
        end
    end
    df
end

function get_raw_split(split::String, medium::String, fields::Vector{Symbol})
    file = get_data_path("splits/$(split).$(medium).jld2")
    data = JLD2.load(file, String.(fields)...)
    df = RatingsDataset()
    for i = 1:length(fields)
        df = Setfield.set(df, Setfield.PropertyLens{fields[i]}(), data[i])
    end
    df
end;

In [15]:
function write_alpha(
    model::Function,
    medium::String,
    outdir::String;
    splits::Vector{String},
)
    alpha = Dict()
    for split in splits
        df = get_raw_split(split, medium, [:user, :item])
        x = model(df.user, df.item)
        @assert length(x) == length(df.user)
        alpha["$split.$medium"] = RatingsDataset(metric = x)
    end
    outdir = mkpath(get_data_path("alphas/$outdir"))
    JLD2.save("$outdir/alpha.jld2", alpha; compress = true)
end;

In [16]:
function read_alpha(alpha::String, split::String, metric::String, medium::String)
    file = get_data_path("alphas/$(alpha)/alpha.jld2")
    a = JLD2.load(file, "$split.$medium")
    df = get_split(split, task, content, medium; fields = [:user, :item])
    @assert length(df.user) == length(a.metric)
    @set df.metric = a
end;

In [17]:
function write_params(params, outdir)
    outdir = mkpath(get_data_path("alphas/$outdir"))
    JLD2.save("$outdir/params.jld2", params)
end

function read_params(alpha)
    JLD2.load(get_data_path("alphas/$alpha/params.jld2"))
end;

## Loss functions

In [18]:
function weighted_loss(x, y, w, lossfn)
    sum(lossfn(x, y) .* w) / sum(w)
end

function weighted_loss_multithreaded(x, y, w, lossfn)
    @assert false "multithreaded loss is deprecated"
    # a = zeros(eltype(x), Threads.nthreads())
    # b = zeros(eltype(w), Threads.nthreads())
    # Threads.@threads :static for _ = 1:Threads.nthreads()
    #     range = thread_range(length(x))
    #     # Base.sum uses pairwise summation which is important for accuracy
    #     @views weight = sum(w[range])
    #     @views a[Threads.threadid()] = sum(lossfn(x[range], y[range]) .* w[range])
    #     b[Threads.threadid()] = weight
    # end
    # sum(a) / sum(b)
end

function loss(x, y, w, metric; multithreaded = false)
    if metric == "rating"
        lossfn = (x, y) -> (x - y) .^ 2
    elseif metric in ["watch", "plantowatch"]
        lossfn = (x, y) -> -y .* log.(x)
    elseif metric == "drop"
        lossfn = (x, y) -> -(y .* log.(x) + (1 .- y) .* log.(1 .- x))
    else
        @assert false
    end
    if multithreaded
        evaluator = weighted_loss_multithreaded
    else
        evaluator = weighted_loss
    end
    evaluator(x, y, w, lossfn)
end;

In [19]:
# find β s.t. loss(X * β, y, w) is minimized
function regress(X, y, w, metric)
    if metric == "rating"
        Xw = (X .* sqrt.(w))
        yw = (y .* sqrt.(w))
        # prevent singular matrix
        λ = convert(eltype(y), 1f-9) * LinearAlgebra.I(size(Xw)[2])
        β = (Xw'Xw + λ) \ Xw'yw
    elseif metric in ["watch", "plantowatch", "drop"]
        β = softmax(
            Optim.minimizer(
                Optim.optimize(
                    β -> loss(X * softmax(β), y, w, metric),
                    fill(0.0f0, size(X)[2]),
                    Optim.LBFGS(),
                    autodiff = :forward,
                    Optim.Options(g_tol = 1e-6, iterations = 100),
                ),
            ),
        )
    else
        @assert false
    end
    X * β, β
end;

In [20]:
# function get_split(
#     split::String,
#     task::String,
#     content::String,
#     medium::String;
#     transpose::Bool = false,
#     fields::Union{Nothing,Vector{Symbol}} = nothing,
# )
#     raw_split(split, task, content, medium) =
#         get_raw_split(split, task, content, medium; fields = fields)
#     if content == "explicit"
#         df = raw_split(split, task, content, medium)
#     elseif content == "implicit"
#         df = cat(
#             raw_split(split, task, "explicit", medium),
#             raw_split(split, task, "implicit", medium),
#         )
#         df.rating .= 1
#     elseif content == "ptw"
#         df = raw_split(split, task, content, medium)
#         df.rating .= 1
#     elseif content == "negative"
#         df = raw_split(split, task, content, medium)
#         N = length(get_raw_split(split, task, content, medium; fields = [:user]).user)
#         df = @set df.rating = fill(0.0f0, N)
#     else
#         @assert false
#     end
#     transpose ? df' : df
# end

# function get_raw_split(
#     split::String,
#     task::String,
#     content::String,
#     medium::String;
#     fields::Union{Nothing,Vector{Symbol}} = nothing,
# )
#     if task == "all"
#         return reduce(
#             cat,
#             [
#                 get_raw_split(split, task, content, medium; fields = fields) for
#                 task in ALL_TASKS
#             ],
#         )
#     end
#     @assert split in ALL_SPLITS &&
#             content in ALL_CONTENTS &&
#             task in ALL_TASKS &&
#             medium in ALL_MEDIUMS "($split $content $task $medium)"
#     if split != "test" && content == "negative"
#         return RatingsDataset(medium = medium)
#     end
#     file = get_data_path("splits/$(content).$(task).$(split).$(medium).jld2")
#     df = JLD2.load(file, "dataset")
#     # clear any columns you are not interested in
#     if !isnothing(fields)
#         for f in fieldnames(RatingsDataset)
#             if f ∉ fields && f ∉ [:medium]
#                 df = Setfield.set(df, Setfield.PropertyLens{f}(), [])
#             end
#         end
#     end
#     df
# end

# const ALL_SPLITS = ["training", "validation", "test"]
# const ALL_CONTENTS = ["explicit", "implicit", "ptw", "negative"]
# const ALL_TASKS = ["temporal_causal"]
# const ALL_MEDIUMS = ["manga", "anime"];

In [21]:
# function log_split_loss(
#     model,
#     alphas::Vector{String},
#     task::String,
#     content::String,
#     medium::String,
#     implicit::Bool,
#     splits_to_log::Vector{String};
#     by_split = false,
# )
#     β = nothing
#     for split in splits_to_log
#         if by_split
#             x = model(split, task, content, medium; raw_splits = false)
#         else
#             df = get_split(split, task, content, medium; fields = [:user, :item])
#             x = model(df.user, df.item)
#         end
#         @assert length(x) ==
#                 length(get_split(split, task, content, medium; fields = [:user]).user)
#         if isnothing(β)
#             @assert split == "validation" || splits_to_log == ["test"]
#             _, β = regress(alphas, split, task, content, medium, implicit, x)
#         end
#         @info "$split loss: $(residualized_loss(alphas, task, content, medium, implicit, x, β, split)), β: $β"
#     end
# end

# function write_alpha(
#     model::Function,
#     medium::String,
#     outdir::String;
#     task::String = "all",
#     splits::Vector{String} = ALL_SPLITS, # TODO make this a required arg
#     by_split::Bool = false,
#     log::Bool = true,
#     log_task::Union{String,Nothing} = nothing,
#     log_content::Union{String,Nothing} = nothing,
#     log_alphas::Vector{String} = String[],
#     log_splits::Vector{String} = ["validation", "test"],
# )
#     if log
#         if log_content == "explicit"
#             implicit = false
#         elseif log_content in ["implicit", "ptw"]
#             implicit = true
#         else
#             @assert false
#         end
#         log_split_loss(
#             model,
#             log_alphas,
#             log_task,
#             log_content,
#             medium,
#             implicit,
#             log_splits;
#             by_split = by_split,
#         )
#     else
#         @info "not logging split losses"
#     end

#     tasks = task == "all" ? ALL_TASKS : [task]
#     predictions = Dict()
#     for split in splits
#         for content in ALL_CONTENTS
#             for task in tasks
#                 if by_split
#                     x = model(split, task, content, medium; raw_splits = true)
#                 else
#                     df =
#                         get_raw_split(split, task, content, medium; fields = [:user, :item])
#                     x = model(df.user, df.item)
#                 end
#                 @assert length(x) == length(
#                     get_raw_split(split, task, content, medium; fields = [:user]).user,
#                 )
#                 predictions["$content.$task.$split.$medium"] =
#                     RatingsDataset(rating = x, medium = medium)
#             end
#         end
#     end

#     outdir = mkpath(get_data_path("alphas/$outdir"))
#     JLD2.save("$outdir/predictions.jld2", predictions)
# end

# function read_raw_alpha_impl(
#     alpha::String,
#     split::String,
#     task::String,
#     content::String,
#     medium::String,
# )
#     file = get_data_path("alphas/$(alpha)/predictions.jld2")
#     JLD2.load(file, "$content.$task.$split.$medium")
# end


# function read_raw_alpha(
#     alpha::String,
#     split::String,
#     task::String,
#     content::String,
#     medium::String,
# )
#     # allow read_raw_alpha to be overridden by hiding the implementation
#     if task == "all"
#         return reduce(
#             cat,
#             [read_raw_alpha(alpha, split, task, content, medium) for task in ALL_TASKS],
#         )
#     end
#     @assert split in ALL_SPLITS &&
#             content in ALL_CONTENTS &&
#             task in ALL_TASKS &&
#             medium in ALL_MEDIUMS "($split $content $task $medium)"
#     read_raw_alpha_impl(alpha, split, task, content, medium)
# end

# function read_alpha(
#     alpha::String,
#     split::String,
#     task::String,
#     content::String,
#     medium::String,
# )
#     if content == "explicit"
#         p = read_raw_alpha(alpha, split, task, content, medium)
#     elseif content == "implicit"
#         p = cat(
#             read_raw_alpha(alpha, split, task, "explicit", medium),
#             read_raw_alpha(alpha, split, task, "implicit", medium),
#         )
#     elseif content == "ptw"
#         p = read_raw_alpha(alpha, split, task, "ptw", medium)
#     elseif content == "negative"
#         p = read_raw_alpha(alpha, split, task, content, medium)
#     else
#         @assert false
#     end
#     df = get_split(split, task, content, medium; fields = [:user, :item])
#     @assert length(df.user) == length(p.rating)
#     RatingsDataset(user = df.user, item = df.item, rating = p.rating, medium = medium)
# end;

## Weight decays

In [22]:
# function powerdecay(x, a)
#     x == 0 ? zero(eltype(a)) : sign(x) * abs(x)^a
# end

# function powerdecay(x::Vector, a)
#     y = Array{eltype(a)}(undef, length(x))
#     Threads.@threads for i = 1:length(x)
#         @inbounds y[i] = powerdecay(x[i], a)
#     end
#     y
# end

# function powerlawdecay(x, a)
#     a^x
# end

# function powerlawdecay(x::Vector, a)
#     a .^ x
# end

# function weighting_scheme(scheme::Number)
#     scheme
# end

# function weighting_scheme(scheme::String)
#     if scheme == "constant"
#         return 0.0f0
#     elseif scheme == "inverse"
#         return -1.0f0
#     else
#         @assert false
#         return 0.0f0
#     end
# end;

# function get_user_counts(
#     split::String,
#     task::String,
#     content::String,
#     medium::String,
#     by_item::Bool,
# )
#     df = get_split(
#         split,
#         task,
#         content,
#         medium;
#         transpose = by_item,
#         fields = [:user, :item],
#     )
#     len = by_item ? num_items(medium) : num_users()
#     counts = zeros(eltype(df.rating), len, Threads.nthreads())
#     Threads.@threads for i = 1:length(df.user)
#         @inbounds counts[df.user[i], Threads.threadid()] += 1
#     end
#     vec(sum(counts, dims = 2))
# end

# @memoize function get_counts(
#     split::String,
#     task::String,
#     content::String,
#     medium::String;
#     per_rating::Bool = true,
#     by_item::Bool = false,
# )
#     user_counts = get_user_counts(split, task, content, medium, by_item)
#     if !per_rating
#         return user_counts
#     end

#     df = get_split(
#         split,
#         task,
#         content,
#         medium;
#         transpose = by_item,
#         fields = [:user, :item],
#     )
#     counts = Array{eltype(user_counts)}(undef, length(df.user))
#     Threads.@threads for i = 1:length(counts)
#         @inbounds counts[i] = user_counts[df.user[i]]
#     end
#     counts
# end

# function get_weights(
#     split::String,
#     task::String,
#     content::String,
#     medium::String,
#     scheme::String,
# )
#     powerdecay(get_counts(split, task, content, medium), weighting_scheme(scheme))
# end;

## Regressions

In [23]:
# # given a matrix of features X, a vector of true labels y, and
# # a vector of weights w, a regression will find the β
# # that minimizes the weighted loss between X * β and y

# function regress(X, y, w, implicit::Bool)
#     if implicit
#         β = softmax(
#             Optim.minimizer(
#                 Optim.optimize(
#                     β -> loss(X * softmax(β), y, w, implicit; multithreaded = false),
#                     fill(0.0f0, size(X)[2]),
#                     Optim.LBFGS(),
#                     autodiff = :forward,
#                     Optim.Options(g_tol = 1e-6, iterations = 100),
#                 ),
#             ),
#         )
#     else
#         λ = Xw = (X .* sqrt.(w))
#         yw = (y .* sqrt.(w))
#         # prevent singular matrix
#         λ = convert(eltype(y), 1f-9) * LinearAlgebra.I(size(Xw)[2])
#         β = (Xw'Xw + λ) \ Xw'yw
#     end
#     X * β, β
# end;

In [24]:
# # regress the given features on the validation set
# function regress(
#     alphas::Vector{String},
#     task::String,
#     content::String,
#     medium::String,
#     implicit::Bool,
#     x = nothing,
# )
#     regress(alphas, "validation", task, content, medium, implicit, x)
# end

# function regress(
#     alphas::Vector{String},
#     split::String,
#     task::String,
#     content::String,
#     medium::String,
#     implicit::Bool,
#     x = nothing,
# )
#     X = regression_features(alphas, split, task, content, medium, implicit, x)
#     y = get_split(split, task, content, medium; fields = [:rating]).rating
#     w = get_weights(split, task, content, medium, "inverse")
#     regress(X, y, w, implicit)
# end

# function regress(
#     alphas::Vector{String},
#     task::String,
#     content::String,
#     medium::String,
#     implicit::Bool,
# )
#     regress(alphas, task, content, medium, implicit, nothing)
# end

# # concatenates x, if given, with the alphas
# function regression_features(
#     alphas::Vector{String},
#     split::String,
#     task::String,
#     content::String,
#     medium::String,
#     implicit::Bool,
#     x = nothing,
# )
#     ncols = length(alphas) + (isnothing(x) ? 0 : 1) + implicit
#     shape =
#         isnothing(x) ? get_split(split, task, content, medium; fields = [:rating]).rating :
#         x
#     X = Array{eltype(shape),2}(undef, length(shape), ncols)
#     @showprogress for j = 1:length(alphas)
#         @inbounds X[:, j] = read_alpha(alphas[j], split, task, content, medium).rating
#     end

#     if implicit
#         # add a baseline feature for non-degeneracy
#         X[:, length(alphas)+1] .= 1.0f0 / num_items(medium)
#     end
#     if !isnothing(x)
#         X[:, end] = x
#     end
#     X
# end

# # linearly combinine the given alphas
# function read_alpha(
#     alphas::Vector{String},
#     split::String,
#     task::String,
#     content::String,
#     medium::String,
#     implicit::Bool,
# )
#     _, β = regress(alphas, task, content, medium, implicit)
#     X = regression_features(alphas, split, task, content, medium, implicit)
#     df = get_split(split, task, content, medium; fields = [:user, :item])
#     RatingsDataset(user = df.user, item = df.item, rating = X * β, medium = medium)
# end

# # performs a regression on the validation set and then 
# # calculates the validation loss of that linear combination
# function residualized_loss(
#     alphas::Vector{String},
#     task::String,
#     content::String,
#     medium::String,
#     implicit::Bool,
#     x,
# )
#     split = "validation"
#     x, β = regress(alphas, task, content, medium, implicit, x)
#     y = get_split(split, task, content, medium; fields = [:rating]).rating
#     loss(
#         x,
#         y,
#         get_weights(split, task, content, medium, "inverse"),
#         implicit;
#         multithreaded = true,
#     )
# end

# function residualized_loss(
#     alphas::Vector{String},
#     split::String,
#     task::String,
#     content::String,
#     medium::String,
#     implicit::Bool,
#     x,
#     β,
# )
#     X = regression_features(alphas, split, task, content, medium, implicit, x)
#     x = X * β
#     y = get_split(split, task, content, medium; fields = [:rating]).rating
#     loss(
#         x,
#         y,
#         get_weights(split, task, content, medium, "inverse"),
#         implicit;
#         multithreaded = true,
#     )
# end;