# Neural Network Base Class
* This class contains infrastructure to train neural networks
* The following algorithms are implemented:
    * Baseline predictors
* The following algoirthms will be implemented
    * Item-based collaborative filtering
    * Matrix Factorization
    * Autoencoder

In [1]:
using Flux
using Random
using SparseArrays
using Statistics: var

import BSON
import CUDA
import NBInclude: @nbinclude
import NLopt
import Setfield: @set
@nbinclude("Alpha.ipynb");

In [2]:
function device(x)
    gpu(x)
end

# efficiently convert a sparse cpu matrix into a dense CUDA array
function device(x::AbstractSparseArray)
    CUDA.functional() ? CUDA.CuArray(gpu(x)) : x
end;

In [3]:
CUDA.functional()

true

## Hyperparameters
* Contains all the information necessary to train a new model
* The important hyperparameters will tuned via a derivative-free optimizer

In [4]:
@with_kw struct Hyperparams
    # model
    implicit::Bool
    input_data::String
    model::String
    # batching
    batch_size::Int
    user_sampling_scheme::String # TODO convert to float32
    # optimizer
    learning_rate::Float32
    optimizer::String
    # training
    seed::UInt64
    num_users::Int
    # loss
    item_weight_decay::Float32
    regularization_params::Vector{Float32}
    residual_alphas::Vector{String}
    residual_beta::Float32
    user_weight_decay::Float32
end

function to_dict(x::Hyperparams)
    Dict(string(key) => getfield(x, key) for key ∈ fieldnames(Hyperparams))
end

function Base.string(x::Hyperparams)
    fields = [x for x in fieldnames(Hyperparams)]
    max_field_size = maximum(length(string(k)) for k in fields)
    ret = "Hyperparameters:\n"
    for f in fields
        ret *= "$(rpad(string(f), max_field_size)) => $(getfield(x, f))\n"
    end
    ret
end;

## Models
* To define a new model, add the architecture to `build_model` and the regularization to `regularization_loss`

In [5]:
# A layer that takes one input and splits it into many
struct Split{T}
    paths::T
end
Split(paths...) = Split(paths)
Flux.@functor Split
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)

# A layer that takes many inputs and joins them into one
Join(combine, paths) = Parallel(combine, paths)
Join(combine, paths...) = Join(combine, paths);

In [6]:
# A layer that adds a 1-D vector to the input
struct BiasLayer
    b::Any
end
BiasLayer(n::Integer; init = randn) = BiasLayer(init(Float32, n))
(m::BiasLayer)(x) = x .+ m.b
Flux.@functor BiasLayer

In [7]:
# Implements a baseline predictor given by R[i, j] = u[i] + a[j]
function user_item_biases()
    U = Flux.Embedding(G.num_users => 1)
    A = BiasLayer(num_items())
    m = Chain(U, A) |> device
end

# regularization is λ_u variance(u) + λ_a variance(a)
function user_item_biases_regularization(m)
    var(m[1].weight) * G.regularization_params[1] + var(m[2].b) * G.regularization_params[2]
end;

In [8]:
function build_model()
    if G.model == "user_item_biases"
        return user_item_biases()
    end
    @assert false
end

function regularization_loss(m)
    if G.model == "user_item_biases"
        return user_item_biases_regularization(m)
    end
    @assert false
end;

# Data Preprocessing
* An epoch is an efficient representation of all the models inputs, outputs, residualization, and weights
* We generate one epoch per split and memoize them

In [9]:
function one_hot_inputs(split, implicit, num_users)
    collect(1:num_users)
end;

In [10]:
@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_inputs(
    split,
    input_data,
    implicit,
    num_users,
)
    if input_data == "one_hot"
        X, Y = one_hot_inputs(split, implicit, num_users)
    else
        @assert false
    end
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_outputs(split, implicit, num_users)
    sparse(filter_users(get_split(split, implicit), num_users))
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_residuals(
    split,
    residual_alphas,
    implicit,
    num_users,
)
    residuals = filter_users(read_alpha(residual_alphas, split, implicit), num_users)
    sparse(residuals)
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_weights(
    split,
    user_weight_decay,
    item_weight_decay,
    implicit,
    num_users,
)
    if split == "training"
        weights =
            expdecay(get_counts(split, implicit), user_weight_decay) .*
            expdecay(get_counts(split, implicit; by_item = true), item_weight_decay)
    else
        weights = expdecay(get_counts(split, implicit), weighting_scheme("inverse"))
    end

    df = get_split(split, implicit)
    df = filter_users(RatingsDataset(df.user, df.item, weights), num_users)
    sparse(df)
end;

In [11]:
# returns (X, Y, Z, W) = (inputs, outputs, residualization alpha, weights)
function get_epoch(split)
    X = get_epoch_inputs(split, G.input_data, G.implicit, G.num_users)
    Y = get_epoch_outputs(split, G.implicit, G.num_users)
    Z = get_epoch_residuals(split, G.residual_alphas, G.implicit, G.num_users)
    W = get_epoch_weights(
        split,
        G.user_weight_decay,
        G.item_weight_decay,
        G.implicit,
        G.num_users,
    )
    X, Y, Z, W
end;

# Batching
* Turns an epoch into minibatches
* Each data point will appear in a minibatch with a probability proportional to its sampling weight

In [12]:
function SparseArrays.sparse(split::RatingsDataset)
    sparse(split.item, split.user, split.rating, num_items(), G.num_users)
end;

In [13]:
function slice(x::AbstractVector, range)
    x[range]
end

function slice(x::AbstractMatrix, range)
    x[:, range]
end;

In [14]:
function get_sampling_order(split)
    weighting_scheme = split == "training" ? G.user_sampling_scheme : "constant"
    if weighting_scheme == "constant"
        return shuffle(1:G.num_users)
    else
        weights = expdecay(
            get_counts(split; per_rating = false),
            weighting_scheme(G.user_sampling_scheme),
        )
        return sample(1:G.num_users, Weights(weights), G.num_users)
    end
end;

In [15]:
# performs the following steps
# 1) shuffle the epoch by the sampling order
# 2) split the epoch into minibatches of size batch_size
# 3) return the iter-th minibatch
function get_batch(epoch, iter, batch_size, sampling_order)
    sampling_order = 1:G.num_users
    range = sampling_order[(iter-1)*batch_size+1:min(iter * batch_size, G.num_users)]
    process(x) = slice(x, range) |> device
    [process.(epoch)], range
end;

function get_batch(epoch, iter, batch_size)
    sampling_order = 1:G.num_users
    get_batch(epoch, iter, batch_size, sampling_order)
end;

## Loss Functions
* The `model_loss` is either the crossentropy loss or squared error, depending on the input data
    * Note that we take the sum over all items, so using a bigger batchsize will have a bigger `model_loss`
* The `regularization_loss` depends on the model architecture, but is commonly an L2 loss
* During training, the `model_loss` is scaled by a function of the weight decays. This keeps the magnitude of the loss function approximately the same, even if the weight decay constats change
* The `split_loss` is either the weighted average crossentropy loss or weighted mean squared error, depending on the input datadepending on the input data

In [16]:
function model_loss(m, x, y, z, w)
    p = m(x)
    if G.implicit
        β = sigmoid(G.residual_beta)
        q = softmax(p) * (1 - β) + z .* β
        return sum(w .* -y .* log.(q))
    else
        q = p + z .* G.residual_beta
        return sum(w .* (q - y) .^ 2)
    end
end

function training_loss(m, x, y, z, w; model_loss_scale)
    model_loss(m, x, y, z, w) * model_loss_scale + regularization_loss(m)
end

function split_loss(m, split)
    epoch = get_epoch(split)
    loss = 0.0
    weights = 0.0
    for iter = 1:Int(ceil(G.num_users / G.batch_size))
        batch, _ = get_batch(epoch, iter, G.batch_size)
        loss += model_loss(m, batch[1]...)
        weights += sum(batch[1][end])
    end
    Float32(loss / weights)
end;

## Training
* Trains a neural network with the given hyperparameters

In [17]:
function get_optimizer(optimizer, learning_rate)
    if optimizer == "ADAM"
        return ADAMW(learning_rate, (0.9, 0.999), 0)
    elseif optimizer == "SGD"
        return Descent(learning_rate)
    else
        @assert false
    end
end;

In [18]:
function train_epoch!(m, ps, opt)
    LinearAlgebra.BLAS.set_num_threads(Threads.nthreads())
    epoch = get_epoch("training")
    sampling_order = get_sampling_order("training")
    # make the training loss invariant to the scale of the weight decays
    model_loss_scale = G.num_users / sum(epoch[4])
    batchloss(x, y, z, w) =
        training_loss(m, x, y, z, w; model_loss_scale = model_loss_scale)

    nbatches = Int(ceil(length(sampling_order) / G.batch_size))
    for iter = 1:nbatches
        batch, _ = get_batch(epoch, iter, G.batch_size, sampling_order)
        Flux.train!(batchloss, ps, batch, opt)
    end
end;

In [19]:
# trains a model with the given hyperparams and returns its validation loss
function train_model(hyp; max_checkpoints = 10, epochs_per_checkpoint = 10, patience = 0)
    global G = hyp
    opt = get_optimizer(G.optimizer, G.learning_rate)
    Random.seed!(G.seed)
    m = build_model()
    best_model = m |> cpu
    ps = Flux.params(m)
    stopper = early_stopper(max_iters = max_checkpoints, patience = patience)

    losses = []
    loss = Inf
    while (!stop!(stopper, loss))
        for i = 1:epochs_per_checkpoint
            train_epoch!(m, ps, opt)
        end
        loss = split_loss(m, "validation")
        push!(losses, loss)
        if loss == minimum(losses)
            best_model = m |> cpu
        end
    end
    global G = nothing
    best_model, minimum(losses)
end;

## Hyperparameter Tuning
* A derivative free optimizer is used to find the best hyperparameters

In [20]:
function num_tuneable_params(model)
    num_model_params = 4
    if model == "user_item_biases"
        num_sampling_params = 0
        num_regularization_params = 2
    else
        @assert false
    end
    num_model_params, num_sampling_params, num_regularization_params
end

function create_hyperparams(hyp, λ)
    _, num_sampling_params, num_regularization_params = num_tuneable_params(hyp.model)
    hyp = @set hyp.learning_rate = 0.01 * exp(λ[1])
    hyp = @set hyp.residual_beta = hyp.implicit ? λ[2] : 1 + λ[2]
    hyp = @set hyp.user_weight_decay = λ[3]
    hyp = @set hyp.item_weight_decay = λ[4]
    if num_sampling_params == 1
        hyp = @set hyp.user_sampling_scheme = λ[5]
    end
    hyp = @set hyp.regularization_params = exp.(λ[end-num_regularization_params+1:end])
    hyp
end;

In [21]:
function optimize_hyperparams(hyp; max_evals)
    function nlopt_loss(λ, grad)
        # nlopt internally converts to float64 because it calls a c library
        λ = convert.(Float32, λ)
        _, loss = train_model(create_hyperparams(hyp, λ))
        @info "$λ $loss"
        loss
    end
    num_variables = sum(num_tuneable_params(hyp.model))
    opt = NLopt.Opt(:LN_NELDERMEAD, num_variables)
    opt.initial_step = 1
    opt.maxeval = max_evals
    opt.min_objective = nlopt_loss
    minf, λ, ret = NLopt.optimize(opt, zeros(Float32, num_variables))
    numevals = opt.numevals

    @info (
        "found minimum $minf at point $λ after $numevals function calls " *
        "(ended because $ret) and saved model at"
    )
    λ
end;

In [22]:
hp = Hyperparams(
    # model
    implicit = false,
    model = "user_item_biases",
    # batching
    # 1024 is the smallest batch size that saturates the gpu
    batch_size = 1024,
    input_data = "one_hot",
    user_sampling_scheme = "constant",
    # optimizer
    learning_rate = 0.01,
    optimizer = "ADAM",
    # training
    seed = 20220524,
    num_users = num_users(),
    # loss
    item_weight_decay = 0,
    regularization_params = Float32[1, 1],
    residual_alphas = [],
    residual_beta = 0,
    user_weight_decay = 0,
);

In [None]:
hp_subset = @set hp.num_users = Int(round(num_users() * 0.1))
λ = optimize_hyperparams(hp_subset; max_evals = 100)
@info "THE BEST HYPERPARAMETERS ARE $λ"

[32mProgress: 100%|███████████████████████████| Time: 0:00:01 ( 0.12 μs/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:04 (31.77 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:04 (36.33 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220604 18:26:43 Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0] 1.8194634
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220604 18:27:16 Float32[1.0, 0.0, 0.0, 0.0, 0.0, 0.0] 1.8205892
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220604 18:28:24 Float32[0.0, 1.0, 0.0, 0.0, 0.0, 0.0] 1.8194634
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220604 18:29:59 Float32[0.0, 1.0, 1.0, 0.0, 0.0, 0.0] 1.834349
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220604 18:31:11 Float32[0.0, 1.0, 0.0, 1.0, 0.0, 0.0] 1.8874848
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220604 18:32:23 Float32[0.0, 1.0, 0.0, 0.0, 1.0, 0.0] 1.8164439
[38;5;6m[1m[ [22m[39m[38;5;6m

In [None]:
# train with the full dataset and with looser early-stopping rules
m, loss = train_model(create_hyperparams(hp, λ); max_checkpoints = 100, epochs_per_checkpoint = 1, patience = 10)
@info loss

## Retrain User Embeddings
* To minimize training/serving skew, we train the model the same
  way we will train it during inference
* This means reinitializing the user embeddings, freezing all other layers,
  and fine-tuning the user embeddings
* During serving, we will determine a new user's embedding
  by training with the same hyperparameters and number of epochs

In [None]:
function retrain_user_embeddings!(hyp, m)
    if hyp.model == "user_item_biases"
        m[1].weight .= Flux.Embedding(hyp.num_users => 1).weight
        ps = Flux.params(m[1])
        stopper = early_stopper(max_iters = 100, patience = 10)
        loss = Inf

        global G = hyp
        while (!stop!(stopper, loss))
            train_epoch!(m, ps, opt)
            loss = split_loss(m, "validation")
            @info loss
        end
        global G = nothing

        epochs = stopper.iters - stopper.iters_without_improvement
    else
        @assert false
    end
end

In [None]:
# train_model(hyp)

## Write predictions

In [None]:
# # returns the preimage of the index -> split.user[index] mapping
# # this is primarily a performance optimization
# @memoize function user_to_output_indices(split)
#     users = filter_users(get_split(split, implicit), G.num_users).user
#     user_to_output_idxs = [Dict() for t = 1:Threads.nthreads()]
#     @tprogress Threads.@threads for j = 1:length(users)
#         u = users[j]
#         t = Threads.threadid()
#         if u ∉ keys(user_to_output_idxs[t])
#             user_to_output_idxs[t][u] = []
#         end
#         push!(user_to_output_idxs[t][u], j)
#     end
#     merge(vcat, user_to_output_idxs...)
# end;

# # returns a ratins dataset of predicted ratings
# function evaluate(m, split)
#     # get model inputs
#     user_to_output_idxs = user_to_output_indices(split)
#     df = filter_users(get_split(split, implicit), G.num_users)
#     users = df.user
#     items = df.item
#     epoch = get_epoch(split)

#     # compute predictions    
#     batch_size = G.batch_size
#     activation = G.implicit ? softmax : identity
#     ratings = zeros(Float32, length(users))
#     @showprogress for iter = 1:Int(ceil(G.num_users / batch_size))
#         batch, sampled_users = get_batch(epoch, iter, batch_size)
#         alpha = activation(m(batch[1][1])) |> cpu

#         for j = 1:length(sampled_users)
#             u = sampled_users[j]
#             if u in keys(user_to_output_idxs)
#                 for output_idx in user_to_output_idxs[u]
#                     ratings[output_idx] = alpha[items[output_idx], j]
#                 end
#             end
#         end
#     end

#     RatingsDataset(user = users, item = items, rating = ratings)
# end;

In [None]:
# function make_prediction(sparse_preds, users, items)
#     preds = zeros(length(users))
#     @tprogress Threads.@threads for j = 1:length(preds)
#         preds[j] = sparse_preds[users[j], items[j]]
#     end
#     preds
# end;

In [None]:
# function save_model(model_path, hyperparams::Hyperparams, outdir)
#     global G = hyperparams
#     BSON.@load model_path m
#     training = evaluate(m, "training")
#     validation = evaluate(m, "validation")
#     test = evaluate(m, "test")
#     df = reduce(cat, [training, validation, test])
#     sparse_preds = sparse(df.user, df.item, df.rating)

#     write_predictions(
#         (users, items) -> make_prediction(sparse_preds, users, items),
#         residual_alphas = G.validation_residuals,
#         outdir = outdir,
#         implicit = G.train_implicit_model,
#     )
#     params = to_dict(G)
#     params["model"] = model_path
#     write_params(params, outdir = outdir)
# end;

In [None]:
# function fit(hyperparams::Hyperparams, outdir)
#     redirect_logging("../../data/alphas/$outdir")
#     @info string(hyperparams)
#     model_path = train_model(hyperparams)
#     save_model(model_path, hyperparams, outdir)
# end;

In [None]:

# hyperparams = Hyperparams(
#     use_derived_features = true,
#     train_implicit_model = true,
#     activation = "relu",
#     autoencode = true,
#     batch_size = 128,
#     dropout_perc = 0.5,
#     dropout_rescale = false,
#     layers = [512, 256, 128, 256, 512],
#     l2penalty = 1e-5,
#     learning_rate = 0.001,
#     optimizer = "ADAM",
#     patience = 10,
#     sampling_weight_scheme = "linear",
#     training_residuals = ["UserItemBiases"],
#     training_weight_scheme = "linear",
#     use_residualized_validation_loss = false,
#     validation_residuals = ["UserItemBiases"],
#     validation_weight_scheme = "constant",
#     seed = 20220501 * hash(name),
# )

# #fit(hyperparams, "GNN.Rating.Test.2")