## Hyperparameters
* Contains all the information necessary to train a new model
* A derivative free optimizer is used to find the best hyperparameters

In [None]:
import NLopt
import Setfield: @set

In [3]:
@with_kw struct Hyperparams
    # model
    implicit::Bool
    input_data::String
    input_alphas::Vector{String}
    model::String
    # batching
    batch_size::Int
    user_sampling_scheme::Union{String,Float32}
    # optimizer
    learning_rate::Float32
    optimizer::String
    # training
    seed::UInt64
    num_users::Int
    # loss
    item_weight_decay::Float32
    regularization_params::Vector{Float32}
    residual_alphas::Vector{String}
    user_weight_decay::Float32
end

function to_dict(x::Hyperparams)
    Dict(string(key) => getfield(x, key) for key ∈ fieldnames(Hyperparams))
end

function Base.string(x::Hyperparams)
    fields = [x for x in fieldnames(Hyperparams)]
    max_field_size = maximum(length(string(k)) for k in fields)
    ret = "Hyperparameters:\n"
    for f in fields
        ret *= "$(rpad(string(f), max_field_size)) => $(getfield(x, f))\n"
    end
    ret
end

function Base.show(io::IO, x::Hyperparams)
    print(io, string(x))
end;

In [23]:
function get_epochs_per_checkpoint(model)
    if model == "user_item_biases" || startswith(model, "matrix_factorization") || model == "autoencoder"
        return 10
    elseif model == "item_based_collaborative_filtering" || model == "ease"
        return 1
    else
        @assert false
    end
end;

In [11]:
function get_subsampling_factor(model)
    if model == "user_item_biases" || startswith(model, "matrix_factorization")
        return 0.01
    elseif model == model == "autoencoder"
        return 0.1
    elseif model == "item_based_collaborative_filtering" || model == "ease"
        return 0.25
    else
        @assert false
    end
end;

In [None]:
function should_holdout_items(model)
    if model == "user_item_biases" ||
       startswith(model, "matrix_factorization") ||
       model == "item_based_collaborative_filtering" ||
       model == "ease"
        return false
    elseif model == "autoencoder"
        return true
    else
        @assert false
    end
end;

In [26]:
function should_retrain_user_embeddings(model)
    if model == "user_item_biases" || startswith(model, "matrix_factorization")
        return true
    elseif model == "item_based_collaborative_filtering" || model == "autoencoder" || model == "ease"
        return false
    else
        @assert false
    end
end;

In [19]:
function get_optimizer(optimizer, learning_rate, regularization_params)
    if optimizer == "ADAM"
        return ADAM(learning_rate)
    elseif optimizer == "ADAMW"
        return ADAMW(learning_rate, (0.9, 0.999), regularization_params[1])
    else
        @assert false
    end
end;

In [None]:
function num_tuneable_params(model)
    num_model_params = 3
    if (model == "user_item_biases") || startswith(model, "matrix_factorization")
        num_sampling_params = 0
        num_regularization_params = 2
    elseif model == "item_based_collaborative_filtering"
        num_sampling_params = 1
        num_regularization_params = 1
    elseif model == "autoencoder"
        num_sampling_params = 1
        num_regularization_params = 2
    elseif model == "ease"
        num_sampling_params = 1
        num_regularization_params = 1        
    else
        @assert false
    end
    num_model_params, num_sampling_params, num_regularization_params
end;

In [None]:
function get_input_type(model, implicit)
    if (model == "user_item_biases") || startswith(model, "matrix_factorization")
        return "one_hot"
    elseif model == "item_based_collaborative_filtering"
        return "explicit"
    elseif model == "ease"
        if implicit
            return "implicit"
        else
            return "explicit"
        end
    elseif model == "autoencoder"
        return "explicit_implicit"
    else
        @assert false
    end    
end;

In [None]:
function get_optimizer_type(model)
    if (model == "user_item_biases") || startswith(model, "matrix_factorization")
        return "ADAM"
    else
        return "ADAMW"    
    end
end

In [24]:
function create_hyperparams(hyp::Hyperparams, λ)
    _, num_sampling_params, num_regularization_params = num_tuneable_params(hyp.model)
    hyp = @set hyp.learning_rate = 10^(λ[1]-3)
    hyp = @set hyp.user_weight_decay = λ[2]
    hyp = @set hyp.item_weight_decay = λ[3]
    if num_sampling_params == 1
        hyp = @set hyp.user_sampling_scheme = λ[4]
    else
        hyp = @set hyp.user_sampling_scheme = "constant"
    end
    hyp = @set hyp.regularization_params = 10 .^ (λ[end-num_regularization_params+1:end] .- 5)
    if should_holdout_items(hyp.model)
        hyp.regularization_params[end] = sigmoid(λ[end])
    end
    hyp
end

function create_hyperparams(model::String, implicit, residual_alphas, input_alphas = [])
    hyp = Hyperparams(
        implicit = implicit,
        model = model,
        batch_size = 1024,
        input_alphas = input_alphas,
        input_data = get_input_type(model, implicit),
        user_sampling_scheme = NaN32,
        learning_rate = NaN,
        optimizer = get_optimizer_type(model),
        seed = 20220524,
        num_users = num_users(),
        item_weight_decay = NaN,
        regularization_params = fill(NaN, num_tuneable_params(model)[end]),
        residual_alphas = residual_alphas,
        user_weight_decay = NaN,
    )
    create_hyperparams(hyp, zeros(Float32, sum(num_tuneable_params(model))))
end;

In [25]:
function nlopt_optimize(
    lossfn,
    n;
    max_evals = 100,
    max_time = 86400,
    ftol_rel = 1e-4,
    xtol_rel = 1e-4,
)
    opt = NLopt.Opt(:LN_NELDERMEAD, n)
    opt.initial_step = 1
    opt.maxeval = max_evals
    opt.maxtime = max_time
    opt.ftol_rel = ftol_rel
    opt.xtol_rel = xtol_rel
    opt.min_objective = lossfn
    minf, λ, ret = NLopt.optimize(opt, zeros(Float32, n))
    numevals = opt.numevals
    @info (
        "found minimum $minf at point $λ after $numevals function calls " *
        "(ended because $ret)"
    )
    convert.(Float32, λ)
end;

function optimize_hyperparams(hyp; max_evals)
    function nlopt_loss(λ, grad)
        _, _, loss = train_model(
            create_hyperparams(hyp, convert.(Float32, λ));
            epochs_per_checkpoint = get_epochs_per_checkpoint(hyp.model),
            patience = 0,
        )
        @info "$λ $loss"
        loss
    end
    nlopt_optimize(nlopt_loss, sum(num_tuneable_params(hyp.model)); max_evals = max_evals)
end;

In [34]:
function optimize_learning_rate(hyp, max_iters)
    # exponentially decay the learning rate whenever we hit a plateau    
    learning_rate_decay = exp(-1)
    hyp = @set hyp.learning_rate /= learning_rate_decay
    m = nothing
    validation_loss = Inf
    epochs = 0
    stopper = early_stopper(patience = 0, min_rel_improvement = 1e-4, max_iters = max_iters)
    losses = []
    verbose = max_iters == 1 ? "info" : false

    while !stop!(stopper, validation_loss)
        probe_hyp = @set hyp.learning_rate *= learning_rate_decay
        probe_m, probe_epochs, validation_loss = train_model(
            probe_hyp;
            max_checkpoints = 1000,
            epochs_per_checkpoint = 1,
            patience = get_epochs_per_checkpoint(probe_hyp.model),
            init_model = m,
            verbose = verbose,
        )
        @info "loss: $validation_loss learning_rate: $(probe_hyp.learning_rate)"
        push!(losses, validation_loss)
        if validation_loss == minimum(losses)
            m = probe_m
            epochs = probe_epochs
            hyp = probe_hyp
        end
    end

    m, epochs, minimum(losses), hyp
end;