# Helper functions that are useful for generating alphas

In [None]:
import Flux: softmax
import JLD2
import LRUCache: LRU
import Memoize: @memoize
import NBInclude: @nbinclude
import Optim
import ProgressMeter: @showprogress
import Setfield
import Setfield: @set
import SparseArrays

## General utilities

In [None]:
@nbinclude("AlphaUtils.ipynb");

In [None]:
function get_data_path(file)
    path = pwd()
    while basename(path) != "notebooks"
        path = dirname(path)
    end
    path = dirname(path)
    "$path/data/$file"
end;

## Logging

In [None]:
if !@isdefined name
    name = "Alpha"
end
function set_logging_outdir(name)
    redirect_logging(get_data_path("alphas/$name"))
end
set_logging_outdir(name);

## Structs for handling ratings

In [None]:
# user[i] has seen item[i] and given it a score of rating[i]
@with_kw struct RatingsDataset
    user::Vector{Int32} = []
    item::Vector{Int32} = []
    rating::Vector{Float32} = []
    timestamp::Vector{Float32} = []
    user_timestamp::Vector{Float32} = []
    item_timestamp::Vector{Float32} = []
    status::Vector{Int32} = []
    completion::Vector{Float32} = []
    rewatch::Vector{Int32} = []
    source::Vector{Int32} = []
end;

# swap users with items
function Base.adjoint(x::RatingsDataset)
    RatingsDataset(
        x.item,
        x.user,
        x.rating,
        x.timestamp,
        x.user_timestamp,
        x.item_timestamp,
        x.status,
        x.completion,
        x.rewatch,
        x.source,
    )
end

# append two datasets
function Base.cat(x::RatingsDataset, y::RatingsDataset)
    for field in fieldnames(RatingsDataset)
        a = length(getfield(x, field)) != 0
        b = length(getfield(y, field)) != 0
        @assert a == b "cat: mismatched sizes in $field ($a != $b)"
    end
    RatingsDataset(
        [x.user; y.user],
        [x.item; y.item],
        [x.rating; y.rating],
        [x.timestamp; y.timestamp],
        [x.user_timestamp; y.user_timestamp],
        [x.item_timestamp; y.item_timestamp],
        [x.status; y.status],
        [x.completion; y.completion],
        [x.rewatch; y.rewatch],
        [x.source; y.source],
    )
end

function Base.filter(x::RatingsDataset, mask::BitVector)
    filter_array(a) = length(a) > 0 ? a[mask] : []
    RatingsDataset(
        filter_array(x.user),
        filter_array(x.item),
        filter_array(x.rating),
        filter_array(x.timestamp),
        filter_array(x.user_timestamp),
        filter_array(x.item_timestamp),
        filter_array(x.status),
        filter_array(x.completion),
        filter_array(x.rewatch),
        filter_array(x.source),
    )
end;

# returns the subset of the data whose userid is atmost userid
function filter_users(x::RatingsDataset, max_userid)
    mask = x.user .<= max_userid
    filter(x, x.user .<= max_userid)
end;

In [None]:
# Some sparse matrix operations require indices to be Int64
@with_kw struct RatingsDataset64
    user::Vector{Int64}
    item::Vector{Int64}
    rating::Vector{Float32}
end

function RatingsDataset64(x::RatingsDataset)
    RatingsDataset64(
        convert.(Int64, x.user),
        convert.(Int64, x.item),
        convert.(Float32, x.rating),
    )
end;

## Reading and writing data

In [None]:
@memoize function num_users()
    open(get_data_path("processed_data/uid_encoding.csv")) do file
        text = read(file, String)
        lines = split(text, '\n')
        fields = split(lines[1], ',')
        @assert fields[1] == "max_userid"
        max_userid = parse(Int, fields[2]) + 1
        return max_userid
    end
end

@memoize function num_items()
    open(get_data_path("processed_data/uid_encoding.csv")) do file
        text = read(file, String)
        lines = split(text, '\n')
        fields = split(lines[2], ',')
        @assert fields[1] == "max_itemid"
        max_itemid = parse(Int, fields[2]) + 1
        return max_itemid
    end
end;

In [None]:
# a split is a collection of (user, item) interactions that are stored as a RatingsDataset
function get_split(split::String, content::String; transpose::Bool = false, fields::Union{Nothing, Vector{Symbol}} = nothing)
    @assert split in all_splits && content in all_contents
    raw_split(split, content) = get_raw_split(split, content; fields = fields)
    if content == "explicit"
        df = raw_split(split, content)
    elseif content == "implicit"
        df = cat(raw_split(split, "explicit"), raw_split(split, "implicit"))
        df.rating .= 1
    elseif content == "ptw"
        df = raw_split(split, content)
        df.rating .= 1
    elseif content == "negative"
        df = raw_split(split, content)
        N = length(get_raw_split(split, content; fields = [:user]).user)
        df = @set df.rating = fill(0f0, N)
    else
        @assert false
    end
    transpose ? df' : df
end

function get_raw_split(split::String, content::String; fields::Union{Nothing, Vector{Symbol}} = nothing)
    @assert split in all_splits && content in all_contents
    file = get_data_path("splits/$(content)_$(split).jld2")
    df = JLD2.load(file, "dataset")
    # clear any columns you are not interested in
    if !isnothing(fields)
        for f in fieldnames(RatingsDataset)
            if f ∉ fields
                df = Setfield.set(df, Setfield.PropertyLens{f}(), [])
            end
        end
    end
    df
end

const all_splits = ["training", "validation", "test"]
const all_contents = ["explicit", "implicit", "negative", "ptw"];

In [None]:
# an alpha is a model that is used to predict whether a user will like an item.
# it's often useful to know an alpha model's value for a given (user, item) pair.
# alphas can be expensive to compute, so we precompute the model's values on
# (user-item) pairs and store the resultant RatingsDatasets to disk.
# storing the model values for all (user, item) pairs would be prohibitively
# large, so we only store values for our splits

function log_split_loss(
    model,
    alphas::Vector{String},
    content::String,
    implicit::Bool,
    splits_to_log::Vector{String};
    by_split = false,
)
    β = nothing
    for split in splits_to_log
        if by_split
            x = model(split, content; raw_splits = false)
        else
            df = get_split(split, content)
            x = model(df.user, df.item)
        end
        @assert length(x) == length(get_split(split, content; fields = [:user]).user)
        if isnothing(β)
            @assert split == "validation"
            _, β = regress(alphas, content, implicit, x)
        end
        @info "$split loss: $(residualized_loss(alphas, content, implicit, x, β, split)), β: $β"
    end
end

function write_alpha(
    model::Function,
    alphas::Vector{String}, # TODO make an optional argument b/c logging is optional
    implicit::Bool, # TODO make an optional argument b/c logging is optional
    outdir::String;
    log_splits::Union{Bool,String} = false, # TODO rename to log_contents
    splits_to_log::Vector{String} = ["validation", "training"], # TODO rename to log_splits
)
    @info "deprecated write_alpha"
    if log_splits in all_contents
        log_split_loss(model, alphas, log_splits, implicit, splits_to_log)
    elseif !log_splits
        @info "not logging split losses"
    else
        @assert false
    end

    predictions = Dict()
    for split in all_splits
        for content in all_contents
            df = get_raw_split(split, content; fields = [:user, :item])
            x = model(df.user, df.item)
            predictions["$(content)_$(split)"] = RatingsDataset(rating = x)
        end
    end

    outdir = mkpath(get_data_path("alphas/$outdir"))
    JLD2.save("$outdir/predictions.jld2", predictions)
end

function write_alpha(
    model::Function,
    outdir::String;
    by_split::Bool = false,
    log::Bool = true,
    log_content::Union{String,Nothing} = nothing,        
    log_alphas::Vector{String} = String[],
    log_splits::Vector{String} = ["validation", "training"],
)
    if log
        if log_content == "explicit"
            implicit = false
        elseif log_content in ["implicit", "ptw"]
            implicit = true
        else
            @assert false
        end
        log_split_loss(
            model,
            log_alphas,
            log_content,
            implicit,
            log_splits;
            by_split = by_split,
        )
    else
        @info "not logging split losses"
    end

    predictions = Dict()
    for split in all_splits
        for content in all_contents
            if by_split
                x = model(split, content; raw_splits=true)
            else
                df = get_raw_split(split, content; fields = [:user, :item])
                x = model(df.user, df.item)
            end
            @assert length(x) == length(get_raw_split(split, content; fields = [:user]).user)
            predictions["$(content)_$(split)"] = RatingsDataset(rating = x)
        end
    end

    outdir = mkpath(get_data_path("alphas/$outdir"))
    JLD2.save("$outdir/predictions.jld2", predictions)
end

function write_alpha(
    p::SparseArrays.AbstractSparseMatrix,
    outdir;
    alphas::Vector{String} = String[],
    implicit::Union{Bool,Nothing} = nothing,
    log_splits::Union{Bool,String} = false, # TODO rename to log_contents
    splits_to_log::Vector{String} = ["validation", "training"], # TODO rename to log_splits
)
    @info "deprecated write_alpha"    
    function model(users, items)
        r = zeros(length(users))
        @tprogress Threads.@threads for j = 1:length(r)
            r[j] = p[users[j], items[j]]
        end
        r
    end
    write_alpha(
        model,
        alphas,
        implicit,
        outdir;
        log_splits = log_splits,
        splits_to_log = splits_to_log,
    )
end

function read_raw_alpha_impl(alpha::String, split::String, content::String)
    file = get_data_path("alphas/$(alpha)/predictions.jld2")
    JLD2.load(file, "$(content)_$(split)")
end


function read_raw_alpha(alpha::String, split::String, content::String)
    # allow read_alpha to be overridden by hiding the implementation
    read_raw_alpha_impl(alpha, split, content)
end

function read_alpha(alpha::String, split::String, content::String)
    @assert split in all_splits && content in all_contents
    if content == "explicit"
        p = read_raw_alpha(alpha, split, content)
    elseif content == "implicit"
        p = cat(
            read_raw_alpha(alpha, split, "explicit"),
            read_raw_alpha(alpha, split, "implicit"),
        )
    elseif content == "ptw"
        p = read_raw_alpha(alpha, split, "ptw")
    elseif content == "negative"
        p = read_raw_alpha(alpha, split, content)
    else
        @assert false
    end
    df = get_split(split, content; fields = [:user, :item])
    @assert length(df.user) == length(p.rating)
    RatingsDataset(user = df.user, item = df.item, rating = p.rating)
end;

In [None]:
# params consist of two things:
# 1) the hyperparameters that are used to train an alpha
# 2) the trained parameters of an alpha 
# in general, params should contain all information necessary to
# efficiently train the alpha model for a new user

function write_params(params, outdir)
    outdir = mkpath(get_data_path("alphas/$outdir"))
    JLD2.save("$outdir/params.jld2", params)
end

function read_params(alpha)
    JLD2.load(get_data_path("alphas/$alpha/params.jld2"))
end;

## Weight decays

In [None]:
# the number of items a user has seen spans orders of magnitude.
# if we place an equal weight on each (user, item) pair, then highly
# active users will skew the loss function. It is generally best practice
# to weight each user equally (see Deep Neural Networks for YouTube Recommendations 
# https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45530.pdf]).
# we achieve this by weighting our validation and test loss functions such that each
# (user, item) pair has a weight of 1 / |number of items the user has seen|.
# 
# during training, the same skew issue appears. instead of weighting each (user, item)
# pair by 1 / |number of items the user has seen|, we take the more general approach
# of weighting each (user, item) pair by 
# |number of items the user has seen| ^ w_u * |number of users that have seen the item| ^ w_a
# when w_u=-1 and w_a=0, then we recover the equal-user weighting that we used for
# validation and test sets.

function powerdecay(x, a)
    x == 0 ? zero(eltype(a)) : sign(x) * abs(x)^a
end

function powerdecay(x::Vector, a)
    y = Array{eltype(a)}(undef, length(x))
    Threads.@threads for i = 1:length(x)
        @inbounds y[i] = powerdecay(x[i], a)
    end
    y
end

function weighting_scheme(scheme::Number)
    scheme
end

function weighting_scheme(scheme::String)
    if scheme == "linear"
        return 1.0f0
    elseif scheme == "constant"
        return 0.0f0
    elseif scheme == "inverse"
        return -1.0f0
    else
        @assert false
        return 0.0f0
    end
end;

function get_user_counts(split::RatingsDataset)
    counts = zeros(eltype(split.rating), num_users(), Threads.nthreads())
    Threads.@threads for i = 1:length(split.user)
        @inbounds counts[split.user[i], Threads.threadid()] += 1
    end
    vec(sum(counts, dims = 2))
end

@memoize function get_counts(
    split::String,
    content::String;
    per_rating::Bool = true,
    by_item::Bool = false,
)
    split = get_split(split, content; transpose = by_item)
    user_counts = get_user_counts(split)

    if !per_rating
        return user_counts
    end

    counts = Array{eltype(user_counts)}(undef, length(split.user))
    Threads.@threads for i = 1:length(counts)
        @inbounds counts[i] = user_counts[split.user[i]]
    end
    counts
end

function get_weights(split::String, content::String, scheme::String)
    powerdecay(get_counts(split, content), weighting_scheme(scheme))
end;

## Loss functions

In [None]:
# most alphas can be classified as one of two types:
# 1) explicit alphas predict what rating a user will give to
#    a given show conditional on having watched the show. these 
#    alphas are trained using mean squared error
# 2) implicit alphas predict whether a user will watch a
#    a given show. these alphas are trained using
#    cross-entropy loss
# both loss functions are weighted according to the weight decay
# logic described above

function weighted_loss(x, y, w, lossfn)
    sum(lossfn(x, y) .* w) / sum(w)
end

function weighted_loss_multithreaded(x, y, w, lossfn)
    a = Array{eltype(x)}(undef, Threads.nthreads())
    b = Array{eltype(w)}(undef, Threads.nthreads())
    Threads.@threads for t = 1:Threads.nthreads()
        range = thread_range(length(x))
        # Base.sum uses pairwise summation which is important for accuracy
        @views weight = sum(w[range])
        @views a[Threads.threadid()] =
            weighted_loss(x[range], y[range], w[range], lossfn) * weight
        b[Threads.threadid()] = weight
    end
    sum(a) / sum(b)
end

function error(x, y, w, implicit)
    lossfn = implicit ? (x, y) -> -y .* log.(x) : (x, y) -> (x - y) .^ 2
    weighted_error(x, y, w, lossfn)
end

function loss(x, y, w, implicit; normalize = true, multithreaded = false)
    if implicit
        lossfn = (x, y) -> -y .* log.(x)
    else
        lossfn = (x, y) -> (x - y) .^ 2
    end
    if normalize
        if multithreaded
            evaluator = weighted_loss_multithreaded
        else
            evaluator = weighted_loss
        end
    else
        evaluator = weighted_unnormalized_loss
    end
    evaluator(x, y, w, lossfn)
end;

## Regressions

In [None]:
# given a matrix of features X, a vector of true labels y, and
# a vector of weights w, a regression will find the β
# that minimizes the weighted between X * β and y. for explicit
# alphas, the loss is mean squared error and there is a closed
# form solution. for implicit alphas, the loss is cross-entropy
# and we solve for β numerically.

function regress(X, y, w, implicit::Bool)
    if implicit
        β = softmax(
            Optim.minimizer(
                Optim.optimize(
                    β -> loss(X * softmax(β), y, w, implicit; multithreaded = true),
                    fill(0.0f0, size(X)[2]),
                    Optim.NewtonTrustRegion(),
                    autodiff = :forward,
                    Optim.Options(g_tol = 1e-6, iterations = 100),
                ),
            ),
        )
    else
        Xw = (X .* sqrt.(w))
        yw = (y .* sqrt.(w))
        β = Xw'Xw \ Xw'yw
    end
    X * β, β
end;

In [None]:
# regress the given features on the validation set
function regress(alphas::Vector{String}, content::String, implicit::Bool, x = nothing)
    split = "validation"
    X = regression_features(alphas, split, content, implicit, x)
    y = get_split(split, content).rating
    w = get_weights(split, content, "inverse")
    regress(X, y, w, implicit)
end

function regress(alphas::Vector{String}, content::String, implicit::Bool)
    regress(alphas, content, implicit, nothing)
end

# concatenates x, if given, with the alphas
function regression_features(
    alphas::Vector{String},
    split::String,
    content::String,
    implicit::Bool,
    x = nothing,
)
    ncols = length(alphas) + (isnothing(x) ? 0 : 1) + implicit
    shape = isnothing(x) ? get_split(split, content).rating : x
    X = Array{eltype(shape),2}(undef, length(shape), ncols)
    @showprogress for j = 1:length(alphas)
        @inbounds X[:, j] = read_alpha(alphas[j], split, content).rating
    end

    if implicit
        # add a baseline feature for non-degeneracy
        X[:, length(alphas)+1] .= 1.0f0 / num_items()
    end
    if !isnothing(x)
        X[:, end] = x
    end
    X
end

# linearly combinine the given alphas
function read_alpha(alphas::Vector{String}, split::String, content::String, implicit::Bool)
    df = get_split(split, content)
    _, β = regress(alphas, content, implicit)
    X = regression_features(alphas, split, content, implicit)
    RatingsDataset(user = df.user, item = df.item, rating = X * β)
end

# performs a regression on the validation set and then 
# calculates the validation loss of that linear combination
function residualized_loss(alphas::Vector{String}, content::String, implicit::Bool, x)
    split = "validation"
    x, β = regress(alphas, content, implicit, x)
    y = get_split(split, content).rating
    loss(x, y, get_weights(split, content, "inverse"), implicit; multithreaded = true)
end

function residualized_loss(
    alphas::Vector{String},
    content::String,
    implicit::Bool,
    x,
    β,
    split::String,
)
    X = regression_features(alphas, split, content, implicit, x)
    x = X * β
    y = get_split(split, content).rating
    loss(x, y, get_weights(split, content, "inverse"), implicit; multithreaded = true)
end;