# Neural Network Base Class

In [1]:
using Flux
using Random
using SparseArrays
using Statistics: var

import BSON
import CUDA
import NBInclude: @nbinclude
@nbinclude("Alpha.ipynb");

In [2]:
function device(x)
    gpu(x)
end

# efficiently convert a sparse cpu matrix into a dense CUDA array
function device(x::AbstractSparseArray)
    CUDA.functional() ? CUDA.CuArray(gpu(x)) : x
end;

## Hyperparameters

In [3]:
@with_kw struct Hyperparams
    # model
    implicit::Bool
    model::String
    # batching
    batch_size::Int
    input_data::String
    user_sampling_scheme::String
    # optimizer
    l2penalty::Float32
    learning_rate::Float32
    optimizer::String
    # training
    patience::Int
    seed::UInt64
    # loss
    item_weight_decay::Float32
    regularization_params::Vector{Float32}
    residual_alphas::Vector{String}
    residual_beta::Float32
    user_weight_decay::Float32
end

function to_dict(x::Hyperparams)
    Dict(string(key) => getfield(x, key) for key ∈ fieldnames(Hyperparams))
end

function Base.string(x::Hyperparams)
    if x.implicit
        @assert 0 <= x.residual_beta && x.residual_beta <= 1
    end
    fields = [x for x in fieldnames(Hyperparams)]
    max_field_size = maximum(length(string(k)) for k in fields)
    ret = "Hyperparameters:\n"
    for f in fields
        ret *= "$(rpad(string(f), max_field_size)) => $(getfield(x, f))\n"
    end
    ret
end;

## Model

In [4]:
# A layer that takes one input and splits it into many
struct Split{T}
    paths::T
end
Split(paths...) = Split(paths)
Flux.@functor Split
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)

# A layer that takes many inputs and joins them into one
Join(combine, paths) = Parallel(combine, paths)
Join(combine, paths...) = Join(combine, paths);

In [5]:
# A layer that adds a 1-D vector to the input
struct BiasLayer
    b::Any
end
BiasLayer(n::Integer; init = randn) = BiasLayer(init(Float32, n))
(m::BiasLayer)(x) = x .+ m.b
Flux.@functor BiasLayer

In [6]:
# Implements a baseline predictor given by R[i, j] = u[i] + a[j]
function user_item_biases()
    U = Flux.Embedding(num_users() => 1)
    A = BiasLayer(num_items())
    m = Chain(U, A) |> device
end

# regularization is λ_u variance(u) + λ_a variance(a)
function user_item_biases_regularization(m)
    var(m[1].weight) * G.regularization_params[1] + var(m[2].b) * G.regularization_params[2]
end;

In [7]:
function build_model()
    if G.model == "user_item_biases"
        return user_item_biases()
    end
    @assert false
end

function regularization_loss(m)
    if G.model == "user_item_biases"
        return user_item_biases_regularization(m)
    end
    @assert false
end;

# Data Preprocessing

In [8]:
function one_hot_inputs(split)
    X = collect(1:num_users())
    Y = sparse(get_split(split; implicit = G.implicit))
    X, Y
end;

In [9]:
@memoize function get_epoch(split)
    if G.input_data == "one_hot"
        X, Y = one_hot_inputs(split)
    else
        @assert false
    end

    # construct residuals
    residuals = read_alpha(G.residual_alphas, split, G.implicit)
    residuals.rating .*= G.residual_beta
    Z = sparse(residuals)

    # construct loss-function weights
    if split == "training"
        weights =
            expdecay(get_counts(split), G.user_weight_decay) .*
            expdecay(get_counts(split; by_item = true), G.item_weight_decay)
    else
        weights = expdecay(get_counts(split), weighting_scheme("inverse"))
    end
    W = sparse(get_split(split; implicit = G.implicit), weights)

    X, Y, Z, W
end;

# Batching
* Turns an epoch into minibatches
* Each data point will appear in a minibatch with a probability proportional to its sampling weight

In [10]:
function SparseArrays.sparse(split::RatingsDataset)
    sparse(split.item, split.user, split.rating, num_items(), num_users())
end

function SparseArrays.sparse(split::RatingsDataset, ratings)
    sparse(split.item, split.user, ratings, num_items(), num_users())
end;

In [11]:
function slice(x::AbstractVector, range)
    x[range]
end

function slice(x::AbstractMatrix, range)
    x[:, range]
end;

In [12]:
function get_sampling_order(split)
    weighting_scheme = split == "training" ? G.user_sampling_scheme : "constant"
    if weighting_scheme == "constant"
        return shuffle(1:num_users())
    else
        weights = expdecay(
            get_counts(split; per_rating = false),
            weighting_scheme(G.user_sampling_scheme),
        )
        return sample(1:num_users(), Weights(weights), num_users())
    end
end;

In [13]:
# performs the following steps
# 1) shuffle the epoch by the sampling order
# 2) split the epoch into minibatches of size batch_size
# 3) return the iter-th minibatch
function get_batch(epoch, iter, batch_size, sampling_order)
    sampling_order = 1:num_users()
    range = sampling_order[(iter-1)*batch_size+1:min(iter * batch_size, num_users())]
    process(x) = slice(x, range) |> device
    [process.(epoch)], range
end;

function get_batch(epoch, iter, batch_size)
    sampling_order = 1:num_users()
    get_batch(epoch, iter, batch_size, sampling_order)
end;

## Loss Functions

In [14]:
function model_loss(m, x, y, z, w)
    p = m(x)
    if G.implicit
        q = softmax(p) .* G.residual_beta + z .* (1 - G.residual_beta)
    else
        q = p + z .* G.residual_beta
    end
    loss(q, y, w, G.implicit)
end

function training_loss(m, x, y, z, w)
    model_loss(m, x, y, z, w) + regularization_loss(m)
end;

In [15]:
function split_loss(m, split)
    epoch = get_epoch(split)
    loss = 0
    @showprogress for iter = 1:Int(ceil(num_users() / G.batch_size))
        batch, _ = get_batch(epoch, iter, G.batch_size)
        loss += model_loss(m, batch[1]...)
    end
    loss
end;

# Evaluation

In [16]:
# returns the preimage of the index -> split.user[index] mapping
# this is primarily a performance optimization
@memoize function user_to_output_indices(split)
    users = get_split(split; implicit = G.implicit).user
    user_to_output_idxs = [Dict() for t = 1:Threads.nthreads()]
    @tprogress Threads.@threads for j = 1:length(users)
        u = users[j]
        t = Threads.threadid()
        if u ∉ keys(user_to_output_idxs[t])
            user_to_output_idxs[t][u] = []
        end
        push!(user_to_output_idxs[t][u], j)
    end
    merge(vcat, user_to_output_idxs...)
end;

function evaluate(m, split)
    # get model inputs
    user_to_output_idxs = user_to_output_indices(split)
    df = get_split(split; implicit = G.implicit)
    users = df.user
    items = df.item
    epoch = get_epoch(split)

    # compute predictions    
    batch_size = G.batch_size
    activation = G.implicit ? softmax : identity
    ratings = zeros(Float32, length(users))
    @showprogress for iter = 1:Int(ceil(num_users() / batch_size))
        batch, sampled_users = get_batch(epoch, iter, batch_size)
        alpha = activation(m(batch[1][1])) |> cpu

        for j = 1:length(sampled_users)
            u = sampled_users[j]
            if u in keys(user_to_output_idxs)
                for output_idx in user_to_output_idxs[u]
                    ratings[output_idx] = alpha[items[output_idx], j]
                end
            end
        end
    end

    RatingsDataset(user = users, item = items, rating = ratings)
end;

In [17]:
function snapshot_loss(m)
    split_loss(m, "training"), split_loss(m, "validation")
end;

In [18]:
function checkpoint(m)
    training_loss, validation_loss = snapshot_loss(m)
    @info "training loss $training_loss, validation loss $validation_loss"
    validation_loss
end;

## Training

In [19]:
function continue_training(m, opt, stop_criteria, model_path)
    validation_loss = checkpoint(m)
    if validation_loss < stop_criteria.loss
        BSON.@save model_path m opt
    end
    !stop!(stop_criteria, validation_loss)
end;

In [20]:
function train_epoch!(m, opt; checkpoint_rate = Inf)
    LinearAlgebra.BLAS.set_num_threads(Threads.nthreads())
    ps = Flux.params(m)
    epoch = get_epoch("training")
    sampling_order = get_sampling_order("training")
    batchloss(x, y, z, w) = training_loss(m, x, y, z, w)

    nbatches = Int(ceil(length(sampling_order) / G.batch_size))
    @showprogress for iter = 1:nbatches
        batch, _ = get_batch(epoch, iter, G.batch_size, sampling_order)
        Flux.train!(batchloss, ps, batch, opt)

        if iter % checkpoint_rate == 0
            checkpoint(m)
        end
    end
end;

In [21]:
function train_model(hyperparams::Hyperparams)
    # unpack parameters
    global G = hyperparams
    Random.seed!(G.seed)
    m = build_model()
    if G.optimizer == "ADAM"
        opt = ADAMW(G.learning_rate, (0.9, 0.999), G.l2penalty)
    else
        @assert false
    end
    stop_criteria = early_stopper(patience = G.patience)
    model_path = "../../data/alphas/$name/model.$(hash(G)).bson"

    # Train model
    train_epoch!(m, opt)
    while continue_training(m, opt, stop_criteria, model_path)
        train_epoch!(m, opt)
    end

    model_path
end;

In [22]:
hyp = Hyperparams(
    # model
    implicit = false,
    model = "user_item_biases",
    # batching
    batch_size = 128,
    input_data = "one_hot",
    user_sampling_scheme = "constant",
    # optimizer
    l2penalty = 0,
    learning_rate = 0.001,
    optimizer = "ADAM",
    # training
    patience = 100000,
    seed = 20220524,
    # loss
    item_weight_decay = 0,
    regularization_params = Float32[0, 0],
    residual_alphas = [],
    residual_beta = 0,
    user_weight_decay = 0,
);

In [23]:
G = hyp
m = build_model();

In [24]:
opt = ADAM();

In [25]:
epoch = get_epoch("training");

[32mProgress: 100%|███████████████████████████| Time: 0:00:01 ( 1.40 μs/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (33.11 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (30.20 ns/it)[39m


In [26]:
@time train_epoch!(m, opt)

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:02:50[39m:07[39m


176.285394 seconds (224.32 M allocations: 15.359 GiB, 7.24% gc time, 58.07% compilation time)


In [27]:
@time split_loss(m, "training")

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:16[39m


 16.466511 seconds (11.28 M allocations: 3.983 GiB, 14.79% gc time, 16.40% compilation time)


48495.31f0

In [28]:
for i in 1:100
    @time train_epoch!(m, opt)
    @time split_loss(m, "training")
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:34[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:13[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:34[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:13[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:34[39m
[32mProgress:  24%|█████████▊                               |  ETA: 0:00:10[39m

 34.450639 seconds (31.66 M allocations: 5.222 GiB, 8.16% gc time)
 13.217715 seconds (6.68 M allocations: 3.740 GiB, 12.45% gc time)
 34.101590 seconds (31.67 M allocations: 5.222 GiB, 8.04% gc time)
 13.206094 seconds (6.67 M allocations: 3.740 GiB, 12.44% gc time)
 34.401832 seconds (31.67 M allocations: 5.222 GiB, 8.11% gc time)


LoadError: InterruptException:

In [29]:
#using BenchmarkTools

In [30]:
# ProgressMeter.@showprogress for j = 1:10
#     train_epoch!(m, opt)
#     @info split_loss(m, "training")
# end

In [31]:
# split_loss(m, "validation")

In [32]:
# train_model(hyp)

## Write predictions

In [33]:
# function make_prediction(sparse_preds, users, items)
#     preds = zeros(length(users))
#     @tprogress Threads.@threads for j = 1:length(preds)
#         preds[j] = sparse_preds[users[j], items[j]]
#     end
#     preds
# end;

In [34]:
# function save_model(model_path, hyperparams::Hyperparams, outdir)
#     global G = hyperparams
#     BSON.@load model_path m
#     training = evaluate(m, "training")
#     validation = evaluate(m, "validation")
#     test = evaluate(m, "test")
#     df = reduce(cat, [training, validation, test])
#     sparse_preds = sparse(df.user, df.item, df.rating)

#     write_predictions(
#         (users, items) -> make_prediction(sparse_preds, users, items),
#         residual_alphas = G.validation_residuals,
#         outdir = outdir,
#         implicit = G.train_implicit_model,
#     )
#     params = to_dict(G)
#     params["model"] = model_path
#     write_params(params, outdir = outdir)
# end;

In [35]:
# function fit(hyperparams::Hyperparams, outdir)
#     redirect_logging("../../data/alphas/$outdir")
#     @info string(hyperparams)
#     model_path = train_model(hyperparams)
#     save_model(model_path, hyperparams, outdir)
# end;

In [36]:

# hyperparams = Hyperparams(
#     use_derived_features = true,
#     train_implicit_model = true,
#     activation = "relu",
#     autoencode = true,
#     batch_size = 128,
#     dropout_perc = 0.5,
#     dropout_rescale = false,
#     layers = [512, 256, 128, 256, 512],
#     l2penalty = 1e-5,
#     learning_rate = 0.001,
#     optimizer = "ADAM",
#     patience = 10,
#     sampling_weight_scheme = "linear",
#     training_residuals = ["UserItemBiases"],
#     training_weight_scheme = "linear",
#     use_residualized_validation_loss = false,
#     validation_residuals = ["UserItemBiases"],
#     validation_weight_scheme = "constant",
#     seed = 20220501 * hash(name),
# )

# #fit(hyperparams, "GNN.Rating.Test.2")