## Hyperparameters
* Contains all the information necessary to train a new model
* A derivative free optimizer is used to find the best hyperparameters

In [None]:
import NLopt

In [None]:
@with_kw struct Hyperparams
    # model
    model::String
    # data
    content::String
    implicit::Bool
    input_data::String
    input_alphas::Vector{String}
    output_data::String
    # batching
    batch_size::Int
    user_sampling_scheme::Float32
    # optimizer
    learning_rate::Float32
    optimizer_weight_decay::Float32
    optimizer::String
    # training
    seed::UInt64
    num_users::Int
    holdout::Float32
    temporal_holdout::Float32
    # loss
    item_weight_decay::Float32
    residual_alphas::Vector{String}
    temporal_weight_decay::Float32
    user_weight_decay::Float32
end

function to_dict(x::Hyperparams)
    Dict(string(key) => getfield(x, key) for key ∈ fieldnames(Hyperparams))
end

function Base.string(x::Hyperparams)
    fields = [x for x in fieldnames(Hyperparams)]
    max_field_size = maximum(length(string(k)) for k in fields)
    ret = "Hyperparameters:\n"
    for f in fields
        ret *= "$(rpad(string(f), max_field_size)) => $(getfield(x, f))\n"
    end
    ret
end

function Base.show(io::IO, x::Hyperparams)
    print(io, string(x))
end;

In [None]:
function get_epochs_per_checkpoint(model::String)
    if startswith(model, "item_biases") ||
       startswith(model, "autoencoder") ||
       startswith(model, "double_embedding") ||
       startswith(model, "metadata_embedding")
        return 10
    elseif startswith(model, "item_based_collaborative_filtering") ||
           startswith(model, "ease")
        return 1
    else
        @assert false
    end
end;

In [None]:
function get_subsampling_factor(model::String)
    if startswith(model, "item_biases")
        return 0.01
    elseif startswith(model, "autoencoder") ||
           startswith(model, "double_embedding") ||
           startswith(model, "metadata_embedding")
        return 0.10
    elseif startswith(model, "item_based_collaborative_filtering") ||
           startswith(model, "ease")
        return 0.25
    else
        @assert false
    end
end;

In [None]:
function should_holdout_items(model::String)
    if startswith(model, "item_biases") ||
       startswith(model, "item_based_collaborative_filtering") ||
       startswith(model, "ease")
        return false
    elseif startswith(model, "autoencoder") ||
           startswith(model, "double_embedding") ||
           startswith(model, "metadata_embedding")
        return true
    else
        @assert false
    end
end

function should_temporal_batch(model::String)
    "temporal" in split(model, ".")
end;

In [None]:
function get_optimizer(
    optimizer::String,
    learning_rate::Float32,
    optimizer_weight_decay::Float32,
)
    if optimizer == "ADAMW"
        return ADAMW(learning_rate, (0.9, 0.999), optimizer_weight_decay)
    else
        @assert false
    end
end;

In [None]:
function num_tuneable_params(model::String)
    num_model_params = 7
    if should_holdout_items(model)
        num_model_params += 1
        if should_temporal_batch(model)
            num_model_params += 1
        end
    end
    num_model_params
end;

In [None]:
function get_input_data_type(model::String, implicit::Bool)
    if startswith(model, "item_biases")
        return "one_hot"
    elseif startswith(model, "item_based_collaborative_filtering")
        return "explicit"
    elseif startswith(model, "ease")
        if implicit
            return "implicit"
        else
            return "explicit"
        end
    elseif startswith(model, "autoencoder")
        return "explicit_implicit"
    elseif startswith(model, "double_embedding")
        return "explicit_implicit_tuple"
    elseif startswith(model, "metadata_embedding")
        return "impression_metadata"
    else
        @assert false
    end
end;

In [None]:
function get_output_data_type(model::String)
    if startswith(model, "item_biases") ||
       startswith(model, "item_based_collaborative_filtering") ||
       startswith(model, "ease") ||
       startswith(model, "autoencoder") ||
       startswith(model, "metadata_embedding")
        return "allitems"
    elseif startswith(model, "double_embedding")
        return "item"
    else
        @assert false
    end
end;

In [None]:
function create_hyperparams(hyp::Hyperparams, λ::Vector{Float32})
    # normalize λ such that a step size of 1 is reasonable
    # and so that λ=1 is a more promising direction than λ=-1
    index = 0
    function incr()
        index = index + 1
    end
    hyp = @set hyp.learning_rate = 10^(-λ[incr()] - 3)
    hyp = @set hyp.temporal_weight_decay = log(0.5) / log(year_in_timestamp_units() * 3 * exp(-λ[incr()]))
    if should_holdout_items(hyp.model)
        hyp = @set hyp.holdout = sigmoid(-λ[incr()])
        if should_temporal_batch(hyp.model)
            hyp = @set hyp.temporal_holdout = max(1 - exp(λ[incr()]) * year_in_timestamp_units(), eps(Float32))
        end
    end
    hyp = @set hyp.user_sampling_scheme = λ[incr()]    
    hyp = @set hyp.user_weight_decay = λ[incr()] - 1
    hyp = @set hyp.item_weight_decay = λ[incr()]
    hyp = @set hyp.optimizer_weight_decay = 10^(λ[incr()] - 5)
    hyp
end

function create_hyperparams(
    model::String,
    content::String,
    residual_alphas::Vector{String},
    input_alphas::Vector{String},
)
    if content == "explicit"
        implicit = false
    elseif content == "implicit"
        implicit = true
    elseif content == "ptw"
        implicit = true
    else
        @assert false
    end
    hyp = Hyperparams(
        # model
        model = model,
        # data
        content = content,
        implicit = implicit,
        input_data = get_input_data_type(model, implicit),
        input_alphas = input_alphas,
        output_data = get_output_data_type(model),
        # batching
        batch_size = 1024,
        user_sampling_scheme = NaN,
        # optimizer
        learning_rate = NaN,
        optimizer_weight_decay = NaN,
        optimizer = "ADAMW",
        # training
        seed = 20220524,
        num_users = num_users(),
        holdout = NaN,
        temporal_holdout = NaN,
        # loss
        item_weight_decay = NaN,
        residual_alphas = residual_alphas,
        temporal_weight_decay = NaN,
        user_weight_decay = NaN,
    )
    create_hyperparams(hyp, zeros(Float32, num_tuneable_params(model)))
end;

In [None]:
function nlopt_optimize(
    lossfn,
    n::Int;
    max_evals::Int = 100,
    max_time::Int = 86400,
    ftol_rel::Real = 1e-4,
    xtol_rel::Real = 1e-4,
)
    opt = NLopt.Opt(:LN_NELDERMEAD, n)
    opt.initial_step = 1
    opt.maxeval = max_evals
    opt.maxtime = max_time
    opt.ftol_rel = ftol_rel
    opt.xtol_rel = xtol_rel
    opt.min_objective = lossfn
    minf, λ, ret = NLopt.optimize(opt, zeros(Float32, n))
    numevals = opt.numevals
    @info (
        "found minimum $minf at point $λ after $numevals function calls " *
        "(ended because $ret)"
    )
    convert.(Float32, λ)
end;

function optimize_hyperparams(hyp::Hyperparams; max_evals::Int)
    function nlopt_loss(λ, grad)
        _, _, loss = train_model(
            create_hyperparams(hyp, convert.(Float32, λ));
            epochs_per_checkpoint = get_epochs_per_checkpoint(hyp.model),
            patience = 0,
        )
        @info "$λ $loss"
        loss
    end
    nlopt_optimize(nlopt_loss, num_tuneable_params(hyp.model); max_evals = max_evals)
end;

In [None]:
function optimize_learning_rate(hyp::Hyperparams, max_iters::Int)
    # exponentially decay the learning rate whenever we hit a plateau    
    learning_rate_decay = exp(-1)
    hyp = @set hyp.learning_rate /= learning_rate_decay
    m = nothing
    validation_loss = Inf
    epochs = 0
    stopper = early_stopper(patience = 0, min_rel_improvement = 1e-3, max_iters = max_iters)
    losses = []
    verbose = "info"

    while !stop!(stopper, validation_loss)
        probe_hyp = @set hyp.learning_rate *= learning_rate_decay
        probe_m, probe_epochs, validation_loss = train_model(
            probe_hyp;
            max_checkpoints = 1000,
            epochs_per_checkpoint = 1,
            patience = get_epochs_per_checkpoint(probe_hyp.model),
            init_model = m,
            verbose = verbose,
        )
        @info "loss: $validation_loss learning_rate: $(probe_hyp.learning_rate)"
        push!(losses, validation_loss)
        if validation_loss == minimum(losses)
            m = probe_m
            epochs = probe_epochs
            hyp = probe_hyp
        end
    end

    m, epochs, minimum(losses), hyp
end;