# Neural Network Base Class
* This class contains infrastructure to train neural networks
* The following algorithms are implemented:
    * Baseline predictors
* The following algoirthms will be implemented
    * Item-based collaborative filtering
    * Matrix Factorization
    * Autoencoder

In [1]:
using Flux
using Random
using SparseArrays
using Statistics: var

import CUDA
import NBInclude: @nbinclude
import NLopt
import Setfield: @set
@nbinclude("Alpha.ipynb");

In [2]:
function device(x)
    gpu(x)
end

# efficiently convert a sparse cpu matrix into a dense CUDA array
function device(x::AbstractSparseArray)
    CUDA.functional() ? CUDA.CuArray(gpu(x)) : x
end;

In [3]:
CUDA.functional()

true

## Hyperparameters
* Contains all the information necessary to train a new model
* The important hyperparameters will tuned via a derivative-free optimizer

In [4]:
@with_kw struct Hyperparams
    # model
    implicit::Bool
    input_data::String
    model::String
    # batching
    batch_size::Int
    user_sampling_scheme::String # TODO convert to float32
    # optimizer
    learning_rate::Float32
    optimizer::String
    # training
    seed::UInt64
    num_users::Int
    # loss
    item_weight_decay::Float32
    regularization_params::Vector{Float32}
    residual_alphas::Vector{String}
    residual_beta::Float32
    user_weight_decay::Float32
end

function to_dict(x::Hyperparams)
    Dict(string(key) => getfield(x, key) for key ∈ fieldnames(Hyperparams))
end

function Base.string(x::Hyperparams)
    fields = [x for x in fieldnames(Hyperparams)]
    max_field_size = maximum(length(string(k)) for k in fields)
    ret = "Hyperparameters:\n"
    for f in fields
        ret *= "$(rpad(string(f), max_field_size)) => $(getfield(x, f))\n"
    end
    ret
end;

## Models
* To define a new model, add the architecture to `build_model` and the regularization to `regularization_loss`

In [5]:
# A layer that takes one input and splits it into many
struct Split{T}
    paths::T
end
Split(paths...) = Split(paths)
Flux.@functor Split
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)

# A layer that takes many inputs and joins them into one
Join(combine, paths) = Parallel(combine, paths)
Join(combine, paths...) = Join(combine, paths);

In [6]:
# A layer that adds a 1-D vector to the input
struct BiasLayer
    b::Any
end
BiasLayer(n::Integer; init = randn) = BiasLayer(init(Float32, n))
(m::BiasLayer)(x) = x .+ m.b
Flux.@functor BiasLayer

In [7]:
# Implements a baseline predictor given by R[i, j] = u[i] + a[j]
function user_item_biases()
    U = Flux.Embedding(G.num_users => 1)
    A = BiasLayer(num_items())
    m = Chain(U, A)
end

# regularization is λ_u variance(u) + λ_a variance(a)
function user_item_biases_regularization(m)
    var(m[1].weight) * G.regularization_params[1] + var(m[2].b) * G.regularization_params[2]
end;

In [8]:
function build_model()
    if G.model == "user_item_biases"
        return user_item_biases()
    end
    @assert false
end

function regularization_loss(m)
    if G.model == "user_item_biases"
        return user_item_biases_regularization(m)
    end
    @assert false
end;

# Data Preprocessing
* An epoch is an efficient representation of all the models inputs, outputs, residualization, and weights
* We generate one epoch per split and memoize them

In [9]:
function one_hot_inputs(implicit, num_users)
    collect(1:num_users)
end;

In [10]:
@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_inputs(
    input_data,
    implicit,
    num_users,
)
    if input_data == "one_hot"
        X, Y = one_hot_inputs(implicit, num_users)
    else
        @assert false
    end
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_outputs(split, implicit, num_users)
    sparse(filter_users(get_split(split, implicit), num_users))
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_residuals(
    split,
    residual_alphas,
    implicit,
    num_users,
)
    residuals = filter_users(read_alpha(residual_alphas, split, implicit), num_users)
    sparse(residuals)
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_weights(
    split,
    user_weight_decay,
    item_weight_decay,
    implicit,
    num_users,
)
    if split == "training"
        weights =
            expdecay(get_counts(split, implicit), user_weight_decay) .*
            expdecay(get_counts(split, implicit; by_item = true), item_weight_decay)
    else
        weights = expdecay(get_counts(split, implicit), weighting_scheme("inverse"))
    end

    df = get_split(split, implicit)
    df = filter_users(RatingsDataset(df.user, df.item, weights), num_users)
    sparse(df)
end;

In [11]:
# returns (X, Y, Z, W) = (inputs, outputs, residualization alpha, weights)
function get_epoch(split)
    X = get_epoch_inputs(G.input_data, G.implicit, G.num_users)
    Y = get_epoch_outputs(split, G.implicit, G.num_users)
    Z = get_epoch_residuals(split, G.residual_alphas, G.implicit, G.num_users)
    W = get_epoch_weights(
        split,
        G.user_weight_decay,
        G.item_weight_decay,
        G.implicit,
        G.num_users,
    )
    X, Y, Z, W
end;

# Batching
* Turns an epoch into minibatches
* Each data point will appear in a minibatch with a probability proportional to its sampling weight

In [12]:
function SparseArrays.sparse(split::RatingsDataset)
    sparse(split.item, split.user, split.rating, num_items(), G.num_users)
end;

In [13]:
function slice(x::AbstractVector, range)
    x[range]
end

function slice(x::AbstractMatrix, range)
    x[:, range]
end;

In [14]:
function get_sampling_order(split)
    weighting_scheme = split == "training" ? G.user_sampling_scheme : "constant"
    if weighting_scheme == "constant"
        return shuffle(1:G.num_users)
    else
        weights = expdecay(
            get_counts(split; per_rating = false),
            weighting_scheme(G.user_sampling_scheme),
        )
        return sample(1:G.num_users, Weights(weights), G.num_users)
    end
end;

In [15]:
# performs the following steps
# 1) shuffle the epoch by the sampling order
# 2) split the epoch into minibatches of size batch_size
# 3) return the iter-th minibatch
function get_batch(epoch, iter, batch_size, sampling_order)
    sampling_order = 1:G.num_users
    range = sampling_order[(iter-1)*batch_size+1:min(iter * batch_size, G.num_users)]
    process(x) = slice(x, range) |> device
    [process.(epoch)], range
end;

function get_batch(epoch, iter, batch_size)
    sampling_order = 1:G.num_users
    get_batch(epoch, iter, batch_size, sampling_order)
end;

## Loss Functions
* The `model_loss` is either the crossentropy loss or squared error, depending on the input data
    * Note that we take the sum over all items, so using a bigger batchsize will have a bigger `model_loss`
* The `regularization_loss` depends on the model architecture, but is commonly an L2 loss
* During training, the `model_loss` is scaled by a function of the weight decays. This keeps the magnitude of the loss function approximately the same, even if the weight decay constats change
* The `split_loss` is either the weighted average crossentropy loss or weighted mean squared error, depending on the input datadepending on the input data

In [16]:
function model_loss(m, x, y, z, w)
    p = m(x)
    if G.implicit
        β = sigmoid(G.residual_beta)
        q = softmax(p) * (1 - β) + z .* β
        return sum(w .* -y .* log.(q))
    else
        q = p + z .* G.residual_beta
        return sum(w .* (q - y) .^ 2)
    end
end

function training_loss(m, x, y, z, w; model_loss_scale)
    model_loss(m, x, y, z, w) * model_loss_scale + regularization_loss(m)
end

function split_loss(m, split)
    epoch = get_epoch(split)
    loss = 0.0
    weights = 0.0
    for iter = 1:Int(ceil(G.num_users / G.batch_size))
        batch, _ = get_batch(epoch, iter, G.batch_size)
        loss += model_loss(m, batch[1]...)
        weights += sum(batch[1][end])
    end
    Float32(loss / weights)
end;

## Training
* Trains a neural network with the given hyperparameters

In [17]:
function get_optimizer(optimizer, learning_rate)
    if optimizer == "ADAM"
        return ADAMW(learning_rate, (0.9, 0.999), 0)
    elseif optimizer == "SGD"
        return Descent(learning_rate)
    else
        @assert false
    end
end;

In [18]:
function train_epoch!(m, ps, opt)
    LinearAlgebra.BLAS.set_num_threads(Threads.nthreads())
    epoch = get_epoch("training")
    sampling_order = get_sampling_order("training")
    # make the training loss invariant to the scale of the weight decays
    model_loss_scale = G.num_users / sum(epoch[4])
    batchloss(x, y, z, w) =
        training_loss(m, x, y, z, w; model_loss_scale = model_loss_scale)

    nbatches = Int(ceil(length(sampling_order) / G.batch_size))
    for iter = 1:nbatches
        batch, _ = get_batch(epoch, iter, G.batch_size, sampling_order)
        Flux.train!(batchloss, ps, batch, opt)
    end
end;

In [19]:
# trains a model with the given hyperparams and returns its validation loss
function train_model(hyp; max_checkpoints = 10, epochs_per_checkpoint = 10, patience = 0)
    global G = hyp
    opt = get_optimizer(G.optimizer, G.learning_rate)
    Random.seed!(G.seed)
    m = build_model() |> device
    best_model = m |> cpu
    ps = Flux.params(m)
    stopper = early_stopper(max_iters = max_checkpoints, patience = patience)

    losses = []
    loss = Inf
    while (!stop!(stopper, loss))
        for i = 1:epochs_per_checkpoint
            train_epoch!(m, ps, opt)
        end
        loss = split_loss(m, "validation")
        push!(losses, loss)
        if loss == minimum(losses)
            best_model = m |> cpu
        end
    end
    global G = nothing
    best_model, minimum(losses)
end;

## Hyperparameter Tuning
* A derivative free optimizer is used to find the best hyperparameters

In [20]:
function num_tuneable_params(model)
    num_model_params = 4
    if model == "user_item_biases"
        num_sampling_params = 0
        num_regularization_params = 2
    else
        @assert false
    end
    num_model_params, num_sampling_params, num_regularization_params
end

function create_hyperparams(hyp, λ)
    _, num_sampling_params, num_regularization_params = num_tuneable_params(hyp.model)
    hyp = @set hyp.learning_rate = 0.01 * exp(λ[1])
    hyp = @set hyp.residual_beta = hyp.implicit ? λ[2] : 1 + λ[2]
    hyp = @set hyp.user_weight_decay = λ[3]
    hyp = @set hyp.item_weight_decay = λ[4]
    if num_sampling_params == 1
        hyp = @set hyp.user_sampling_scheme = λ[5]
    end
    hyp = @set hyp.regularization_params = exp.(λ[end-num_regularization_params+1:end])
    hyp
end;

In [21]:
function optimize_hyperparams(hyp; max_evals)
    function nlopt_loss(λ, grad)
        # nlopt internally converts to float64 because it calls a c library
        λ = convert.(Float32, λ)
        _, loss = train_model(create_hyperparams(hyp, λ))
        @info "$λ $loss"
        loss
    end
    num_variables = sum(num_tuneable_params(hyp.model))
    opt = NLopt.Opt(:LN_NELDERMEAD, num_variables)
    opt.initial_step = 1
    opt.maxeval = max_evals
    opt.min_objective = nlopt_loss
    minf, λ, ret = NLopt.optimize(opt, zeros(Float32, num_variables))
    numevals = opt.numevals

    @info (
        "found minimum $minf at point $λ after $numevals function calls " *
        "(ended because $ret) and saved model at"
    )
    λ
end;

## Retrain User Embeddings
* To minimize training/serving skew, we train the model the same
  way we will train it during inference
* This means reinitializing the user embeddings, freezing all other layers,
  and fine-tuning the user embeddings
* During serving, we will determine a new user's embedding
  by training with the same hyperparameters and number of epochs

In [22]:
function retrain_user_embeddings(hyp, m)
    if hyp.model == "user_item_biases"
        global G = hyp
        opt = get_optimizer(G.optimizer, G.learning_rate)
        Random.seed!(G.seed)
        m = m |> cpu
        m[1].weight .= Flux.Embedding(hyp.num_users => 1).weight
        m = m |> device
        best_model = m |> cpu
        ps = Flux.params(m[1])
        stopper = early_stopper(max_iters = 100, patience = 10)

        losses = []
        loss = Inf
        while (!stop!(stopper, loss))
            train_epoch!(m, ps, opt)
            loss = split_loss(m, "validation")
            push!(losses, loss)
            if loss == minimum(losses)
                best_model = m |> cpu
            end
        end
        epochs = stopper.iters - stopper.iters_without_improvement
        global G = nothing
        return best_model, epochs
    else
        @assert false
    end
end;

## Write predictions

In [23]:
# returns a dict that maps a user to the list of items they have watched
function user_to_items(users, items)
    utoa = [Dict() for t = 1:Threads.nthreads()]
    @tprogress Threads.@threads for j = 1:length(users)
        u = users[j]
        a = items[j]
        t = Threads.threadid()
        if u ∉ keys(utoa[t])
            utoa[t][u] = []
        end
        push!(utoa[t][u], a)
    end
    merge(vcat, utoa...)
end;

In [24]:
# returns a ratings dataset of predicted ratings
function evaluate(hyp, m, users, items)
    # get model inputs
    global G = hyp
    m = m |> device
    utoa = user_to_items(users, items)
    epoch = [get_epoch_inputs(G.input_data, G.implicit, G.num_users)]
    activation = G.implicit ? softmax : identity

    # allocate outputs
    out_users = Array{eltype(users)}(undef, length(users))
    out_items = Array{eltype(items)}(undef, length(users))
    out_ratings = Array{Float32}(undef, length(users))
    out_idx = 1

    # compute predictions    
    @showprogress for iter = 1:Int(ceil(G.num_users / G.batch_size))
        batch, sampled_users = get_batch(epoch, iter, G.batch_size)
        alpha = activation(m(batch[1][1])) |> cpu
        for j = 1:length(sampled_users)
            u = sampled_users[j]
            if u in keys(utoa)
                item_mask = utoa[u]
                next_idx = out_idx + length(item_mask)
                out_users[out_idx:next_idx-1] .= u
                out_items[out_idx:next_idx-1] = item_mask
                out_ratings[out_idx:next_idx-1] = alpha[item_mask, j]
                out_idx = next_idx
            end
        end
    end

    global G = nothing
    RatingsDataset(user = out_users, item = out_items, rating = out_ratings)
end;

In [25]:
function write_alpha(hyp::Hyperparams, m, outdir)
    splits = reduce(cat, [get_split(split, hyp.implicit) for split in all_splits])
    preds = evaluate(hyp, m, splits.user, splits.item)
    sparse_preds = sparse(preds.user, preds.item, preds.rating)


    function model(users, items)
        r = zeros(length(users))
        @tprogress Threads.@threads for j = 1:length(r)
            r[j] = sparse_preds[users[j], items[j]]
        end
        r
    end

    write_alpha(model, hyp.residual_alphas, hyp.implicit; outdir = outdir)
end;

In [26]:
function train_alpha(hyp, outdir)
    # optimize hyperparameters
    @info "Optimizing hyperparameters..."    
    hyp_subset = @set hyp.num_users = Int(round(num_users() * 0.1))
    λ = optimize_hyperparams(hyp_subset; max_evals = 100)

    # train with the full dataset and with looser early-stopping rules
    @info "Training model..."        
    m, validation_loss = train_model(
        create_hyperparams(hp, λ);
        max_checkpoints = 100,
        epochs_per_checkpoint = 1,
        patience = 10,
    )
    @info "Trained model loss: $validation_loss"

    @info "Retraining user embeddings..."            
    hyp = create_hyperparams(hyp, λ)
    m, epochs = retrain_user_embeddings(hyp, m)

    @info "Writing alpha..."     
    write_params(
        Dict("m" => m, "epochs" => epochs, "λ" => λ, "hyp" => hyp),
        outdir = outdir,
    )    
    write_alpha(hyp, m, outdir)
end;

In [27]:
# hp = Hyperparams(
#     # model
#     implicit = false,
#     model = "user_item_biases",
#     # batching
#     # 1024 is the smallest batch size that saturates the gpu
#     batch_size = 1024,
#     input_data = "one_hot",
#     user_sampling_scheme = "constant",
#     # optimizer
#     learning_rate = 0.01,
#     optimizer = "ADAM",
#     # training
#     seed = 20220524,
#     num_users = num_users(),
#     # loss
#     item_weight_decay = 0,
#     regularization_params = Float32[1, 1],
#     residual_alphas = [],
#     residual_beta = 0,
#     user_weight_decay = 0,
# )

# λ = Float32[
#     -0.45587917457999283,
#     1.59948091680668,
#     -0.4203013509949342,
#     -0.025724981261365888,
#     4.4043160918290605,
#     2.42710246614073,
# ];

In [28]:
# params = read_params(name)
# hyp = params["hyp"]
# m = params["m"]

In [29]:
# write_alpha(hyp, m, name)