# Generalized Neural Network
* A denoising autoencoder 

In [1]:
const name = "GNN.Rating.Test"

"GNN.Rating.Test"

In [2]:
using Random
import BSON

In [3]:
using NBInclude
@nbinclude("Alpha.ipynb");

In [4]:
const device = gpu;

## Parameters

In [29]:
@with_kw struct Hyperparams
    # model
    train_implicit_model::Bool
    use_derived_features::Bool
    # training
    activation::String
    autoencode::Bool
    batch_size::Int
    dropout_perc::Float32
    dropout_rescale::Bool
    layers::Vector{Int}
    l2penalty::Float32
    learning_rate::Float32
    optimizer::String
    patience::Int
    # loss functions
    sampling_weight_scheme::String
    training_residuals::Vector{String}
    training_weight_scheme::String
    use_residualized_validation_loss::Bool
    validation_residuals::Vector{String}
    validation_weight_scheme::String
    # misc
    seed::UInt64
end

function to_dict(x::Hyperparams)
    Dict(string(key) => getfield(x, key) for key ∈ fieldnames(Hyperparams))
end

function Base.string(x::Hyperparams)
    fields = [x for x in fieldnames(Hyperparams)]
    max_field_size = maximum(length(string(k)) for k in fields)
    ret = "Hyperparameters:\n"
    for f in fields
        ret *= "$(rpad(string(f), max_field_size)) => $(getfield(x, f))\n"
    end
    ret
end;

In [6]:
const n_items = num_items() + 1 # leave room to map unseen items
const n_users = maximum(get_split("training").user) + 1; # leave room to map unseen users

In [7]:
# column accesses are faster than row accesses, so we make this an (item, user) matrix 
function to_sparse_mat(split)
    sparse(split.item, split.user, split.rating, n_items, n_users)
end;

In [8]:
function get_derived_feature(split, agg)
    sums = zeros(Float32, n_users, Threads.nthreads())
    counts = zeros(Float32, n_users, Threads.nthreads())
    @tprogress Threads.@threads for i = 1:length(split.rating)
        sums[split.user[i], Threads.threadid()] += split.rating[i]
        counts[split.user[i], Threads.threadid()] += 1
    end
    sums = sum(sums, dims = 2)
    counts = sum(counts, dims = 2)
    sparse(agg.(sums, counts)')
end;

In [10]:
function get_epoch(split)
    # todo support G.autoencode = false
    @assert G.autoencode

    # construct inputs
    X = vcat(
        to_sparse_mat(get_residuals("training", G.training_residuals)),
        to_sparse_mat(get_split("implicit_training")),
    )
    if G.use_derived_features
        Xd = vcat(
            # fraction of implicit items
            get_derived_feature(
                get_split("implicit_training"),
                (sum, count) -> count / n_items,
            ),
            # fraction of seen items
            get_derived_feature(get_split("training"), (sum, count) -> count / n_items),
            # average item rating
            get_derived_feature(
                get_split("training"),
                (sum, count) -> sum / max(1, count) / 10,
            ),
        )
        X = vcat(X, Xd, Xd .^ 2, sqrt.(Xd))
    end
    if split == "training" && G.dropout_rescale
        X .* (1 - G.dropout_perc)
    end

    # construct outputs
    Y = to_sparse_mat(get_residuals(split, G.validation_residuals))
    if G.train_implicit_model
        Y.nzval .= 1
    end

    # How much to weight each user in the loss function    
    function count_to_weight(x)
        scheme = split == "training" ? G.training_weight_scheme : G.validation_weight_scheme
        weighting_scheme(x, scheme)
    end
    W = get_derived_feature(get_split(split), (_, count) -> count_to_weight(count))

    X, Y, W
end;

In [11]:
function get_sampling_order(split)
    weights = vec(
        collect(
            get_derived_feature(
                get_split(split),
                (_, count) -> weighting_scheme(count, G.sampling_weight_scheme),
            ),
        ),
    )
    sample(1:n_users, Weights(weights), n_users)
end;

In [12]:
function get_batch(epoch, iter, batch_size, sampling_order)
    range = sampling_order[(iter-1)*batch_size+1:min(iter * batch_size, size(epoch[1])[2])]
    process(x) = collect(x[:, range]) |> device
    [process.(epoch)]
end;

function get_batch(epoch, iter, batch_size)
    sampling_order = 1:size(epoch[1])[2]
    get_batch(epoch, iter, batch_size, sampling_order)
end;

## Model

In [13]:
function generate_model()
    # inputs are the user's ratings for all shows (unseen shows get mapped to zero) + implicit ratings + heterogenous features
    # outputs are the user's ratings for all shows (unseen shows get mapped to zero), implicit ratings
    # we will train ratings using mse on observed shows, and implicit ratings via crossentropy loss
    n_inputs = n_items + n_items + (G.use_derived_features ? 9 : 0)
    layers = [[n_inputs]; G.layers; [n_items]]
    if G.activation == "relu"
        activation = relu
    else
        @assert false
    end
    autoencoder = [Dense(layers[i], layers[i+1], activation) for i = 1:length(layers)-1]
    m = Chain(Dropout(G.dropout_perc), autoencoder...) |> device
    m |> device
end;

## Loss Functions

In [14]:
function evaluate(m, split)
    BLAS.set_num_threads(1)

    df = get_split(split)
    users = df.user
    items = df.item

    # index users
    user_to_output_idxs = [Dict() for t = 1:Threads.nthreads()]
    @tprogress Threads.@threads for j = 1:length(users)
        u = users[j]
        t = Threads.threadid()
        if u ∉ keys(user_to_output_idxs[t])
            user_to_output_idxs[t][u] = []
        end
        push!(user_to_output_idxs[t][u], j)
    end
    user_to_output_idxs = merge(vcat, user_to_output_idxs...)

    # compute predictions
    ratings = zeros(Float32, length(users))
    epoch = get_epoch(split)
    epoch = (epoch[1], epoch[2], collect(1:n_users)')
    batch_size = 16
    @tprogress Threads.@threads for iter = 1:Int(ceil(n_users / batch_size))
        batch = get_batch(epoch, iter, batch_size)[1]
        alpha = m(batch[1]) |> cpu
        if G.train_implicit_model
            alpha .= exp.(alpha)
            alpha .= alpha ./ sum(alpha, dims = 1)
        end

        for j = 1:size(alpha)[2]
            u = batch[3][1, j] # ??????? HOW DID THIS OWRK AT ALL?????
            if u ∉ keys(user_to_output_idxs)
                continue
            end
            for output_idx in user_to_output_idxs[u]
                ratings[output_idx] = alpha[items[output_idx], j]
            end
        end
    end

    BLAS.set_num_threads(Threads.nthreads())

    RatingsDataset(user = users, item = items, rating = ratings)
end;

In [15]:
function rating_loss(ŷ, y, weights)
    # only compute loss on items the user has seen
    mask = y .!= 0
    per_user_mse =
        sum(((ŷ .- y) .* mask) .^ 2, dims = 1) ./
        max.(one(eltype(weights)), sum(mask, dims = 1))
    dot(per_user_mse, weights) / sum(weights)
end

function implicit_loss(ŷ, y, weights)
    agg(x) = dot(x, weights)
    Flux.logitcrossentropy(ŷ, y, agg = agg) / sum(weights)
end

loss(m, x, y, weights) =
    G.train_implicit_model ? implicit_loss(m(x), y, weights) : rating_loss(m(x), y, weights);

In [16]:
function get_loss(m, split)
    BLAS.set_num_threads(1)

    epoch = get_epoch(split)
    batch_size = 16
    losses = zeros(Threads.nthreads())
    @tprogress Threads.@threads for iter = 1:Int(ceil(n_users / batch_size))
        batch = get_batch(epoch, iter, batch_size)
        losses[Threads.threadid()] += loss(m, batch[1]...) * sum(batch[1][3])
    end

    BLAS.set_num_threads(Threads.nthreads())

    sum(losses) / sum(epoch[3])
end;

In [17]:
function get_residualized_loss(m, split)
    rating = evaluate(m, split).rating
    df = get_residuals(split, G.validation_residuals)

    # turn per-user weights into per-item weights
    W = get_derived_feature(
        df,
        (_, count) ->
            weighting_scheme(count, G.validation_weight_scheme) *
            weighting_scheme(count, "inverse"),
    )
    weights = zeros(eltype(rating), length(df.user))
    Threads.@threads for i = 1:length(weights)
        weights[i] = W[df.user[i]]
    end

    if G.train_implicit_model
        @assert false
    else
        Y = df.rating .* sqrt.(weights)
        X = rating .* sqrt.(weights)
        β = X \ Y
        @info "beta: $β"
        return mse(Y, X .* β, weights)
    end
end;

## Training

In [18]:
function checkpoint(m)
    loss_fn = G.use_residualized_validation_loss ? get_residualized_loss : get_loss
    training_loss = loss_fn(m, "training")
    validation_loss = loss_fn(m, "validation")
    @info "training loss $training_loss, validation loss $validation_loss"
    validation_loss
end;

In [19]:
function continue_training(m, stop_criteria, model_path)
    validation_loss = checkpoint(m)
    if validation_loss < stop_criteria.loss
        BSON.@save model_path m
    end
    !stop!(stop_criteria, validation_loss)
end;

In [20]:
function train_epoch!(m, opt; checkpoint_rate = 0.1)
    BLAS.set_num_threads(Threads.nthreads())
    ps = Flux.params(m)
    train_loss(x, y, w) = loss(m, x, y, w)
    epoch = get_epoch("training")
    sampling_order = get_sampling_order("training")

    nbatches = Int(ceil(size(epoch[1])[2] / G.batch_size))
    @showprogress for iter = 1:nbatches
        batch = get_batch(epoch, iter, G.batch_size, sampling_order)
        Flux.train!(train_loss, ps, batch, opt)

        if iter % Int(round(nbatches * checkpoint_rate)) == 0
            checkpoint(m)
        end
    end
end;

In [21]:
function train_model(hyperparams::Hyperparams)
    # unpack parameters
    global G = hyperparams
    Random.seed!(G.seed)
    m = generate_model()
    if G.optimizer == "ADAM"
        opt = ADAMW(G.learning_rate, (0.9, 0.999), G.l2penalty)
    else
        @assert false
    end
    stop_criteria = early_stopper(patience = G.patience)
    model_path = "../../data/alphas/$name/model.$(hash(G)).bson"

    # Train model
    while continue_training(m, stop_criteria, model_path)
        train_epoch!(m, opt)
    end

    model_path
end;

## Write predictions

In [22]:
function make_prediction(sparse_preds, users, items)
    preds = zeros(length(users))
    @tprogress Threads.@threads for j = 1:length(preds)
        preds[j] = sparse_preds[users[j], items[j]]
    end
    preds
end;

In [23]:
function save_model(model_path, hyperparams::Hyperparams, outdir)
    global G = hyperparams
    BSON.@load model_path m
    training = evaluate(m, "training")
    validation = evaluate(m, "validation")
    test = evaluate(m, "test")
    df = reduce(cat, [training, validation, test])
    sparse_preds = sparse(df.user, df.item, df.rating)

    write_predictions(
        (users, items) -> make_prediction(sparse_preds, users, items),
        residual_alphas = G.validation_residuals,
        outdir = outdir,
        implicit = G.train_implicit_model,
    )
    params = to_dict(G)
    params["model"] = model_path
    write_params(params, outdir = outdir)
end;

In [24]:
function fit(hyperparams::Hyperparams, outdir)
    redirect_logging("../../data/alphas/$outdir")
    @info string(hyperparams)    
    model_path = train_model(hyperparams)
    save_model(model_path, hyperparams, outdir)
end;

In [25]:
hyperparams = Hyperparams(
    use_derived_features = true,
    train_implicit_model = true,
    activation = "relu",
    autoencode = true,
    batch_size = 128,
    dropout_perc = 0.5,
    dropout_rescale = false,
    layers = [512, 256, 128, 256, 512],
    l2penalty = 1e-5,
    learning_rate = 0.001,
    optimizer = "ADAM",
    patience = 10,
    sampling_weight_scheme = "linear",
    training_residuals = ["UserItemBiases"],
    training_weight_scheme = "linear",
    use_residualized_validation_loss = false,
    validation_residuals = ["UserItemBiases"],
    validation_weight_scheme = "constant",
    seed = 20220501 * hash(name),
)

#fit(hyperparams, "GNN.Rating.Test.2")

Hyperparams
  train_implicit_model: Bool true
  use_derived_features: Bool true
  activation: String "relu"
  autoencode: Bool true
  batch_size: Int64 128
  dropout_perc: Float32 0.5f0
  dropout_rescale: Bool false
  layers: Array{Int64}((5,)) [512, 256, 128, 256, 512]
  l2penalty: Float32 1.0f-5
  learning_rate: Float32 0.001f0
  optimizer: String "ADAM"
  patience: Int64 10
  sampling_weight_scheme: String "linear"
  training_residuals: Array{String}((1,))
  training_weight_scheme: String "linear"
  use_residualized_validation_loss: Bool false
  validation_residuals: Array{String}((1,))
  validation_weight_scheme: String "constant"
  seed: UInt64 0x811af080737da77b
