## Retrain User Embeddings
* To minimize training/serving skew, we train the model the same
  way we will train it during inference
* This means reinitializing the user embeddings, freezing all other layers,
  and fine-tuning the user embeddings
* During serving, we will determine a new user's embedding
  by training with the same hyperparameters and number of epochs

In [28]:
function build_retrain_model(hyp, m)
    global G = hyp
    Random.seed!(G.seed)
    if G.model == "user_item_biases"
        embedding_size = 1
        initfn = (x...) -> zeros(Float32, x...)
    elseif startswith(G.model, "matrix_factorization")
        embedding_size = parse(Int, split(G.model, "_")[end])
        initfn = Flux.glorot_uniform
    else
        @assert false
    end
    m = m |> cpu
    Chain(Flux.Embedding(G.num_users => embedding_size, init = initfn), m[2:end]...)
end;

In [29]:
function optimize_retraining_hyperparams(hyp, m, max_iters)
    function nlopt_loss(λ, grad)
        @assert length(hyp.regularization_params) == 2
        probe_hyp = hyp
        probe_hyp = @set probe_hyp.learning_rate = 0.001 * 10^λ[1]
        probe_hyp = @set probe_hyp.regularization_params = [1e-5 * 10^λ[2], 0]
        probe_m = build_retrain_model(probe_hyp, m)
        _, _, loss = train_model(
            probe_hyp;
            epochs_per_checkpoint = get_epochs_per_checkpoint(hyp.model),
            init_model = probe_m,
            fine_tune_layers = 1,
            patience = 0,
        )
        @info "$λ $loss"
        loss
    end
    nlopt_optimize(nlopt_loss, 2; max_evals = max_iters, max_time = 3600)
end;

In [30]:
function retrain_user_embeddings(hyp, m, max_iters)
    optimize_hyp =
        @set hyp.num_users = Int(round(num_users() * get_subsampling_factor(hyp.model)))
    λ = optimize_retraining_hyperparams(optimize_hyp, m, max_iters)
    retrain_hyp = @set hyp.learning_rate = 0.001 * 10^λ[1]
    retrain_hyp = @set retrain_hyp.regularization_params = [1e-5 * 10^λ[2], 0]
    m = build_retrain_model(retrain_hyp, m)
    retrained_model = train_model(
        retrain_hyp,
        init_model = m,
        fine_tune_layers = 1,
        epochs_per_checkpoint = 1,
        patience = get_epochs_per_checkpoint(hyp.model),
    )
    tuple(retrained_model..., retrain_hyp)
end;