# Neural Network Base Class

In [1]:
using Flux
using Random
using SparseArrays


import BSON
import NBInclude: @nbinclude
@nbinclude("Alpha.ipynb");

In [2]:
const device = Flux.gpu;

## Parameters

In [3]:
@with_kw struct Hyperparams
    # model
    implicit::Bool
    activation::String
    model::String
    # batching
    autoencode::Bool
    batch_size::Int
    input_data::String
    user_sampling_scheme::String
    # optimizer
    l2penalty::Float32
    learning_rate::Float32
    optimizer::String
    # training
    # dropout_perc::Float32
    # dropout_rescale::Bool
    patience::Int
    seed::UInt64
    # loss
    item_weight_decay::Float32
    loss_function::String
    residual_alphas::Vector{String}
    residual_beta::Float32
    user_weight_decay::Float32
end

function to_dict(x::Hyperparams)
    Dict(string(key) => getfield(x, key) for key ∈ fieldnames(Hyperparams))
end

function Base.string(x::Hyperparams)
    fields = [x for x in fieldnames(Hyperparams)]
    max_field_size = maximum(length(string(k)) for k in fields)
    ret = "Hyperparameters:\n"
    for f in fields
        ret *= "$(rpad(string(f), max_field_size)) => $(getfield(x, f))\n"
    end
    ret
end;

# Batching

In [4]:
function SparseArrays.sparse(split::RatingsDataset)
    sparse(split.item, split.user, split.rating, num_items(), num_users())
end

function SparseArrays.sparse(split::RatingsDataset, ratings)
    sparse(split.item, split.user, ratings, num_items(), num_users())
end;

In [5]:
function one_hot_inputs(split)
    @assert !G.autoencode

    # construct inputs
    X = collect(1:num_users())

    # construct outputs
    Y = sparse(get_split(split; implicit = G.implicit))

    # construct residuals
    residuals = read_alpha(G.residual_alphas, split, G.implicit)
    residuals.rating .*= G.residual_beta
    Z = sparse(residuals)

    # construct loss-function weights
    if split == "training"
        weights =
            expdecay(get_counts(split), G.user_weight_decay) .*
            expdecay(get_counts(split; by_item = true), G.item_weight_decay)
    else
        weights = expdecay(get_counts(split), weighting_scheme("inverse"))
    end
    W = sparse(get_split(split), weights)

    X, Y, Z, W
end;

In [6]:
function get_epoch(split)
    if G.input_data == "one_hot"
        return one_hot_inputs(split)
    end
    @assert false
end;

In [7]:
function get_sampling_order(split)
    weighting_scheme = split == "training" ? G.user_sampling_scheme : "constant"
    if weighting_scheme == "constant"
        return shuffle(1:num_users())
    else
        weights = expdecay(
            get_counts(split; per_rating = false),
            weighting_scheme(G.user_sampling_scheme),
        )
        return sample(1:num_users(), Weights(weights), num_users())
    end
end;

In [8]:
function get_batch(epoch, iter, batch_size, sampling_order)
    range = sampling_order[(iter-1)*batch_size+1:min(iter * batch_size, num_users())]
    function process(x)
        if length(size(x)) == 1
            return collect(x[range]) |> device
        else
            return collect(x[:, range]) |> device
        end
    end
    [process.(epoch)]
end;

function get_batch(epoch, iter, batch_size)
    sampling_order = 1:num_users()
    get_batch(epoch, iter, batch_size, sampling_order)
end;

## Model

In [9]:
# A layer that takes one input and splits it into many
struct Split{T}
    paths::T
end
Split(paths...) = Split(paths)
Flux.@functor Split
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)

# A layer that takes many inputs and joins them into one
Join(combine, paths) = Parallel(combine, paths)
Join(combine, paths...) = Join(combine, paths);

In [10]:
# A layer that adds a 1-D vector to the input
struct BiasLayer
    b::Any
end
BiasLayer(n::Integer) = BiasLayer(randn(n))
(m::BiasLayer)(x) = x .+ m.b
Flux.@functor BiasLayer

In [11]:
function user_item_biases()
    U = Flux.Embedding(num_users() => 1)
    A = BiasLayer(num_items())
    m = Chain(U, A) |> device
end;

In [12]:
function build_model()
    if G.model == "user_item_biases"
        return user_item_biases()
    end
    @assert false
end;

## Loss Functions

In [13]:
function evaluate(m, split)
    LinearAlgebra.BLAS.set_num_threads(1)

    df = get_split(split)
    users = df.user
    items = df.item

    # index users
    user_to_output_idxs = [Dict() for t = 1:Threads.nthreads()]
    @tprogress Threads.@threads for j = 1:length(users)
        u = users[j]
        t = Threads.threadid()
        if u ∉ keys(user_to_output_idxs[t])
            user_to_output_idxs[t][u] = []
        end
        push!(user_to_output_idxs[t][u], j)
    end
    user_to_output_idxs = merge(vcat, user_to_output_idxs...)

    # compute predictions
    ratings = zeros(Float32, length(users))
    epoch = get_epoch(split)
    batch_size = 16
    # TODO singlethread on GPU
    @tprogress Threads.@threads for iter = 1:Int(ceil(num_users() / batch_size))
        batch = get_batch(epoch, iter, batch_size)[1]
        alpha = m(batch[1]) |> cpu
        if G.implicit
            alpha = softmax(alpha)
        end

        for j = 1:size(alpha)[2]
            u = batch[end][1, j]
            if u ∉ keys(user_to_output_idxs)
                continue
            end
            for output_idx in user_to_output_idxs[u]
                ratings[output_idx] = alpha[items[output_idx], j]
            end
        end
    end

    LinearAlgebra.BLAS.set_num_threads(Threads.nthreads())

    RatingsDataset(user = users, item = items, rating = ratings)
end;

In [14]:
G = Hyperparams(
    # model
    activation = "relu",
    implicit = false,
    model = "user_item_biases",
    # batching
    autoencode = false,
    batch_size = 128,
    input_data = "one_hot",
    user_sampling_scheme = "constant",
    # optimizer
    l2penalty = 0,
    learning_rate = 0.001,
    optimizer = "ADAM",
    # training
    patience = 5,
    seed = 20220524,
    # loss
    item_weight_decay = 0,
    loss_function = "NONE",
    residual_alphas = [],
    residual_beta = 0,
    user_weight_decay = 0,
);

m = build_model()

[38;5;6m[1m┌ [22m[39m[38;5;6m[1mInfo: [22m[39m20220525 05:13:29 The GPU function is being called but the GPU is not accessible. 
[38;5;6m[1m└ [22m[39mDefaulting back to the CPU. (No action is required if you want to run on the CPU).


Chain(
  Embedding(1320149 => 1),              [90m# 1_320_149 parameters[39m
  BiasLayer([-0.26932675661872696, -2.440773478136208, 0.08021055035203482, 1.4455818403355054, -0.44806555533155173, 0.25289354203518805, 0.46897454426679674, -0.6797234030284606, 1.4793253642505402, -0.6165317561730165  …  -0.9824382362236107, 1.2368613417734458, -2.53506290867229, 0.01469397586525117, -0.24792565093217855, 0.32514373029162075, 0.7109159008368262, -1.670829085918523, 0.1940653778539473, 0.982331464381786]),  [90m# 18_952 parameters[39m
) 

In [15]:
evaluate(m, "training")

[32mProgress: 100%|███████████████████████████| Time: 0:00:12 ( 2.53 μs/it)[39m39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220525 05:15:16 regression coefficients: Float32[0.0]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.64 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.12 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:35 (13.69 ms/it)[39m


RatingsDataset
  user: Array{Int32}((152803783,)) Int32[851625, 851625, 851625, 851625, 851625, 851625, 851625, 851625, 851625, 851625  …  369188, 369188, 369188, 369188, 369188, 369188, 369188, 369188, 369188, 369188]
  item: Array{Int32}((152803783,)) Int32[11528, 498, 805, 41, 12807, 7670, 101, 6464, 11, 3148  …  4961, 14194, 3, 6082, 774, 14254, 15584, 9017, 14423, 369]
  rating: Array{Float32}((152803783,)) Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]


In [16]:
function rating_loss(ŷ, y, weights)
    # only compute loss on items the user has seen
    mask = y .!= 0
    per_user_mse =
        sum(((ŷ .- y) .* mask) .^ 2, dims = 1) ./
        max.(one(eltype(weights)), sum(mask, dims = 1))
    dot(per_user_mse, weights) / sum(weights)
end

function implicit_loss(ŷ, y, weights)
    agg(x) = dot(x, weights)
    Flux.logitcrossentropy(ŷ, y, agg = agg) / sum(weights)
end

loss(m, x, y, weights) =
    G.train_implicit_model ? implicit_loss(m(x), y, weights) : rating_loss(m(x), y, weights);

In [17]:
function get_loss(m, split)
    LinearAlgebra.BLAS.set_num_threads(1)

    epoch = get_epoch(split)
    batch_size = 16
    losses = zeros(Threads.nthreads())
    @tprogress Threads.@threads for iter = 1:Int(ceil(n_users / batch_size))
        batch = get_batch(epoch, iter, batch_size)
        losses[Threads.threadid()] += loss(m, batch[1]...) * sum(batch[1][3])
    end

    LinearAlgebra.BLAS.set_num_threads(Threads.nthreads())

    sum(losses) / sum(epoch[3])
end;

In [18]:
function get_residualized_loss(m, split)
    rating = evaluate(m, split).rating
    df = get_residuals(split, G.validation_residuals)

    # turn per-user weights into per-item weights
    W = get_derived_feature(
        df,
        (_, count) ->
            weighting_scheme(count, G.validation_weight_scheme) *
            weighting_scheme(count, "inverse"),
    )
    weights = zeros(eltype(rating), length(df.user))
    Threads.@threads for i = 1:length(weights)
        weights[i] = W[df.user[i]]
    end

    if G.train_implicit_model
        @assert false
    else
        Y = df.rating .* sqrt.(weights)
        X = rating .* sqrt.(weights)
        β = X \ Y
        @info "beta: $β"
        return mse(Y, X .* β, weights)
    end
end;

## Training

In [19]:
function checkpoint(m)
    loss_fn = G.use_residualized_validation_loss ? get_residualized_loss : get_loss
    training_loss = loss_fn(m, "training")
    validation_loss = loss_fn(m, "validation")
    @info "training loss $training_loss, validation loss $validation_loss"
    validation_loss
end;

In [20]:
function continue_training(m, stop_criteria, model_path)
    validation_loss = checkpoint(m)
    if validation_loss < stop_criteria.loss
        BSON.@save model_path m
    end
    !stop!(stop_criteria, validation_loss)
end;

In [21]:
function train_epoch!(m, opt; checkpoint_rate = 0.1)
    LinearAlgebra.BLAS.set_num_threads(Threads.nthreads())
    ps = Flux.params(m)
    train_loss(x, y, w) = loss(m, x, y, w)
    epoch = get_epoch("training")
    sampling_order = get_sampling_order("training")

    nbatches = Int(ceil(size(epoch[1])[2] / G.batch_size))
    ProgressMeter.@showprogress for iter = 1:nbatches
        batch = get_batch(epoch, iter, G.batch_size, sampling_order)
        Flux.train!(train_loss, ps, batch, opt)

        if iter % Int(round(nbatches * checkpoint_rate)) == 0
            checkpoint(m)
        end
    end
end;

In [22]:
function train_model(hyperparams::Hyperparams)
    # unpack parameters
    global G = hyperparams
    Random.seed!(G.seed)
    m = build_model()
    if G.optimizer == "ADAM"
        opt = ADAMW(G.learning_rate, (0.9, 0.999), G.l2penalty)
    else
        @assert false
    end
    stop_criteria = early_stopper(patience = G.patience)
    model_path = "../../data/alphas/$name/model.$(hash(G)).bson"

    # Train model
    while continue_training(m, stop_criteria, model_path)
        train_epoch!(m, opt)
    end

    model_path
end;

## Write predictions

In [23]:
function make_prediction(sparse_preds, users, items)
    preds = zeros(length(users))
    @tprogress Threads.@threads for j = 1:length(preds)
        preds[j] = sparse_preds[users[j], items[j]]
    end
    preds
end;

In [24]:
function save_model(model_path, hyperparams::Hyperparams, outdir)
    global G = hyperparams
    BSON.@load model_path m
    training = evaluate(m, "training")
    validation = evaluate(m, "validation")
    test = evaluate(m, "test")
    df = reduce(cat, [training, validation, test])
    sparse_preds = sparse(df.user, df.item, df.rating)

    write_predictions(
        (users, items) -> make_prediction(sparse_preds, users, items),
        residual_alphas = G.validation_residuals,
        outdir = outdir,
        implicit = G.train_implicit_model,
    )
    params = to_dict(G)
    params["model"] = model_path
    write_params(params, outdir = outdir)
end;

In [25]:
function fit(hyperparams::Hyperparams, outdir)
    redirect_logging("../../data/alphas/$outdir")
    @info string(hyperparams)
    model_path = train_model(hyperparams)
    save_model(model_path, hyperparams, outdir)
end;

In [26]:
# hyperparams = Hyperparams(
#     use_derived_features = true,
#     train_implicit_model = true,
#     activation = "relu",
#     autoencode = true,
#     batch_size = 128,
#     dropout_perc = 0.5,
#     dropout_rescale = false,
#     layers = [512, 256, 128, 256, 512],
#     l2penalty = 1e-5,
#     learning_rate = 0.001,
#     optimizer = "ADAM",
#     patience = 10,
#     sampling_weight_scheme = "linear",
#     training_residuals = ["UserItemBiases"],
#     training_weight_scheme = "linear",
#     use_residualized_validation_loss = false,
#     validation_residuals = ["UserItemBiases"],
#     validation_weight_scheme = "constant",
#     seed = 20220501 * hash(name),
# )

# #fit(hyperparams, "GNN.Rating.Test.2")