# Common utitities for all alphas

In [None]:
using CSV
using Dates
using FileIO
using JLD2
using JupyterFormatter
using LinearAlgebra
using LoggingExtras
using Memoize
using Optim
using ProgressMeter
using SparseArrays
using Statistics

# General utils

In [None]:
# name = "Alpha"
# residual_alphas = []

In [None]:
BLAS.set_num_threads(1)

In [None]:
macro tprogress(expr)
    # let the @progress macro work with Threads.@threads
    loop = expr
    if loop.head == :macrocall && loop.args[1] == :(Threads.var"@threads")
        loop = loop.args[end]
    end
    
    p = gensym()    
    r = loop.args[1].args[end]
    ex = quote
        n = Int(round(length($(esc(r))) / Threads.nthreads()))
        global $p = Progress(n; showspeed=true)
        $(esc(expr))
    end
    
    update = quote
        if Threads.threadid() == 1
            next!($p)
        end
    end
    push!(loop.args[end].args, update)    
    
    ex    
end;

In [None]:
enable_autoformat();

In [None]:
function mse(truth, pred)
    mean((truth .- pred) .^ 2)
end

function rmse(truth, pred)
    sqrt(mse(truth, pred))
end

function mae(truth, pred)
    mean(abs.(truth .- pred))
end

function r2(truth, pred)
    1 - mse(truth, pred) / mse(truth, mean(truth))
end;

In [None]:
struct FlushLogger <: AbstractLogger
    logger::ConsoleLogger
end

function FlushLogger(logger::AbstractLogger)
    FlushLogger(logger)
end

function Logging.handle_message(logger::FlushLogger, args...; kwargs...)
    Logging.handle_message(logger.logger, args...; kwargs...)
    flush(logger.logger.stream)
end

Logging.shouldlog(logger::FlushLogger, arg...) = Logging.shouldlog(logger.logger, arg...)
Logging.min_enabled_level(logger::FlushLogger) = Logging.min_enabled_level(logger.logger)
Logging.catch_exceptions(logger::FlushLogger) = Logging.catch_exceptions(logger.logger)

function logging_meta_formatter(level, _module, group, id, file, line)
    prefix_color = (
        level < Logging.Info ? 4 : level < Logging.Warn ? 6 : level < Logging.Error ? 3 : 1
    )
    prefix = (level == Logging.Warn ? "Warning" : string(level)) * ':'
    prefix_color, prefix, ""
end

In [None]:
# Improved logging
const date_format = "yyyymmdd HH:MM:SS"
timestamp_logger(logger) =
    TransformerLogger(logger) do log
        merge(log, (; message = "$(Dates.format(now(), date_format)) $(log.message)"))
    end

outdir = mkpath("../../data/alphas/$name");
global_logger(
    TeeLogger(
        FlushLogger(
            ConsoleLogger(stderr, Logging.Debug; meta_formatter = logging_meta_formatter),
        ) |> timestamp_logger,
        FlushLogger(
            ConsoleLogger(
                open("$(outdir)/log", write = true),
                Logging.Debug;
                meta_formatter = logging_meta_formatter,
            ),
        ) |> timestamp_logger,
    ),
);

# Alpha specific utils

In [None]:
Base.@kwdef struct RatingsDataset
    user::Vector{Int64}
    item::Vector{Int64}
    rating::Vector{Float64}
end;

function Base.adjoint(x::RatingsDataset)
    RatingsDataset(x.item, x.user, x.rating)
end;

function get_split(split)
    @assert split in ["training", "validation", "test"]
    file = "../../data/splits/splits.jld2"
    load(file, split)
end;

In [None]:
function Base.cat(x::RatingsDataset, y::RatingsDataset)
    RatingsDataset([x.user; y.user], [x.item; y.item], [x.rating; y.rating])
end;

In [None]:
@memoize function num_items()
    df = DataFrame(CSV.File("../../data/processed_data/anime_to_uid.csv"))
    length(df.uid)
end;

In [None]:
function get_alpha(alpha, split)
    @assert split in ["training", "validation", "test"]
    file = "../../data/alphas/$(alpha)/predictions.jld2"
    load(file, split)
end;

In [None]:
@memoize function get_residual_β(alphas)
    # train a linear model on the validation set
    y = get_split("validation").rating
    X = zeros(length(y), length(alphas))
    @tprogress Threads.@threads for j = 1:length(alphas)
        X[:, j] = get_alpha(alphas[j], "validation").rating
    end
    X \ y    
end

function get_residuals(split, alphas)
    # residualize out the linear model
    β = get_residual_β(alphas)        
    df = get_split(split)
    ratings = df.rating
    @showprogress for j = 1:length(alphas)
        ratings -= β[j] * get_alpha(alphas[j], split).rating
    end
    RatingsDataset(df.user, df.item, ratings)
end

In [None]:
function write_predictions(
    model;
    save_training = false, # TODO default to true
    outdir = name,
    residual_alphas = residual_alphas,
)
    splits = ["training", "validation", "test"]
    # don't save training set by default because it's huge
    splits_to_save = ["validation", "test"]
    # don't cheat by peeking at the test set
    splits_to_log = ["validation"]
    if save_training
        push!(splits_to_save, "training")
        push!(splits_to_log, "training")        
    end

    predictions = Dict()
    for split in splits
        df = get_residuals(split, residual_alphas)
        truth = df.rating
        pred = model(df.user, df.item)
        β = pred \ truth
        if split in splits_to_log
            @info "$(split) set: RMSE $(rmse(truth, β*pred)) MAE $(mae(truth, β*pred)) R2 $(r2(truth, β*pred))"
        end
        if split in splits_to_save
            predictions[split] = RatingsDataset(df.user, df.item, pred)
        end
    end

    outdir = mkpath("../../data/alphas/$outdir")
    save("$outdir/predictions.jld2", predictions)
end;

In [None]:
function write_params(params; outdir = name)
    outdir = mkpath("../../data/alphas/$outdir")
    save("$outdir/params.jld2", params)
end;

In [None]:
function read_params(alpha)
    load("../../data/alphas/$alpha/params.jld2")
end;

In [None]:
Base.@kwdef mutable struct convergence_stopper
    tolerance::AbstractFloat
    max_iters = Inf
    params::AbstractVector
    prev_params::AbstractVector
    iters = 0
end

function convergence_stopper(tolerance)
    convergence_stopper(tolerance = tolerance, params = [], prev_params = [])
end

function stop!(x::convergence_stopper, params)
    x.iters += 1
    if x.iters > x.max_iters
        return true
    end

    if x.iters == 1
        x.params = deepcopy(params)
        return false
    end

    function maxabs(a)
        maximum(abs.(a))
    end

    x.prev_params = deepcopy(x.params)
    x.params = deepcopy(params)
    maximum(maxabs.(x.params - x.prev_params)) < x.tolerance
end;

In [None]:
Base.@kwdef mutable struct early_stopper
    max_iters::Int
    patience = Inf
    min_rel_improvement = -Inf
    iters = 0
    iters_without_improvement = 0
    loss = Inf
end

function stop!(x::early_stopper, loss)
    x.iters += 1
    if x.iters > x.max_iters
        return true
    end

    if loss < x.loss * (1 - x.min_rel_improvement)
        x.loss = loss
        x.iters_without_improvement = 0
    else
        x.iters_without_improvement += 1
    end
    x.iters_without_improvement >= x.patience
end;