# Common utitities for all alphas

In [None]:
using CSV
using DataFrames
using Dates
using FileIO
using Flux
using JLD2
using JupyterFormatter
using LinearAlgebra
using LoggingExtras
using Memoize
using Optim
using Parameters
using ProgressMeter
using SparseArrays
using Statistics

# General utils

In [None]:
# TODO upstream this into the ProgressMeter
macro tprogress(expr)
    # let the @progress macro work with Threads.@threads
    loop = expr
    if loop.head == :macrocall && loop.args[1] == :(Threads.var"@threads")
        loop = loop.args[end]
    end
    
    p = gensym()    
    r = loop.args[1].args[end]
    ex = quote
        n = Int(round(length($(esc(r))) / Threads.nthreads()))
        global $p = Progress(n; showspeed=true)
        $(esc(expr))
        finish!($p)
    end
    
    update = quote
        if Threads.threadid() == 1
            next!($p)
        end
    end
    push!(loop.args[end].args, update)    
    
    ex    
end;

In [None]:
enable_autoformat();

In [None]:
# Prefer Julia multithreading to BLAS multithreading
BLAS.set_num_threads(1);

In [None]:
# Loss functions
function wmean(x, w)
    sum(x .* w) / sum(w)
end

function equal_weight(x)
    ones(eltype(x), length(x))
end

function mse(truth, pred, weights)
    wmean((truth .- pred) .^ 2, weights)
end

function mse(truth, pred)
    mse(truth, pred, equal_weight(truth))
end

function rmse(truth, pred, weights)
    sqrt(mse(truth, pred, weights))
end

function rmse(truth, pred)
    rmse(truth, pred, equal_weight(truth))
end

function mae(truth, pred, weights)
    wmean(abs.(truth .- pred), weights)
end

function mae(truth, pred)
    mae(truth, pred, equal_weight(truth))
end

function r2(truth, pred, weights)
    1 - mse(truth, pred, weights) / mse(truth, mean(truth), weights)
end

function r2(truth, pred)
    r2(truth, pred, equal_weight(truth))
end

function sparse_crossentropy(p, weights, ϵ = 1e-16)
    # cross entropy loss where we are given the probabilities for every true label 
    -wmean(log.(clamp.(p, ϵ, Inf)), weights)
end;

function sparse_crossentropy(p, ϵ = 1e-16)
    sparse_crossentropy(p, equal_weight(p), ϵ)
end;

In [None]:
# stop training if the parameters have converged
@with_kw mutable struct convergence_stopper
    tolerance::AbstractFloat
    max_iters = Inf
    params::AbstractVector
    prev_params::AbstractVector
    iters = 0
end

function convergence_stopper(tolerance; max_iters=Inf)
    convergence_stopper(tolerance = tolerance, max_iters=max_iters, params = [], prev_params = [])
end

function stop!(x::convergence_stopper, params)
    x.iters += 1
    if x.iters > x.max_iters
        return true
    end

    if x.iters == 1
        x.params = deepcopy(params)
        return false
    end

    function maxabs(a)
        maximum(abs.(a))
    end

    x.prev_params = deepcopy(x.params)
    x.params = deepcopy(params)
    maximum(maxabs.(x.params - x.prev_params)) < x.tolerance
end;

In [None]:
# stop training if the loss function has stopped decreasing
@with_kw mutable struct early_stopper
    max_iters = Inf
    patience = Inf
    min_rel_improvement = 0
    iters = 0
    iters_without_improvement = 0
    loss = NaN
end

function stop!(x::early_stopper, loss)
    x.iters += 1
    if x.iters > x.max_iters
        return true
    end
    
    if x.iters == 1
        x.loss = loss
        return false
    end

    if loss < x.loss * (1 - x.min_rel_improvement)
        x.loss = loss
        x.iters_without_improvement = 0
    else
        x.iters_without_improvement += 1
    end
    x.iters_without_improvement > x.patience
end;

In [None]:
# Logger that flushes after every log statement
struct FlushLogger <: AbstractLogger
    logger::ConsoleLogger
end

function FlushLogger(logger::AbstractLogger)
    FlushLogger(logger)
end

function Logging.handle_message(logger::FlushLogger, args...; kwargs...)
    Logging.handle_message(logger.logger, args...; kwargs...)
    flush(logger.logger.stream)
end

Logging.shouldlog(logger::FlushLogger, arg...) = Logging.shouldlog(logger.logger, arg...)
Logging.min_enabled_level(logger::FlushLogger) = Logging.min_enabled_level(logger.logger)
Logging.catch_exceptions(logger::FlushLogger) = Logging.catch_exceptions(logger.logger)

function logging_meta_formatter(level, _module, group, id, file, line)
    prefix_color = (
        level < Logging.Info ? 4 : level < Logging.Warn ? 6 : level < Logging.Error ? 3 : 1
    )
    prefix = (level == Logging.Warn ? "Warning" : string(level)) * ':'
    prefix_color, prefix, ""
end;

In [None]:
# Log to file and stdout at the same time
function redirect_logging(outdir)
    date_format = "yyyymmdd HH:MM:SS"
    timestamp_logger(logger) =
        TransformerLogger(logger) do log
            merge(log, (; message = "$(Dates.format(now(), date_format)) $(log.message)"))
        end

    outdir = mkpath(outdir)
    global_logger(
        TeeLogger(
            FlushLogger(
                ConsoleLogger(
                    stderr,
                    Logging.Info;
                    meta_formatter = logging_meta_formatter,
                ),
            ) |> timestamp_logger,
            FlushLogger(
                ConsoleLogger(
                    open("$(outdir)/log", write = true),
                    Logging.Info;
                    meta_formatter = logging_meta_formatter,
                ),
            ) |> timestamp_logger,
        ),
    )
end;

redirect_logging("../../data/alphas/$name");

In [None]:
# A custom split layer for Flux
struct Split{T}
    paths::T
end
Split(paths...) = Split(paths)
Flux.@functor Split
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)

# Alpha specific utils

In [None]:
@with_kw struct RatingsDataset
    user::Vector{Int32}
    item::Vector{Int32}
    rating::Vector{Float32}
end;

function Base.adjoint(x::RatingsDataset)
    RatingsDataset(x.item, x.user, x.rating)
end;

function get_split(split)
    @assert split in ["training", "validation", "test", "implicit", "implicit_training"]
    file = "../../data/splits/splits.jld2"
    load(file, split)
end;

function get_split(split, transpose)
    df = get_split(split)
    return transpose ? df' : df
end;

In [None]:
# Some sparse matrix operations require indices to be Int64
@with_kw struct RatingsDataset64
    user::Vector{Int64}
    item::Vector{Int64}
    rating::Vector{Float32}
end;

function RatingsDataset64(x::RatingsDataset)
    RatingsDataset64(
        convert.(Int64, x.user),
        convert.(Int64, x.item),
        convert.(Float32, x.rating),
    )
end;

In [None]:
function Base.cat(x::RatingsDataset, y::RatingsDataset)
    RatingsDataset([x.user; y.user], [x.item; y.item], [x.rating; y.rating])
end;

In [None]:
@memoize function num_items()
    df = DataFrame(CSV.File("../../data/processed_data/anime_to_uid.csv"))
    length(df.uid)
end;

In [None]:
function get_alpha(alpha, split)
    @assert split in ["training", "validation", "test"]
    file = "../../data/alphas/$(alpha)/predictions.jld2"
    load(file, split)
end;

In [None]:
function weighting_scheme(x, scheme::Number)
    return x == 0 ? zero(eltype(scheme)) : x^scheme
end;

function weighting_scheme(scheme::String)
    if scheme == "linear"
        return 1
    elseif scheme == "constant"
        return 0
    elseif scheme == "inverse"
        return -1
    else
        @assert false
        return 0
    end
end;

function get_counts(split)
    counts = zeros(eltype(split.rating), maximum(split.user), Threads.nthreads())
    @tprogress Threads.@threads for i = 1:length(split.rating)
        counts[split.user[i], Threads.threadid()] += 1
    end
    sum(counts, dims = 2)    
end

function get_weights(split, scheme::Number; per_rating=false, by_item=false)
    split = get_split(split, by_item)
    counts = get_counts(split)
    
    if per_rating
       return weighting_scheme.(vec(counts), scheme)
    end

    outtype = eltype(scheme) <: Number ? eltype(scheme) : eltype(counts)
    weights = zeros(outtype, length(split.user))
    Threads.@threads for i = 1:length(weights)
        weights[i] = weighting_scheme(counts[split.user[i]], scheme)
    end
    weights
end

@memoize function get_weights(split, scheme::String)
    el_type = eltype(get_split(split).rating)
    get_weights(split, convert(el_type, weighting_scheme(scheme)))
end;

In [None]:
@memoize function get_residual_beta(alphas)
    # train a linear model on the validation set
    split = "validation"
    y = get_split(split).rating
    X = zeros(length(y), length(alphas))
    @tprogress Threads.@threads for j = 1:length(alphas)
        X[:, j] = get_alpha(alphas[j], split).rating
    end
    
    # weight each user equally
    weights_sqrt = sqrt.(get_weights(split, "inverse"))
    (X .* weights_sqrt) \ (y .* weights_sqrt)
end

function get_residuals(split, alphas)
    # residualize out the linear model
    β = get_residual_beta(alphas)
    df = get_split(split)
    ratings = df.rating
    @showprogress for j = 1:length(alphas)
        ratings -= β[j] * get_alpha(alphas[j], split).rating
    end
    RatingsDataset(df.user, df.item, ratings)
end;

In [None]:
function write_predictions(model; outdir = name, residual_alphas, implicit = false)
    splits_to_save = ["training", "validation", "test"]
    if implicit
        # the implicit predictions for the training set are not 
        # worth storing because we don't use them for residualization
        splits_to_save = splits_to_save[2:end]
    end
    # don't cheat by peeking at the test set
    splits_to_log = splits_to_save[1:end-1]

    predictions = Dict()
    for split in splits_to_save
        df = get_residuals(split, residual_alphas)
        pred = model(df.user, df.item)
        if split in splits_to_log
            if !implicit
                truth = df.rating
                β = pred \ truth
                @info "$(split) set: RMSE $(rmse(truth, β*pred)) " *
                      "MAE $(mae(truth, β*pred)) R2 $(r2(truth, β*pred))"
                weights = get_weights(split, "inverse")
                β = (pred .* sqrt.(weights)) \ (truth .* sqrt.(weights))
                @info "$(split) set weighted-loss: RMSE $(rmse(truth, β*pred, weights)) " *
                      "MAE $(mae(truth, β*pred, weights)) R2 $(r2(truth, β*pred, weights))"                
            else
                @info "$(split) set: Cross-Entropy loss $(sparse_crossentropy(pred))"
                weights = get_weights(split, "inverse")
                @info "$(split) set weighted-loss: Cross-Entropy loss $(sparse_crossentropy(pred, weights))"                
            end
        end
        predictions[split] = RatingsDataset(df.user, df.item, pred)
    end

    outdir = mkpath("../../data/alphas/$outdir")
    save("$outdir/predictions.jld2", predictions)
end;

In [None]:
function write_params(params; outdir = name)
    outdir = mkpath("../../data/alphas/$outdir")
    save("$outdir/params.jld2", params)
end;

In [None]:
function read_params(alpha)
    load("../../data/alphas/$alpha/params.jld2")
end;