# Bayesian Personalized Ranking
* Creates a model for pairwise classification

In [1]:
using Flux

import CUDA
import NLopt
import Random
import NBInclude: @nbinclude
import Setfield: @set
@nbinclude("BPRBase.ipynb")
@nbinclude("../Neural/GPU.ipynb")

## Hyperparameters

In [3]:
@with_kw struct Hyperparams
    batch_size::Int
    features::Vector{String}
    l2penalty::Float32
    learning_rate::Float32
    seed::UInt64
end

function to_dict(x::Hyperparams)
    Dict(string(key) => getfield(x, key) for key ∈ fieldnames(Hyperparams))
end

function Base.string(x::Hyperparams)
    fields = [x for x in fieldnames(Hyperparams)]
    max_field_size = maximum(length(string(k)) for k in fields)
    ret = "Hyperparameters:\n"
    for f in fields
        ret *= "$(rpad(string(f), max_field_size)) => $(getfield(x, f))\n"
    end
    ret
end;

## Models

In [4]:
function build_model(features)
    num_inputs = length(features) * 2
    Chain(Dense(num_inputs => 256, relu), Dense(256 => 1))
end;

## Loss Functions

In [8]:
function model_loss(m, x, y)
    Flux.logitbinarycrossentropy(m(x), y)
end

function split_loss(m, iters, batches::Channel)
    losses = 0.0
    for _ = 1:iters
        losses += model_loss(m, take!(batches)[1]...)
    end
    losses / iters
end;

## Training

In [10]:
function generate_batches(user_priorities, user_features, batch_size, c::Channel)
    while true
        try
            put!(c, get_batch(user_priorities, user_features, batch_size))
        catch e
            if isa(e, InvalidStateException)
                break
            end
        end
    end
end;

In [11]:
# trains a model with the given hyperparams and returns its validation loss
function train_model(
    hyp;
    max_checkpoints = 100,
    epochs_per_checkpoint = 10,
    patience = 0,
    verbose = false,
)
    if verbose
        @info "Getting data"
    end
    opt = ADAMW(hyp.learning_rate, (0.9, 0.999), hyp.l2penalty)
    Random.seed!(hyp.seed)
    m = build_model(hyp.features) |> device
    best_model = m |> cpu
    ps = Flux.params(m)
    stopper = early_stopper(max_iters = max_checkpoints, patience = patience)
    training, test, user_features = get_data(hyp.features)
    batchloss(x, y) = model_loss(m, x, y)
    epoch_size = Int(round(num_users() / hyp.batch_size))

    if verbose
        @info "Setting up batches"
    end
    training_batches = Channel(64)
    test_batches = Channel(64)
    for _ = 1:max(Threads.nthreads() / 2 - 1, 1)
        Threads.@spawn generate_batches(
            training,
            user_features,
            hyp.batch_size,
            training_batches,
        )
        Threads.@spawn generate_batches(test, user_features, hyp.batch_size, test_batches)
    end

    if verbose
        @info "Training..."
    end
    losses = []
    loss = Inf
    while (!stop!(stopper, loss))
        for i = 1:epochs_per_checkpoint
            for _ = 1:epoch_size
                Flux.train!(batchloss, ps, take!(training_batches), opt)
            end
        end

        loss = split_loss(m, epoch_size, test_batches)
        push!(losses, loss)
        if loss == minimum(losses)
            best_model = m |> cpu
        end
        if verbose
            @info "loss $loss"
        end
    end

    close(training_batches)
    close(test_batches)
    best_model, minimum(losses)
end;

## Hyperparameter Tuning

In [12]:
function create_hyperparams(hyp, λ)
    hyp = @set hyp.learning_rate = 10^(λ[1] - 3)
    hyp = @set hyp.l2penalty = 10^(λ[2] - 5)
    hyp
end;

In [13]:
function optimize_hyperparams(hyp; max_evals)
    function nlopt_loss(λ, grad)
        # nlopt internally converts to float64 because it calls a c library
        λ = convert.(Float32, λ)
        _, loss = train_model(create_hyperparams(hyp, λ))
        @info "$λ $loss"
        loss
    end
    opt = NLopt.Opt(:LN_NELDERMEAD, 2)
    opt.initial_step = 1
    opt.maxeval = max_evals
    opt.min_objective = nlopt_loss
    minf, λ, ret = NLopt.optimize(opt, zeros(Float32, 2))
    numevals = opt.numevals

    @info (
        "found minimum $minf at point $λ after $numevals function calls " *
        "(ended because $ret) and saved model at"
    )
    λ
end;

## Save Model

In [14]:
function train_alpha(hyp, outdir; tune_hyperparams = true)
    set_logging_outdir(outdir)

    if tune_hyperparams
        @info "Optimizing hyperparameters..."
        λ = optimize_hyperparams(hyp; max_evals = 10)
    else
        λ = zeros(2)
    end
    hyp = create_hyperparams(hyp, λ)

    @info "Training model..."
    m, validation_loss =
        train_model(hyp; max_checkpoints = 1000, epochs_per_checkpoint = 1, patience = 10)
    @info "Trained model loss: $validation_loss"

    @info "Writing alpha..."
    write_params(Dict("m" => m, "λ" => λ, "hyp" => hyp), outdir)
    @info "Wrote alpha!"
end;

In [15]:
restricted_alphas = [
    "Explicit"
    "LinearImplicit"
    "ErrorExplicit"
    "ErrorImplicit"
    "ExplicitUserItemBiases"
    "NeuralImplicitUserItemBiases"
];

In [16]:
train_alpha(
    Hyperparams(
        batch_size = 1024,
        features = restricted_alphas,
        l2penalty = NaN,
        learning_rate = NaN,
        seed = 20220609,
    ),
    "BPR";
    tune_hyperparams = true,
)

[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220708 22:25:32 Optimizing hyperparameters...
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:23[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:11[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:01:01[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:04[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:43[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:24[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:02:30[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220708 22:36:11 Float32[0.0, 0.0] 0.11864882431325711
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m