# Neural Network Base Class

In [2]:
using Flux
using Random
using SparseArrays
using Statistics: var


import BSON
import NBInclude: @nbinclude
@nbinclude("Alpha.ipynb");

In [3]:
const device = Flux.gpu;

## Parameters

In [4]:
@with_kw struct Hyperparams
    # model
    implicit::Bool
    model::String
    # batching
    batch_size::Int
    input_data::String
    user_sampling_scheme::String
    # optimizer
    l2penalty::Float32
    learning_rate::Float32
    optimizer::String
    # training
    patience::Int
    seed::UInt64
    # loss
    item_weight_decay::Float32
    regularization_params::Vector{Float32}
    residual_alphas::Vector{String}
    residual_beta::Float32
    user_weight_decay::Float32
end

function to_dict(x::Hyperparams)
    Dict(string(key) => getfield(x, key) for key ∈ fieldnames(Hyperparams))
end

function Base.string(x::Hyperparams)
    if x.implicit
        @assert 0 <= x.residual_beta && x.residual_beta <= 1
    end
    fields = [x for x in fieldnames(Hyperparams)]
    max_field_size = maximum(length(string(k)) for k in fields)
    ret = "Hyperparameters:\n"
    for f in fields
        ret *= "$(rpad(string(f), max_field_size)) => $(getfield(x, f))\n"
    end
    ret
end;

## Model

In [5]:
# A layer that takes one input and splits it into many
struct Split{T}
    paths::T
end
Split(paths...) = Split(paths)
Flux.@functor Split
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)

# A layer that takes many inputs and joins them into one
Join(combine, paths) = Parallel(combine, paths)
Join(combine, paths...) = Join(combine, paths);

In [6]:
# A layer that adds a 1-D vector to the input
struct BiasLayer
    b::Any
end
BiasLayer(n::Integer; init = randn) = BiasLayer(init(Float32, n))
(m::BiasLayer)(x) = x .+ m.b
Flux.@functor BiasLayer

In [7]:
# Implements a baseline predictor given by R[i, j] = u[i] + a[j]
function user_item_biases()
    U = Flux.Embedding(num_users() => 1; init = zeros)
    A = BiasLayer(num_items(); init = zeros)
    m = Chain(U, A) |> device
end

# regularization is λ_u variance(u) + λ_a variance(a)
function user_item_biases_regularization(m)
    var(m[1].weight) * G.regularization_params[1] + var(m[2].b) * G.regularization_params[2]
end;

In [8]:
function build_model()
    if G.model == "user_item_biases"
        return user_item_biases()
    end
    @assert false
end

function regularization_loss(m)
    if G.model == "user_item_biases"
        return user_item_biases_regularization(m)
    end
    @assert false
end;

# Batching

In [9]:
function SparseArrays.sparse(split::RatingsDataset)
    sparse(split.item, split.user, split.rating, num_items(), num_users())
end

function SparseArrays.sparse(split::RatingsDataset, ratings)
    sparse(split.item, split.user, ratings, num_items(), num_users())
end;

In [10]:
function one_hot_inputs(split)
    X = collect(1:num_users())
    Y = sparse(get_split(split; implicit = G.implicit))
    X, Y
end;

In [11]:
function get_epoch(split)
    if G.input_data == "one_hot"
        X, Y = one_hot_inputs(split)
    else
        @assert false
    end

    # construct residuals
    residuals = read_alpha(G.residual_alphas, split, G.implicit)
    residuals.rating .*= G.residual_beta
    Z = sparse(residuals)

    # construct loss-function weights
    if split == "training"
        weights =
            expdecay(get_counts(split), G.user_weight_decay) .*
            expdecay(get_counts(split; by_item = true), G.item_weight_decay)
    else
        weights = expdecay(get_counts(split), weighting_scheme("inverse"))
    end
    W = sparse(get_split(split; implicit = G.implicit), weights)

    X, Y, Z, W
end;

In [12]:
function get_sampling_order(split)
    weighting_scheme = split == "training" ? G.user_sampling_scheme : "constant"
    if weighting_scheme == "constant"
        return shuffle(1:num_users())
    else
        weights = expdecay(
            get_counts(split; per_rating = false),
            weighting_scheme(G.user_sampling_scheme),
        )
        return sample(1:num_users(), Weights(weights), num_users())
    end
end;

In [13]:
function tcollect(X::SparseMatrixCSC)
    Y = zeros(eltype(X), size(X)...)
    Threads.@threads for j in 1:size(X)[2]
        for i in X[:, j].nzind
            Y[i, j] = X[i, j]
        end
    end
    Y
end

function get_batch(epoch, iter, batch_size, sampling_order)
    range = sampling_order[(iter-1)*batch_size+1:min(iter * batch_size, num_users())]
    function process(x)
        if length(size(x)) == 1
            # TODO switch to using tcollect
            return collect(x[range]) |> device
        else
            return tcollect(x[:, range]) |> device
        end
    end
    [process.(epoch)]
end;

function get_batch(epoch, iter, batch_size)
    sampling_order = 1:num_users()
    get_batch(epoch, iter, batch_size, sampling_order)
end;

## Loss Functions

In [14]:
function training_loss(m, x, y, z, w)
    p = m(x)
    if G.implicit
        q = softmax(p) .* G.residual_beta + z .* (1 - G.residual_beta)
    else
        q = p + z
    end
    loss(q, y, w, G.implicit) + regularization_loss(m)
end;

# Evaluation

In [15]:
# returns the preimage of the index -> split.user[index] mapping
# this is primarily a performance optimization
@memoize function user_to_output_indices(split)
    users = get_split(split; implicit = G.implicit).user
    user_to_output_idxs = [Dict() for t = 1:Threads.nthreads()]
    @tprogress Threads.@threads for j = 1:length(users)
        u = users[j]
        t = Threads.threadid()
        if u ∉ keys(user_to_output_idxs[t])
            user_to_output_idxs[t][u] = []
        end
        push!(user_to_output_idxs[t][u], j)
    end
    merge(vcat, user_to_output_idxs...)
end;

In [16]:
function evaluate(m, split)
    # get model inputs
    user_to_output_idxs = user_to_output_indices(split)
    df = get_split(split; implicit = G.implicit)
    users = df.user
    items = df.item
    epoch = get_epoch(split)

    # compute predictions    
    LinearAlgebra.BLAS.set_num_threads(1)
    batch_size = 16
    ratings = zeros(Float32, length(users))
    # TODO singlethread on GPU
    @tprogress Threads.@threads for iter = 1:Int(ceil(num_users() / batch_size))
        batch = get_batch(epoch, iter, batch_size)[1]
        alpha = m(batch[1]) |> cpu
        if G.implicit
            alpha = softmax(alpha)
        end

        for j = 1:size(alpha)[2]
            u = batch[end][1, j]
            if u ∉ keys(user_to_output_idxs)
                continue
            end
            for output_idx in user_to_output_idxs[u]
                ratings[output_idx] = alpha[items[output_idx], j]
            end
        end
    end
    LinearAlgebra.BLAS.set_num_threads(Threads.nthreads())

    RatingsDataset(user = users, item = items, rating = ratings)
end;

In [17]:
function snapshot_loss(m)
    val_pred = evaluate(m, "validation")
    x, β = regress(val_pred.rating, G.residual_alphas, G.implicit)
    val_loss =
        residualized_loss(val_pred.rating, G.residual_alphas, G.implicit, β, "validation")

    train_pred = evaluate(m, "training")
    train_loss =
        residualized_loss(train_pred.rating, G.residual_alphas, G.implicit, β, "training")

    train_loss, val_loss
end;

In [18]:
function checkpoint(m)
    training_loss, validation_loss = snapshot_loss(m)
    @info "training loss $training_loss, validation loss $validation_loss"
    validation_loss
end;

## Training

In [19]:
function continue_training(m, opt, stop_criteria, model_path)
    validation_loss = checkpoint(m)
    if validation_loss < stop_criteria.loss
        BSON.@save model_path m opt
    end
    !stop!(stop_criteria, validation_loss)
end;

In [20]:
function train_epoch!(m, opt; checkpoint_rate = Inf)
    LinearAlgebra.BLAS.set_num_threads(Threads.nthreads())
    ps = Flux.params(m)
    epoch = get_epoch("training")
    sampling_order = get_sampling_order("training")
    batchloss(x, y, z, w) = training_loss(m, x, y, z, w)

    nbatches = Int(ceil(num_items() / G.batch_size))
    ProgressMeter.@showprogress for iter = 1:nbatches
        batch = get_batch(epoch, iter, G.batch_size, sampling_order)
        Flux.train!(batchloss, ps, batch, opt)

        if iter % checkpoint_rate == 0
            checkpoint(m)
        end
    end
end;

In [21]:
function train_model(hyperparams::Hyperparams)
    # unpack parameters
    global G = hyperparams
    Random.seed!(G.seed)
    m = build_model()
    if G.optimizer == "ADAM"
        opt = ADAMW(G.learning_rate, (0.9, 0.999), G.l2penalty)
    else
        @assert false
    end
    stop_criteria = early_stopper(patience = G.patience)
    model_path = "../../data/alphas/$name/model.$(hash(G)).bson"

    # Train model
    while continue_training(m, opt, stop_criteria, model_path)
        train_epoch!(m, opt)
    end

    model_path
end;

In [22]:
hyp = Hyperparams(
    # model
    implicit = false,
    model = "user_item_biases",
    # batching
    batch_size = 128,
    input_data = "one_hot",
    user_sampling_scheme = "constant",
    # optimizer
    l2penalty = 0,
    learning_rate = 0.001,
    optimizer = "ADAM",
    # training
    patience = 100000,
    seed = 20220524,
    # loss
    item_weight_decay = 0,
    regularization_params = Float32[0, 0],
    residual_alphas = [],
    residual_beta = 0,
    user_weight_decay = 0,
);

In [None]:
train_model(hyp)

[38;5;6m[1m┌ [22m[39m[38;5;6m[1mInfo: [22m[39m20220526 14:24:19 The GPU function is being called but the GPU is not accessible. 
[38;5;6m[1m└ [22m[39mDefaulting back to the CPU. (No action is required if you want to run on the CPU).
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 ( 3.23 μs/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:17 ( 6.94 ms/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:09 ( 2.06 μs/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (35.01 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (34.39 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:21 ( 8.30 ms/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220526 14:27:30 training loss 65.610374, validation loss 64.24093
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:19:21[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00

In [34]:
# G = hyp;
# m = build_model();
# ps = Flux.params(m)
# epoch = get_epoch("training")
# opt = Descent()
# sampling_order = get_sampling_order("training")
# batchloss(x, y, z, w) = training_loss(m, x, y, z, w)

[32mProgress: 100%|███████████████████████████| Time: 0:00:00 ( 0.15 μs/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (33.69 ns/it)[39m


batchloss (generic function with 1 method)

## Write predictions

In [None]:
# function make_prediction(sparse_preds, users, items)
#     preds = zeros(length(users))
#     @tprogress Threads.@threads for j = 1:length(preds)
#         preds[j] = sparse_preds[users[j], items[j]]
#     end
#     preds
# end;

In [None]:
# function save_model(model_path, hyperparams::Hyperparams, outdir)
#     global G = hyperparams
#     BSON.@load model_path m
#     training = evaluate(m, "training")
#     validation = evaluate(m, "validation")
#     test = evaluate(m, "test")
#     df = reduce(cat, [training, validation, test])
#     sparse_preds = sparse(df.user, df.item, df.rating)

#     write_predictions(
#         (users, items) -> make_prediction(sparse_preds, users, items),
#         residual_alphas = G.validation_residuals,
#         outdir = outdir,
#         implicit = G.train_implicit_model,
#     )
#     params = to_dict(G)
#     params["model"] = model_path
#     write_params(params, outdir = outdir)
# end;

In [None]:
# function fit(hyperparams::Hyperparams, outdir)
#     redirect_logging("../../data/alphas/$outdir")
#     @info string(hyperparams)
#     model_path = train_model(hyperparams)
#     save_model(model_path, hyperparams, outdir)
# end;

In [None]:
# hyperparams = Hyperparams(
#     use_derived_features = true,
#     train_implicit_model = true,
#     activation = "relu",
#     autoencode = true,
#     batch_size = 128,
#     dropout_perc = 0.5,
#     dropout_rescale = false,
#     layers = [512, 256, 128, 256, 512],
#     l2penalty = 1e-5,
#     learning_rate = 0.001,
#     optimizer = "ADAM",
#     patience = 10,
#     sampling_weight_scheme = "linear",
#     training_residuals = ["UserItemBiases"],
#     training_weight_scheme = "linear",
#     use_residualized_validation_loss = false,
#     validation_residuals = ["UserItemBiases"],
#     validation_weight_scheme = "constant",
#     seed = 20220501 * hash(name),
# )

# #fit(hyperparams, "GNN.Rating.Test.2")