In [None]:
import NBInclude: @nbinclude
@nbinclude("Alpha.ipynb");

In [None]:
import LinearAlgebra
import NNlib: softmax
import Optim
import ProgressMeter: @showprogress, next!
import Random
import SHA
import StatsBase

In [None]:
# # Prefer Julia multithreading to BLAS multithreading
# LinearAlgebra.BLAS.set_num_threads(1);

## Splits

In [None]:
function seed_rng!(salt::String)
    init = first(read_csv(get_data_path("rng.csv")).seed)
    seed = first(reinterpret(UInt64, SHA.sha256(init * salt)))
    Random.seed!(seed)
end;

In [None]:
function input_output_split(
    df::RatingsDataset,
    ts_cutoff::Float64,
    max_output_items::Int,
    output_newest::Bool,
)
    N = maximum([length(getfield(df, x)) for x in fieldnames(typeof(df))])
    for x in [:userid, :medium, :updated_at, :update_order]
        @assert length(getfield(df, x)) == N
    end

    userid = nothing
    num_output = [0 for _ in ALL_MEDIUMS]
    input_mask = BitArray([false for _ = 1:length(df.userid)])
    output_mask = BitArray([false for _ = 1:length(df.userid)])
    order = sortperm(collect(zip(df.userid, df.updated_at, df.update_order)))
    if output_newest
        order = reverse(order)
    end
    @showprogress for i in order
        if userid != df.userid[i]
            userid = df.userid[i]
            num_output .= 0
        end
        if (df.updated_at[i] > ts_cutoff) && (num_output[df.medium[i]+1] < max_output_items)
            num_output[df.medium[i]+1] += 1
            output_mask[i] = true
        end
        if output_newest
            input_mask[i] = !output_mask[i]
        else
            input_mask[i] = df.updated_at[i] <= ts_cutoff
        end
    end
    subset(df, input_mask), subset(df, output_mask)
end;

## Loss functions

In [None]:
function loss(x, y, w, metric)
    safelog(x) = log(x .+ Float32(eps(Float64))) # so that log(0) doesn't NaN
    if metric == "rating"
        lossfn = (x, y) -> (x - y) .^ 2
    elseif metric in ["watch", "plantowatch"]
        lossfn = (x, y) -> -y .* safelog.(x)
    elseif metric == "drop"
        lossfn = (x, y) -> -(y .* safelog.(x) + (1 .- y) .* safelog.(1 .- x))
    else
        @assert false
    end
    sum(lossfn(x, y) .* w) / sum(w)
end;

In [None]:
# find β s.t. loss(X * β, y, w) is minimized
function regress(X, y, w, metric)
    if metric == "rating"
        Xw = (X .* sqrt.(w))
        yw = (y .* sqrt.(w))
        # prevent singular matrix
        λ = eps(Float32) * LinearAlgebra.I(size(Xw)[2])
        return (Xw'Xw + λ) \ Xw'yw
    elseif metric in ["watch", "plantowatch", "drop"]
        return softmax(
            Optim.minimizer(
                Optim.optimize(
                    β -> loss(X * softmax(β), y, w, metric),
                    fill(0.0f0, size(X)[2]),
                    Optim.LBFGS(),
                    autodiff = :forward,
                    Optim.Options(g_tol = 1e-6, iterations = 100),
                ),
            ),
        )
    else
        @assert false
    end
end;

In [None]:
function get_features(
    dataset::String,
    medium::String,
    metric::String,
    alphas::Vector{String},
)
    split = "test_output"
    df = as_metric(
        get_split(dataset, split, medium, [:userid, :itemid, :rating, :status]),
        metric,
    )
    y = df.metric
    counts = StatsBase.countmap(df.userid)
    w = Float32[1 / counts[x] for x in df.userid]

    inputs = [read_alpha(dataset, df.userid, df.itemid, x) for x in alphas]
    if metric in ["watch", "plantowatch"]
        push!(inputs, fill(1.0f0 / num_items(medium), length(y)))
    elseif metric == "drop"
        push!(inputs, fill(1.0f0, length(y)))
        push!(inputs, fill(0.0f0, length(y)))
    end
    X = hcat(inputs...)
    X, y, w
end;

In [None]:
function print_losses(medium::String, metric::String, alphas::Vector{String})
    β = regress(get_features("streaming", medium, metric, alphas)..., metric)
    for dataset in ALL_DATASETS
        X, y, w = get_features(dataset, medium, metric, alphas)
        val = loss(X * β, y, w, metric)
        @info "$dataset $medium $metric loss = $val"
    end
end;