# Ranking
* This is trained to learn the partial ordering implied by each user's watches
* Items that are watched are preferred to items that have not been watched
* If two items have been watched, then the impression metadata determines
  which one, if any, is liked more
* It uses a generalized form of the position aware maximum likehood estimation loss that handles posets

In [None]:
import CUDA
import Flux
import Flux: Chain, Dense, Dropout, cpu, gpu, relu, sigmoid
import Optimisers
import Optimisers: Adam, OptimiserChain, WeightDecay
import NBInclude: @nbinclude
import NLopt
import Random
import Setfield: @set
import SparseArrays: AbstractSparseArray, sparse, SparseMatrixCSC, SparseVector
import Statistics: mean, std
import StatsBase: sample
@nbinclude("../Alpha.ipynb")
@nbinclude("../Neural/Helpers/GPU.ipynb");
@nbinclude("EnsembleInputs.ipynb");

## Hyperparameters

In [None]:
abstract type Features end

In [None]:
@with_kw struct Hyperparams
    alphas::Vector{String}
    batch_size::Int32
    embedding_size::Int32
    holdout::Float32
    l2penalty::Float32
    learning_rate::Float32
    list_size::Int32
    seed::UInt64
end

function to_dict(x::Hyperparams)
    Dict(string(key) => getfield(x, key) for key ∈ fieldnames(Hyperparams))
end

function Base.string(x::Hyperparams)
    fields = [x for x in fieldnames(Hyperparams)]
    max_field_size = maximum(length(string(k)) for k in fields)
    ret = "Hyperparameters:\n"
    for f in fields
        ret *= "$(rpad(string(f), max_field_size)) => $(getfield(x, f))\n"
    end
    ret
end;

In [None]:
function create_hyperparams(hyp, λ)
    hyp = @set hyp.learning_rate = 3e-4 * 10^(-λ[1])
    hyp = @set hyp.holdout = sigmoid(-1 + λ[2])
    hyp = @set hyp.l2penalty = 10^(λ[3] - 5)
    hyp
end;

## Data Preprocessing

In [None]:
function get_priority_size()
    4
end

function get_priority(df::RatingsDataset, content::String, i::Integer)
    # todo remove on next recrunch           
    if content != "negative"
        completion = min(max(df.completion[i], 0), 1)
    end

    if content == "explicit"
        priority = Float16[1, df.rating[i], df.status[i], completion]
    elseif content in ["implicit", "ptw"]
        priority = Float16[1, NaN, df.status[i], completion]
    elseif content == "negative"
        priority = Float16[0, NaN, NaN, NaN]
    else
        @assert false
    end
    @assert length(priority) == get_priority_size()
    priority
end

function get_priorities(split::String, task::String, content::String)
    @info "getting $split $task $content priorities"
    df = get_raw_split(split, task, content)
    A = Matrix{Float16}(undef, get_priority_size(), length(df.user))
    @tprogress Threads.@threads for i = 1:length(df.user)
        A[:, i] = get_priority(df, content, i)
    end
    A
end

function compare(x::Number, y::Number)
    if isnan(x) || isnan(y)
        return NaN
    elseif abs(x - y) < eps(eltype(x))
        return 0
    elseif x > y
        return 1
    else
        return -1
    end
end

function compare(x::Vector, y::Vector)
    @assert length(x) == length(y)
    retval = 0
    for i = 1:length(x)
        r = compare(x[i], y[i])
        if isnan(r)
            retval = NaN
        elseif r != 0
            return r
        end
    end
    retval
end;

In [None]:
function get_user_to_indexes(split_content_pairs::Vector, task::String, include::Function)
    u_to_xs = Dict{Int32,Vector{Int32}}(u => Int32[] for u = 1:num_users())
    index_base::Int32 = 0
    for (split, content) in split_content_pairs
        df = get_raw_split(split, task, content)
        if include(split, content)
            # multithread by sharding on userid
            idxs = [[[] for _ = 1:Threads.nthreads()] for _ = 1:Threads.nthreads()]
            @tprogress Threads.@threads for i = 1:length(df.user)
                push!(idxs[Threads.threadid()][df.user[i]%Threads.nthreads()+1], i)
            end
            Threads.@threads for t = 1:Threads.nthreads()
                for idx in idxs
                    for i in idx[t]
                        push!(u_to_xs[df.user[i]], i + index_base)
                    end
                end
            end
        end
        index_base += length(df.item)
    end

    # prune unused users
    for u = 1:num_users()
        if length(u_to_xs[u]) == 0
            delete!(u_to_xs, u)
        end
    end
    u_to_xs
end;

In [None]:
function get_preferences(V::Vector{Int32}, E::Function)
    P = zeros(Int32, length(V), length(V))
    for i = 1:length(V)
        P[i, i] = 1
    end
    for i = 1:length(V)
        for j = i+1:length(V)
            cmp = E(V[i], V[j])
            if cmp == 1
                P[i, j] = 1
            elseif cmp == -1
                P[j, i] = 1
            elseif cmp == 0
                # expand ties
                P[i, j] = 1
                P[j, i] = 1
            end
        end
    end
    P
end;

## Loss Functions

In [None]:
function position_aware_list_mle_loss(m, x, P)
    # position aware list mle loss with a modifications to handle non-comparable items
    # P is the preference relation where P[i, j] = 1 iff i is prefered to j

    p = Flux.flatten(m(x))
    p = p .- maximum(p; dims = 1)
    q = exp.(p)

    ϵ = Float32(eps(Float64))
    N, batch_size = size(p)
    r = ones(Float32, N) |> device

    total = 0.0f0
    for i = 1:batch_size
        w = ((2.0f0 .^ (P[:, :, i] * r) .- 1) ./ (2^N - 1))
        unweighted_loss = -p[:, i] + log.(P[:, :, i] * q[:, i] .+ ϵ)
        total += sum(w .* unweighted_loss)
    end
    total / batch_size
end;

In [None]:
function average_loss(m, f::Features, hyp::Hyperparams, iters::Integer)
    loss = 0.0
    @showprogress for _ = 1:iters
        batch = get_batch(f, false, hyp.list_size, hyp.batch_size, hyp.holdout)
        loss += position_aware_list_mle_loss(m, batch...)
        device_free!(batch)
    end
    loss / iters
end;

## Training

In [None]:
function train_epoch!(m, opt, f::Features, hyp::Hyperparams, epoch_size::Integer)
    @showprogress for _ = 1:epoch_size
        batch = get_batch(f, true, hyp.list_size, hyp.batch_size, hyp.holdout)
        grads = Flux.gradient(m) do model
            position_aware_list_mle_loss(model, batch...)
        end        
        device_free!(batch)
        Flux.update!(opt, m, grads[1])        
    end
end;

In [None]:
# trains a model with the given hyperparams and returns its validation loss
function train_model(
    hyp::Hyperparams,
    task::String,
    outdir::String;
    max_checkpoints::Integer,
    epochs_per_checkpoint::Integer,
    patience::Integer,
    verbose::Bool = true,
)
    if verbose
        @info "Initializing model"
    end
    rng = Random.Xoshiro(hyp.seed)
    Random.seed!(rand(rng, UInt64))
    if CUDA.functional()
        Random.seed!(CUDA.default_rng(), rand(rng, UInt64))
        Random.seed!(CUDA.CURAND.default_rng(), rand(rng, UInt64))
    end
    m = build_model(hyp) |> device
    opt = Optimisers.setup(
        OptimiserChain(Adam(hyp.learning_rate, (0.9f0, 0.999f0)), WeightDecay(hyp.l2penalty)),
        m,
    )    
    best_model = m |> cpu
    stopper = early_stopper(
        max_iters = max_checkpoints,
        patience = patience,
        min_rel_improvement = 1e-3,
    )
    function loginfo(x)
        if verbose
            @info x
        end
    end
    loginfo("Getting data")
    f = get_features(hyp.alphas, task)
    epoch_size = Int(round(length(f.user_to_indexes) / hyp.batch_size))    
    loginfo("Training model...")
    curloss = Inf
    losses = []
    while (!stop!(stopper, curloss))
        for _ = 1:epochs_per_checkpoint
            train_epoch!(m, opt, f, hyp, epoch_size)
        end
        curloss = average_loss(m, f, hyp, epoch_size)
        push!(losses, curloss)
        if curloss == minimum(losses)
            best_model = m |> cpu
            write_params(
                Dict(
                    "m" => best_model,
                    "hyp" => hyp,
                    "inference_data" => get_inference_data(f),
                ),
                outdir,
            )
        end
        loginfo("loss $curloss")
    end

    best_model, minimum(losses), f
end;

## Save Model

In [None]:
function train_alpha(hyp, task::String, outdir::String; tune_hyperparams::Bool = false)
    set_logging_outdir(outdir)

    if tune_hyperparams
        @info "Optimizing hyperparameters..."
        λ = optimize_hyperparams(hyp; max_evals = 10)
        hyp = create_hyperparams(hyp, λ)
    end

    @info "Training model..."
    m, validation_loss, f = train_model(
        hyp,
        task,
        outdir;
        max_checkpoints = 50,
        epochs_per_checkpoint = 1,
        patience = 2,
    )
    @info "Trained model loss: $validation_loss"

    @info "Writing alpha..."
    write_params(
        Dict("m" => m, "hyp" => hyp, "inference_data" => get_inference_data(f)),
        outdir,
    )
    # write_alpha(
    #     (split::String, task::String, content::String; raw_splits::Bool = true) ->
    #         inference_model(m, hyp, f, split, task, content; raw_splits = raw_splits),
    #     outdir;
    #     by_split = true,
    #     log = false,
    # )
    @info "Wrote alpha!"
end;

## Methods to override in subclasses

In [None]:
function get_features(alphas::Vector{String})
    Base.error("not implemented")
end

function build_model(hyp::Hyperparams)
    Base.error("not implemented")
end

function get_inference_data(f::Features)
    Base.error("not implemented")
end

function get_batch(
    f::Features,
    training::Bool,
    list_size::Integer,
    batch_size::Integer,
    holdout::Float32,
)
    Base.error("not implemented")
end

function inference_model(
    hyp::Hyperparams,
    f::Features,
    split::String,
    task::String,
    content::String;
    raw_splits::Bool,
)
    Base.error("not implemented")
end;