# Common utitities for all alphas

In [19]:
# using CSV
using DataFrames
using FileIO
using Flux
using JLD2
using LinearAlgebra
using LoggingExtras
using Memoize
using Optim
using Parameters
using ProgressMeter
using SparseArrays
using Statistics

import Dates
import LineSearches
import JupyterFormatter

# General utils
* TODO split into own file

In [2]:
# TODO upstream this into the ProgressMeter
const SHOW_PROGRESS_BARS = parse(Bool, get(ENV, "JULIA_SHOW_PROGRESS_BARS", "true"))

macro tprogress(expr)
    # let the @progress macro work with Threads.@threads
    loop = expr
    if loop.head == :macrocall && loop.args[1] == :(Threads.var"@threads")
        loop = loop.args[end]
    end
    
    p = gensym()    
    r = loop.args[1].args[end]
    ex = quote
        n = Int(round(length($(esc(r))) / Threads.nthreads()))
        global $p = Progress(n; showspeed=true, enabled=SHOW_PROGRESS_BARS)
        $(esc(expr))
        finish!($p)
    end
    
    update = quote
        if Threads.threadid() == 1
            next!($p)
        end
    end
    push!(loop.args[end].args, update)    
    
    ex    
end;

In [3]:
JupyterFormatter.enable_autoformat();

In [4]:
# Prefer Julia multithreading to BLAS multithreading
BLAS.set_num_threads(1);

In [5]:
# stop training if the parameters have converged
@with_kw mutable struct convergence_stopper
    tolerance::AbstractFloat
    max_iters = Inf
    params::AbstractVector
    prev_params::AbstractVector
    iters = 0
end

function convergence_stopper(tolerance; max_iters = Inf)
    convergence_stopper(
        tolerance = tolerance,
        max_iters = max_iters,
        params = [],
        prev_params = [],
    )
end

function stop!(x::convergence_stopper, params)
    x.iters += 1
    if x.iters > x.max_iters
        return true
    end

    if x.iters == 1
        x.params = deepcopy(params)
        return false
    end

    function maxabs(a)
        maximum(abs.(a))
    end

    x.prev_params = deepcopy(x.params)
    x.params = deepcopy(params)
    maximum(maxabs.(x.params - x.prev_params)) < x.tolerance
end;

In [6]:
# stop training if the loss function has stopped decreasing
@with_kw mutable struct early_stopper
    max_iters = Inf
    patience = Inf
    min_rel_improvement = 0
    iters = 0
    iters_without_improvement = 0
    loss = NaN
end

function stop!(x::early_stopper, loss)
    x.iters += 1
    if x.iters > x.max_iters
        return true
    end

    if x.iters == 1
        x.loss = loss
        return false
    end

    if loss < x.loss * (1 - x.min_rel_improvement)
        x.loss = loss
        x.iters_without_improvement = 0
    else
        x.iters_without_improvement += 1
    end
    x.iters_without_improvement > x.patience
end;

In [7]:
# Logger that flushes after every log statement
struct FlushLogger <: AbstractLogger
    logger::ConsoleLogger
end

function FlushLogger(logger::AbstractLogger)
    FlushLogger(logger)
end

function Logging.handle_message(logger::FlushLogger, args...; kwargs...)
    Logging.handle_message(logger.logger, args...; kwargs...)
    flush(logger.logger.stream)
end

Logging.shouldlog(logger::FlushLogger, arg...) = Logging.shouldlog(logger.logger, arg...)
Logging.min_enabled_level(logger::FlushLogger) = Logging.min_enabled_level(logger.logger)
Logging.catch_exceptions(logger::FlushLogger) = Logging.catch_exceptions(logger.logger)

function logging_meta_formatter(level, _module, group, id, file, line)
    prefix_color = (
        level < Logging.Info ? 4 : level < Logging.Warn ? 6 : level < Logging.Error ? 3 : 1
    )
    prefix = (level == Logging.Warn ? "Warning" : string(level)) * ':'
    prefix_color, prefix, ""
end;

In [8]:
# Log to file and stdout at the same time
function redirect_logging(outdir)
    date_format = "yyyymmdd HH:MM:SS"
    timestamp_logger(logger) =
        TransformerLogger(logger) do log
            merge(log, (; message = "$(Dates.format(Dates.now(), date_format)) $(log.message)"))
        end

    outdir = mkpath(outdir)
    global_logger(
        TeeLogger(
            FlushLogger(
                ConsoleLogger(
                    stderr,
                    Logging.Info;
                    meta_formatter = logging_meta_formatter,
                ),
            ) |> timestamp_logger,
            FlushLogger(
                ConsoleLogger(
                    open("$(outdir)/log", write = true),
                    Logging.Info;
                    meta_formatter = logging_meta_formatter,
                ),
            ) |> timestamp_logger,
        ),
    )
end;

In [9]:
# A custom split layer for Flux
struct Split{T}
    paths::T
end
Split(paths...) = Split(paths)
Flux.@functor Split
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)

# Alpha specific utils

In [10]:
redirect_logging("../../data/alphas/$name");

In [11]:
@with_kw struct RatingsDataset
    user::Vector{Int32}
    item::Vector{Int32}
    rating::Vector{Float32}
end

function Base.adjoint(x::RatingsDataset)
    RatingsDataset(x.item, x.user, x.rating)
end

function get_split(split; implicit = false, transpose = false)
    @assert split in ["training", "validation", "test", "implicit", "implicit_training"]
    file = "../../data/splits/splits.jld2"
    df = load(file, split)
    if implicit
        df.rating .= 1
    end
    transpose ? df' : df
end

function get_alpha(alpha, split)
    @assert split in ["training", "validation", "test"]
    file = "../../data/alphas/$(alpha)/predictions.jld2"
    load(file, split)
end;

In [12]:
# Some sparse matrix operations require indices to be Int64
@with_kw struct RatingsDataset64
    user::Vector{Int64}
    item::Vector{Int64}
    rating::Vector{Float32}
end;

function RatingsDataset64(x::RatingsDataset)
    RatingsDataset64(
        convert.(Int64, x.user),
        convert.(Int64, x.item),
        convert.(Float32, x.rating),
    )
end;

In [13]:
function Base.cat(x::RatingsDataset, y::RatingsDataset)
    RatingsDataset([x.user; y.user], [x.item; y.item], [x.rating; y.rating])
end;

In [14]:
@memoize function num_items()
    maximum(get_split("training").item)
end

@memoize function num_users()
    maximum(get_split("training").user)
end;

## Weight decays

In [15]:
function safe_exp(x, a)
    return x == 0 ? zero(eltype(a)) : x^a
end;

function weighting_scheme(scheme::String)
    if scheme == "linear"
        return 1
    elseif scheme == "constant"
        return 0
    elseif scheme == "inverse"
        return -1
    else
        @assert false
        return 0
    end
end;

function get_user_counts(split::RatingsDataset)
    counts = zeros(eltype(split.rating), maximum(split.user), Threads.nthreads())
    @tprogress Threads.@threads for i = 1:length(split.rating)
        counts[split.user[i], Threads.threadid()] += 1
    end
    vec(sum(counts, dims = 2))
end

@memoize function get_counts(split; per_rating = true, by_item = false)
    split = get_split(split; transpose = by_item)
    user_counts = get_user_counts(split)

    if !per_rating
        return user_counts
    end

    counts = zeros(eltype(user_counts), length(split.user))
    Threads.@threads for i = 1:length(counts)
        counts[i] = user_counts[split.user[i]]
    end
    counts
end

function get_weights(split, scheme::String)
    safe_exp.(get_counts(split), weighting_scheme(scheme))
end;

## Loss functions and regressions

In [23]:
function weighted_crossentropy_loss(x, y, w)
    a = zeros(eltype(x), Threads.nthreads())
    b = zeros(eltype(w), Threads.nthreads())
    Threads.@threads for i = 1:length(x)
        a[Threads.threadid()] += -y[i] * log(x[i]) * w[i]
        b[Threads.threadid()] += w[i]
    end
    sum(a) / sum(b)
end

function weighted_mean_squared_error(x, y, w)
    a = zeros(eltype(x), Threads.nthreads())
    b = zeros(eltype(w), Threads.nthreads())
    Threads.@threads for i = 1:length(x)
        a[Threads.threadid()] +=  (x[i] - y[i]) ^ 2 * w[i]
        b[Threads.threadid()] += w[i]
    end
    sum(a) / sum(b)
end

loss(x, y, w, implicit) =
    implicit ? weighted_crossentropy_loss(x, y, w) : weighted_mean_squared_error(x, y, w)

# returns the linear combination that minimizes the loss
# For explicit data, there is a closed form solution
function regress(X, y, w, implicit)
    if implicit
        β = softmax(
            Optim.minimizer(
                optimize(
                    β -> loss(X * softmax(β), y, w, implicit),
                    fill(0.0f0, size(X)[2]),
                    BFGS(),
                    autodiff = :forward,
                ),
            ),
        )
        return β, X * β
    else
        β = (X .* sqrt.(w)) \ (y .* sqrt.(w))
        return X * β, β
    end
end

# returns the linear combination that minimizes the validation loss
function regress(x, alphas, implicit)
    split = "validation"
    X = zeros(eltype(x), length(x), length(alphas) + 1)
    @tprogress Threads.@threads for j = 1:length(alphas)
        X[:, j] = get_alpha(alphas[j], split).rating
    end
    X[:, end] .= x
    y = get_split(split; implicit = implicit).rating
    w = get_weights(split, "inverse")
    regress(X, y, w, implicit)
end

# returns the minimum loss obtainable by linearly combining the alphas
function residualized_loss(x, alphas, implicit, split)
    x, _ = regress(x, alphas, implicit)
    y = get_split(split; implicit = implicit).rating
    loss(x, y, get_weights(split, "inverse"), implicit)
end

# linearly combinine the given alphas
function get_alpha(alphas::Vector{String}, split::String, implicit)
    df = get_split(split; implicit = implicit)
    baseline = implicit ? 1.0f0 / num_items() : 0.0f0
    x, _ = regress(fill(baseline, length(get_split("validation").rating)), alphas, implicit)
    df.rating .= x
    df
end;

## Saving data

In [17]:
function write_predictions(model, alphas, implicit; outdir = name)
    splits_to_save = ["training", "validation", "test"]
    splits_to_log = ["training", "validation"]

    predictions = Dict()
    for split in splits_to_save
        df = residualize(split, alphas, implicit)
        x = model(df.user, df.item)
        predictions[split] = RatingsDataset(df.user, df.item, x)
        if split in splits_to_log
            @info "$(split) loss: $(residualized_loss(x, alphas, implicit, split))"
        end
    end

    outdir = mkpath("../../data/alphas/$outdir")
    save("$outdir/predictions.jld2", predictions)
end;

In [18]:
function write_params(params; outdir = name)
    outdir = mkpath("../../data/alphas/$outdir")
    save("$outdir/params.jld2", params)
end

function read_params(alpha)
    load("../../data/alphas/$alpha/params.jld2")
end;