In [None]:
import NBInclude: @nbinclude
@nbinclude("Alpha.ipynb");

In [None]:
# import Dates
# import JupyterFormatter
import LinearAlgebra
# import Logging
# import LoggingExtras
import NNlib: softmax
import Optim
# import ProgressMeter
# import ProgressMeter: @showprogress
import StatsBase

## Multi-threading

In [None]:
# # Prefer Julia multithreading to BLAS multithreading
# LinearAlgebra.BLAS.set_num_threads(1);

In [None]:
# # let the @progress macro work with Threads.@threads
# macro tprogress(expr)
#     loop = expr
#     if loop.head == :macrocall && loop.args[1] == :(Threads.var"@threads")
#         loop = loop.args[end]
#     end
    
#     p = gensym()    
#     r = loop.args[1].args[end]
#     ex = quote
#         n = Int(round(length($(esc(r))) / Threads.nthreads()))
#         global $p = ProgressMeter.Progress(n; showspeed=true)
#         $(esc(expr))
#         ProgressMeter.finish!($p)
#     end
    
#     update = quote
#         if Threads.threadid() == 1
#             ProgressMeter.next!($p)
#         end
#     end
#     push!(loop.args[end].args, update)    
    
#     ex    
# end;

In [None]:
# # like Threads.@threads except we can specify the number of threads
# function tforeach(f::Function, args, threads::Int)
#     @sync for (t, chunk) in Iterators.enumerate(
#         Iterators.partition(args, div(length(args), threads, RoundUp)),
#     )
#         Threads.@spawn begin
#             @showprogress enabled = (t == 1) for i in chunk
#                 f(i)
#             end
#         end
#     end
# end;

## Formatting

In [None]:
# JupyterFormatter.enable_autoformat();

## Logging

In [None]:
# Logging.disable_logging(Logging.Debug);

In [None]:
# # Logger that flushes after every log statement
# struct FlushLogger <: LoggingExtras.AbstractLogger
#     logger::LoggingExtras.ConsoleLogger
# end

# function FlushLogger(logger::LoggingExtras.AbstractLogger)
#     FlushLogger(logger)
# end

# function Logging.handle_message(logger::FlushLogger, args...; kwargs...)
#     Logging.handle_message(logger.logger, args...; kwargs...)
#     flush(logger.logger.stream)
# end

# Logging.shouldlog(logger::FlushLogger, arg...) = Logging.shouldlog(logger.logger, arg...)
# Logging.min_enabled_level(logger::FlushLogger) = Logging.min_enabled_level(logger.logger)
# Logging.catch_exceptions(logger::FlushLogger) = Logging.catch_exceptions(logger.logger)

# function logging_meta_formatter(level, _module, group, id, file, line)
#     prefix_color = (
#         level < Logging.Info ? 4 : level < Logging.Warn ? 6 : level < Logging.Error ? 3 : 1
#     )
#     prefix = (level == Logging.Warn ? "Warning" : string(level)) * ':'
#     prefix_color, prefix, ""
# end;

In [None]:
# # Log to file and stdout at the same time
# function redirect_logging(outdir; overwrite = true)
#     date_format = "yyyymmdd HH:MM:SS"
#     timestamp_logger(logger) =
#         LoggingExtras.TransformerLogger(logger) do log
#             merge(
#                 log,
#                 (; message = "$(Dates.format(Dates.now(), date_format)) $(log.message)"),
#             )
#         end

#     outdir = mkpath(outdir)
#     suffix = ""
#     if !overwrite
#         tries = 0
#         while ispath("$(outdir)/log$(suffix)")
#             tries += 1
#             suffix = ".$tries"
#         end
#     end
#     Logging.global_logger(
#         LoggingExtras.TeeLogger(
#             FlushLogger(
#                 LoggingExtras.ConsoleLogger(
#                     stderr,
#                     Logging.Info;
#                     meta_formatter = logging_meta_formatter,
#                 ),
#             ) |> timestamp_logger,
#             FlushLogger(
#                 LoggingExtras.ConsoleLogger(
#                     open("$(outdir)/log$(suffix)", write = true),
#                     Logging.Info;
#                     meta_formatter = logging_meta_formatter,
#                 ),
#             ) |> timestamp_logger,
#         ),
#     )
# end;

In [None]:
# function set_logging_outdir(name)
#     redirect_logging(get_data_path("alphas/$name"); overwrite = false)
# end;

## Settings

In [None]:
# @memoize function get_settings()
#     settings = Dict()
#     for f in ["default_settings", "private_settings"]
#         d = YAML.load_file(get_data_path("../environment/$f.yml"))
#         for (k, v) in d
#             settings[k] = v
#         end
#     end
#     settings
# end;

In [None]:
# function training_test_split(df::RatingsDataset)
#     c = get_settings()
#     if c["mode"] == "research"
#         ts_cutoff = days_in_timestamp_units(c["cutoff_days"] * 2)
#     else
#         ts_cutoff = days_in_timestamp_units(c["cutoff_days"])
#     end
#     test_mask =
#         (df.update_order .<= c["cutoff_interactions"]) .&& (df.updated_at .>= 1 - ts_cutoff)
#     filter(df, .!(test_mask)), filter(df, test_mask)
# end;

## Loss functions

In [None]:
function loss(x, y, w, metric)
    safelog(x) = log(x .+ Float32(eps(Float64))) # so that log(0) doesn't NaN
    if metric == "rating"
        lossfn = (x, y) -> (x - y) .^ 2
    elseif metric in ["watch", "plantowatch"]
        lossfn = (x, y) -> -y .* safelog.(x)
    elseif metric == "drop"
        lossfn = (x, y) -> -(y .* safelog.(x) + (1 .- y) .* safelog.(1 .- x))
    else
        @assert false
    end
    sum(lossfn(x, y) .* w) / sum(w)
end;

In [None]:
# find β s.t. loss(X * β, y, w) is minimized
function regress(X, y, w, metric)
    if metric == "rating"
        Xw = (X .* sqrt.(w))
        yw = (y .* sqrt.(w))
        # prevent singular matrix
        λ = eps(Float32) * LinearAlgebra.I(size(Xw)[2])
        return (Xw'Xw + λ) \ Xw'yw
    elseif metric in ["watch", "plantowatch", "drop"]
        return softmax(
            Optim.minimizer(
                Optim.optimize(
                    β -> loss(X * softmax(β), y, w, metric),
                    fill(0.0f0, size(X)[2]),
                    Optim.LBFGS(),
                    autodiff = :forward,
                    Optim.Options(g_tol = 1e-6, iterations = 100),
                ),
            ),
        )
    else
        @assert false
    end
end;

In [None]:
function get_features(
    dataset::String,
    medium::String,
    metric::String,
    alphas::Vector{String},
)
    split = "test_output"
    df = as_metric(
        get_split(dataset, split, medium, [:userid, :itemid, :rating, :status]),
        metric,
    )
    y = df.metric
    counts = StatsBase.countmap(df.userid)
    w = Float32[1 / counts[x] for x in df.userid]

    inputs = [read_alpha(dataset, df.userid, df.itemid, x) for x in alphas]
    if metric in ["watch", "plantowatch"]
        push!(inputs, fill(1.0f0 / num_items(medium), length(y)))
    elseif metric == "drop"
        push!(inputs, fill(1.0f0, length(y)))
        push!(inputs, fill(0.0f0, length(y)))
    end
    X = hcat(inputs...)
    X, y, w
end;

In [None]:
function print_losses(medium::String, metric::String, alphas::Vector{String})
    β = regress(get_features("streaming", medium, metric, alphas)..., metric)
    for dataset in ALL_DATASETS
        X, y, w = get_features(dataset, medium, metric, alphas)
        val = loss(X * β, y, w, metric)
        @info "$dataset $medium $metric loss = $val"
    end
end;