# Neural Network Base Class
* This class serves as the basis 

In [1]:
using Flux
using Random
using SparseArrays
using Statistics: var

import BSON
import CUDA
import NBInclude: @nbinclude
import NLopt
@nbinclude("Alpha.ipynb");

In [2]:
function device(x)
    gpu(x)
end

# efficiently convert a sparse cpu matrix into a dense CUDA array
function device(x::AbstractSparseArray)
    CUDA.functional() ? CUDA.CuArray(gpu(x)) : x
end;

## Hyperparameters

In [3]:
@with_kw struct Hyperparams
    # model
    implicit::Bool
    model::String
    # batching
    batch_size::Int
    input_data::String
    user_sampling_scheme::String
    # optimizer
    learning_rate::Float32
    optimizer::String
    # training
    patience::Int
    seed::UInt64
    # loss
    item_weight_decay::Float32
    regularization_params::Vector{Float32}
    residual_alphas::Vector{String}
    residual_beta::Float32
    user_weight_decay::Float32
end

function to_dict(x::Hyperparams)
    Dict(string(key) => getfield(x, key) for key ∈ fieldnames(Hyperparams))
end

function Base.string(x::Hyperparams)
    if x.implicit
        @assert 0 <= x.residual_beta && x.residual_beta <= 1
    end
    fields = [x for x in fieldnames(Hyperparams)]
    max_field_size = maximum(length(string(k)) for k in fields)
    ret = "Hyperparameters:\n"
    for f in fields
        ret *= "$(rpad(string(f), max_field_size)) => $(getfield(x, f))\n"
    end
    ret
end;

## Models
* To define a new model, add the architecture to `build_model` and the regularization to `regularization_loss`

In [4]:
# A layer that takes one input and splits it into many
struct Split{T}
    paths::T
end
Split(paths...) = Split(paths)
Flux.@functor Split
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)

# A layer that takes many inputs and joins them into one
Join(combine, paths) = Parallel(combine, paths)
Join(combine, paths...) = Join(combine, paths);

In [5]:
# A layer that adds a 1-D vector to the input
struct BiasLayer
    b::Any
end
BiasLayer(n::Integer; init = randn) = BiasLayer(init(Float32, n))
(m::BiasLayer)(x) = x .+ m.b
Flux.@functor BiasLayer

In [6]:
# Implements a baseline predictor given by R[i, j] = u[i] + a[j]
function user_item_biases()
    U = Flux.Embedding(num_users() => 1)
    A = BiasLayer(num_items())
    m = Chain(U, A) |> device
end

# regularization is λ_u variance(u) + λ_a variance(a)
function user_item_biases_regularization(m)
    var(m[1].weight) * G.regularization_params[1] + var(m[2].b) * G.regularization_params[2]
end;

In [7]:
function build_model()
    if G.model == "user_item_biases"
        return user_item_biases()
    end
    @assert false
end

function regularization_loss(m)
    if G.model == "user_item_biases"
        return user_item_biases_regularization(m)
    end
    @assert false
end;

# Data Preprocessing

In [8]:
function one_hot_inputs(split, implicit)
    X = collect(1:num_users())
    Y = sparse(get_split(split; implicit = implicit))
    X, Y
end;

In [9]:
@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_inputs(split, input_data, implicit)
    if G.input_data == "one_hot"
        X, Y = one_hot_inputs(split, implicit)
    else
        @assert false
    end
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_residuals(
    split,
    residual_alphas,
    residual_beta,
    implicit,
)
    residuals = read_alpha(residual_alphas, split, implicit)
    residuals.rating .*= residual_beta
    sparse(residuals)
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_weights(
    split,
    user_weight_decay,
    item_weight_decay,
    implicit,
)
    if split == "training"
        weights =
            expdecay(get_counts(split), user_weight_decay) .*
            expdecay(get_counts(split; by_item = true), item_weight_decay)
    else
        weights = expdecay(get_counts(split), weighting_scheme("inverse"))
    end
    sparse(get_split(split; implicit = implicit), weights)
end;

In [10]:
# returns (X, Y, Z, W) = (inputs, outputs, residualization alpha, weights)
function get_epoch(split)
    X, Y = get_epoch_inputs(split, G.input_data, G.implicit)
    Z = get_epoch_residuals(split, G.residual_alphas, G.residual_beta, G.implicit)
    W = get_epoch_weights(split, G.user_weight_decay, G.item_weight_decay, G.implicit)
    X, Y, Z, W
end;

# Batching
* Turns an epoch into minibatches
* Each data point will appear in a minibatch with a probability proportional to its sampling weight

In [11]:
function SparseArrays.sparse(split::RatingsDataset)
    sparse(split.item, split.user, split.rating, num_items(), num_users())
end

function SparseArrays.sparse(split::RatingsDataset, ratings)
    sparse(split.item, split.user, ratings, num_items(), num_users())
end;

In [12]:
function slice(x::AbstractVector, range)
    x[range]
end

function slice(x::AbstractMatrix, range)
    x[:, range]
end;

In [13]:
function get_sampling_order(split)
    weighting_scheme = split == "training" ? G.user_sampling_scheme : "constant"
    if weighting_scheme == "constant"
        return shuffle(1:num_users())
    else
        weights = expdecay(
            get_counts(split; per_rating = false),
            weighting_scheme(G.user_sampling_scheme),
        )
        return sample(1:num_users(), Weights(weights), num_users())
    end
end;

In [14]:
# performs the following steps
# 1) shuffle the epoch by the sampling order
# 2) split the epoch into minibatches of size batch_size
# 3) return the iter-th minibatch
function get_batch(epoch, iter, batch_size, sampling_order)
    sampling_order = 1:num_users()
    range = sampling_order[(iter-1)*batch_size+1:min(iter * batch_size, num_users())]
    process(x) = slice(x, range) |> device
    [process.(epoch)], range
end;

function get_batch(epoch, iter, batch_size)
    sampling_order = 1:num_users()
    get_batch(epoch, iter, batch_size, sampling_order)
end;

## Loss Functions

In [15]:
function model_loss(m, x, y, z, w)
    p = m(x)
    if G.implicit
        q = softmax(p) .* G.residual_beta + z .* (1 - G.residual_beta)
        return sum(w .* -y .* log.(x))
    else
        q = p + z .* G.residual_beta
        return sum(w .* (q - y) .^ 2)
    end
end

function training_loss(m, x, y, z, w; model_loss_scale)
    model_loss(m, x, y, z, w) * model_loss_scale + regularization_loss(m)
end

function split_loss(m, split)
    epoch = get_epoch(split)
    loss = 0.0
    weights = 0.0
    for iter = 1:Int(ceil(num_users() / G.batch_size))
        batch, _ = get_batch(epoch, iter, G.batch_size)
        loss += model_loss(m, batch[1]...)
        weights += sum(batch[1][end])
    end
    Float32(loss / weights)
end;

# Evaluation

In [16]:
# returns the preimage of the index -> split.user[index] mapping
# this is primarily a performance optimization
@memoize function user_to_output_indices(split)
    users = get_split(split; implicit = G.implicit).user
    user_to_output_idxs = [Dict() for t = 1:Threads.nthreads()]
    @tprogress Threads.@threads for j = 1:length(users)
        u = users[j]
        t = Threads.threadid()
        if u ∉ keys(user_to_output_idxs[t])
            user_to_output_idxs[t][u] = []
        end
        push!(user_to_output_idxs[t][u], j)
    end
    merge(vcat, user_to_output_idxs...)
end;

function evaluate(m, split)
    # get model inputs
    user_to_output_idxs = user_to_output_indices(split)
    df = get_split(split; implicit = G.implicit)
    users = df.user
    items = df.item
    epoch = get_epoch(split)

    # compute predictions    
    batch_size = G.batch_size
    activation = G.implicit ? softmax : identity
    ratings = zeros(Float32, length(users))
    @showprogress for iter = 1:Int(ceil(num_users() / batch_size))
        batch, sampled_users = get_batch(epoch, iter, batch_size)
        alpha = activation(m(batch[1][1])) |> cpu

        for j = 1:length(sampled_users)
            u = sampled_users[j]
            if u in keys(user_to_output_idxs)
                for output_idx in user_to_output_idxs[u]
                    ratings[output_idx] = alpha[items[output_idx], j]
                end
            end
        end
    end

    RatingsDataset(user = users, item = items, rating = ratings)
end;

In [17]:
function snapshot_loss(m)
    split_loss(m, "training"), split_loss(m, "validation")
end;

In [18]:
function checkpoint(m)
    training_loss, validation_loss = snapshot_loss(m)
    @info "training loss $training_loss, validation loss $validation_loss"
    validation_loss
end;

## Training

In [19]:
function train_epoch!(m, opt)
    LinearAlgebra.BLAS.set_num_threads(Threads.nthreads())
    ps = Flux.params(m)
    epoch = get_epoch("training")
    sampling_order = get_sampling_order("training")
    # make the training loss invariant to the scale of the weight decays
    model_loss_scale = num_users() / sum(epoch[4])
    batchloss(x, y, z, w) =
        training_loss(m, x, y, z, w; model_loss_scale = model_loss_scale)

    nbatches = Int(ceil(length(sampling_order) / G.batch_size))
    for iter = 1:nbatches
        batch, _ = get_batch(epoch, iter, G.batch_size, sampling_order)
        Flux.train!(batchloss, ps, batch, opt)
    end
end;

## Hyperparameter Optimization

In [20]:
function num_tuneable_params(model)
    if model == "user_item_biases"
        return 5
    else
        @assert false
    end
end

function create_hyperparams(base::Hyperparams, λ)
    if base.model == "user_item_biases"
        return Hyperparams(
            # model
            implicit = hyp.implicit,
            model = hyp.model,
            # batching
            batch_size = hyp.batch_size,
            input_data = hyp.input_data,
            user_sampling_scheme = hyp.user_sampling_scheme,
            # optimizer
            learning_rate = 0.01 * exp(λ[1]), # tuned parameter
            optimizer = hyp.optimizer,
            # training
            patience = 0, # speed up search by early terminating bad seeds
            seed = hyp.seed,
            # loss
            item_weight_decay = λ[5], # tuned parameter
            regularization_params = exp.(λ[2:3]), # tuned parameter
            residual_alphas = [],
            residual_beta = 0,
            user_weight_decay = λ[4], # tuned parameter
        )
    else
        @assert false
    end
end

function hyperparameter_loss(
    hyp::Hyperparams,
    λ,
    model_path;
    num_checkpoints = 20,
    epochs_per_checkpoint = 10,
)
    global G = create_hyperparams(hyp, λ)

    if G.optimizer == "ADAM"
        opt = ADAMW(G.learning_rate, (0.9, 0.999), 0) # TODO try training the decay rate
    end
    Random.seed!(G.seed)
    m = build_model()
    stopper = early_stopper(patience = G.patience, max_iters = num_checkpoints)

    losses = []
    loss = Inf
    while (!stop!(stopper, loss))
        for i = 1:epochs_per_checkpoint
            train_epoch!(m, opt)
        end
        loss = split_loss(m, "validation")
        push!(losses, loss)
        @info loss
    end

    m = m |> cpu
    BSON.@save model_path m
    @info λ, minimum(losses)
    minimum(losses)
end;

In [21]:
function optimize_hyperparams(hyp; max_evals)
    model_path = "../../data/alphas/$name/$(hash(hyp)).bson"
    function nlopt_loss(λ, grad)
        # nlopt internally converts to float64 because it calls a c library
        hyperparameter_loss(hyp, convert.(Float32, λ), model_path)
    end
    num_variables = num_tuneable_params(hyp.model)
    opt = NLopt.Opt(:LN_BOBYQA, num_variables)
    opt.initial_step = 1
    opt.maxeval = max_evals
    opt.min_objective = nlopt_loss
    minf, λ, ret = NLopt.optimize(opt, zeros(Float32, num_variables))
    numevals = opt.numevals

    @info ("found minimum $minf at point $λ after $numevals function calls "
        * "(ended because $ret) and saved model at $model_path")
    write_params(Dict("λ" => λ, "m" => model_path))
    λ
end;

In [None]:
hyp = Hyperparams(
    # model
    implicit = false,
    model = "user_item_biases",
    # batching
    # 1024 is the smallest batch size that saturates the gpu
    batch_size = 1024,
    input_data = "one_hot",
    user_sampling_scheme = "constant",
    # optimizer
    learning_rate = 0.01,
    optimizer = "ADAM",
    # training
    patience = 10,
    seed = 20220524,
    # loss
    item_weight_decay = 0,
    regularization_params = Float32[1, 1],
    residual_alphas = [],
    residual_beta = 0,
    user_weight_decay = 0,
);

optimize_hyperparams(hyp; max_evals = 100)

[32mProgress: 100%|███████████████████████████| Time: 0:00:00 ( 0.76 μs/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (32.50 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (30.31 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220530 02:34:12 1.8687259
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220530 02:36:01 1.8605917
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220530 02:37:51 1.8554357
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220530 02:39:39 1.8515037
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220530 02:41:29 1.8516275
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220530 02:41:32 (Float32[0.0, 0.0, 0.0, 0.0, 0.0], 1.8515037f0)
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220530 02:43:23 2.006554
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220530 02:45:15 2.0062215
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m2022053

In [None]:
# train_model(hyp)

## Write predictions

In [None]:
# function make_prediction(sparse_preds, users, items)
#     preds = zeros(length(users))
#     @tprogress Threads.@threads for j = 1:length(preds)
#         preds[j] = sparse_preds[users[j], items[j]]
#     end
#     preds
# end;

In [None]:
# function save_model(model_path, hyperparams::Hyperparams, outdir)
#     global G = hyperparams
#     BSON.@load model_path m
#     training = evaluate(m, "training")
#     validation = evaluate(m, "validation")
#     test = evaluate(m, "test")
#     df = reduce(cat, [training, validation, test])
#     sparse_preds = sparse(df.user, df.item, df.rating)

#     write_predictions(
#         (users, items) -> make_prediction(sparse_preds, users, items),
#         residual_alphas = G.validation_residuals,
#         outdir = outdir,
#         implicit = G.train_implicit_model,
#     )
#     params = to_dict(G)
#     params["model"] = model_path
#     write_params(params, outdir = outdir)
# end;

In [None]:
# function fit(hyperparams::Hyperparams, outdir)
#     redirect_logging("../../data/alphas/$outdir")
#     @info string(hyperparams)
#     model_path = train_model(hyperparams)
#     save_model(model_path, hyperparams, outdir)
# end;

In [None]:

# hyperparams = Hyperparams(
#     use_derived_features = true,
#     train_implicit_model = true,
#     activation = "relu",
#     autoencode = true,
#     batch_size = 128,
#     dropout_perc = 0.5,
#     dropout_rescale = false,
#     layers = [512, 256, 128, 256, 512],
#     l2penalty = 1e-5,
#     learning_rate = 0.001,
#     optimizer = "ADAM",
#     patience = 10,
#     sampling_weight_scheme = "linear",
#     training_residuals = ["UserItemBiases"],
#     training_weight_scheme = "linear",
#     use_residualized_validation_loss = false,
#     validation_residuals = ["UserItemBiases"],
#     validation_weight_scheme = "constant",
#     seed = 20220501 * hash(name),
# )

# #fit(hyperparams, "GNN.Rating.Test.2")