# Ranking
* This is trained to learn the partial ordering implied by each user's watches
* Items that are watched are preferred to items that have not been watched
* If two items have been watched, then the impression metadata determines
  which one, if any, is liked more
* It uses a generalized form of the position aware maximum likehood estimation loss that handles posets

In [None]:
import CUDA
import Flux
import Flux: Chain, Dense, Dropout, cpu, gpu, relu, sigmoid
import Optimisers
import Optimisers: Adam, OptimiserChain, WeightDecay
import NBInclude: @nbinclude
import NLopt
import Random
import Setfield: @set
import SparseArrays: AbstractSparseArray, sparse, SparseMatrixCSC, SparseVector
import Statistics: mean, std
import StatsBase: sample
@nbinclude("../Alpha.ipynb")
@nbinclude("../Neural/Helpers/GPU.ipynb");
@nbinclude("EnsembleInputs.ipynb");

## Hyperparameters

In [None]:
abstract type Features end

In [None]:
@with_kw struct Hyperparams
    alphas::Vector{String}
    batch_size::Int32
    embedding_size::Int32
    l2penalty::Float32
    learning_rate::Float32
    list_size::Int32
    seed::UInt64
    ranking_weight::Float32
    interaction_weights::Vector{Float32}
end

function verify(x::Hyperparams)
    @assert (0 < x.ranking_weight) && (x.ranking_weight <= 1)
    @assert all(x.interaction_weights .>= 0) && all(x.interaction_weights .<= 1)
end

function to_dict(x::Hyperparams)
    Dict(string(key) => getfield(x, key) for key ∈ fieldnames(Hyperparams))
end

function Base.string(x::Hyperparams)
    fields = [x for x in fieldnames(Hyperparams)]
    max_field_size = maximum(length(string(k)) for k in fields)
    ret = "Hyperparameters:\n"
    for f in fields
        ret *= "$(rpad(string(f), max_field_size)) => $(getfield(x, f))\n"
    end
    ret
end;

In [None]:
function create_hyperparams(hyp, λ)
    hyp = @set hyp.learning_rate = 3e-4 * 10^(-λ[1])
    hyp = @set hyp.l2penalty = hyp.learning_rate * 10^(λ[2] - 2)
    hyp
end;

## Data Preprocessing

In [None]:
function get_priority_size()
    4
end

function get_priority(df::RatingsDataset, content::String, i::Integer)
    if content == "explicit"
        priority = Float16[1, df.rating[i], df.status[i], df.completion[i]]
    elseif content == "implicit"
        priority = Float16[1, NaN, df.status[i], df.completion[i]]
    elseif content == "ptw"
        priority = Float16[1, df.rating[i], df.status[i], df.completion[i]] 
    elseif content == "negative"
        priority = Float16[0, NaN, NaN, NaN]
    else
        @assert false
    end
    @assert length(priority) == get_priority_size()
    priority
end

function get_priorities(split::String, task::String, content::String)
    @info "getting $split $task $content priorities"
    df = get_raw_split(split, task, content)
    df.rating[df.rating.==0] .= NaN
    A = Matrix{Float16}(undef, get_priority_size(), length(df.user))
    @tprogress Threads.@threads for i = 1:length(df.user)
        A[:, i] = get_priority(df, content, i)
    end
    A
end;

In [None]:
function get_user_to_indexes(split_content_pairs::Vector, task::String, include::Function)
    u_to_xs = Dict{Int32,Vector{Int32}}(u => Int32[] for u = 1:num_users())
    index_base::Int32 = 0
    for (split, content) in split_content_pairs
        df = get_raw_split(split, task, content)
        if include(split, content)
            # multithread by sharding on userid
            idxs = [[[] for _ = 1:Threads.nthreads()] for _ = 1:Threads.nthreads()]
            @tprogress Threads.@threads :static for i = 1:length(df.user)
                push!(idxs[Threads.threadid()][df.user[i]%Threads.nthreads()+1], i)
            end
            Threads.@threads :static for t = 1:Threads.nthreads()
                for idx in idxs
                    for i in idx[t]
                        push!(u_to_xs[df.user[i]], i + index_base)
                    end
                end
            end
        end
        index_base += length(df.item)
    end

    # prune unused users
    for u = 1:num_users()
        if length(u_to_xs[u]) == 0
            delete!(u_to_xs, u)
        end
    end
    u_to_xs
end;

## Loss Functions

In [None]:
function get_partial_ordering(P, i, interaction_weights)
    # returns the weighted adjacency matrix of the partial ordering graph
    # note that diagonal entries are zero
    a = (P[1, :, i] .> P[1, :, i]') .* interaction_weights[1]
    for j = 2:size(P)[1]
        b = (a .== 0) .&& (a' .== 0) .&& (P[j, :, i] .> P[j, :, i]')
        a = a + b .* interaction_weights[j]
    end
    a
end

function get_position_weights(P, ranking_weight)
    P = convert.(Float32, (P .> 0))
    N = size(P)[1]
    better_than = sum(P, dims = 2)
    if ranking_weight == 1
        w = convert.(Float32, (better_than .> 0))
    else
        worse_than = sum(P, dims = 1)'
        remaining = (N - 1) .- (better_than + worse_than)
        rank = worse_than + remaining ./ 2.0f0
        w = (ranking_weight .^ rank) .* (better_than .> 0)
    end
    # normalize the weights so that each user is equally weighted in the loss
    w ./ sum(w)
end

function position_aware_list_mle_loss(m, x, priorities, ranking_weight, interaction_weights)
    # position aware list mle loss with a modifications to handle non-comparable items and 
    # varying comparison strengths see [Position-Aware ListMLE: A Sequential Learning Process 
    # for Ranking](https://auai.org/uai2014/proceedings/individuals/164.pdf)
    p = Flux.flatten(m(x))
    p = p .- maximum(p; dims = 1)
    q = exp.(p)

    ϵ = Float32(eps(Float64))
    N, batch_size = size(p)

    total = 0.0f0
    for i = 1:batch_size
        P = get_partial_ordering(priorities, i, interaction_weights)
        w = get_position_weights(P, ranking_weight)
        unweighted_loss = -p[:, i] + log.(P * q[:, i] + q[:, i] .+ ϵ)
        total += sum(w .* unweighted_loss)
    end
    total / batch_size
end;

In [None]:
function average_loss(m, f::Features, hyp::Hyperparams, iters::Integer)
    rng = Random.Xoshiro(20230224)
    loss = 0.0
    @showprogress for _ = 1:iters
        batch = get_batch(f, false, hyp.list_size, hyp.batch_size; rng = rng)
        loss += position_aware_list_mle_loss(
            m,
            batch...,
            hyp.ranking_weight,
            hyp.interaction_weights,
        )
        device_free!(batch)
    end
    loss / iters
end;

## Training

In [None]:
function train_epoch!(m, opt, f::Features, hyp::Hyperparams, epoch_size::Integer; rng = rng)
    training_losses = Float32[]
    @showprogress for _ = 1:epoch_size
        batch = get_batch(f, true, hyp.list_size, hyp.batch_size; rng = rng)
        tloss, grads = Flux.withgradient(m) do model
            position_aware_list_mle_loss(
                model,
                batch...,
                hyp.ranking_weight,
                hyp.interaction_weights,
            )
        end
        device_free!(batch)
        Flux.update!(opt, m, grads[1])
        push!(training_losses, tloss)        
    end
    mean(training_losses)
end;

In [None]:
# trains a model with the given hyperparams and returns its validation loss
function train_model(
    hyp::Hyperparams,
    task::String,
    outdir::String;
    max_checkpoints::Integer,
    patience::Integer,
    features::Features = nothing,
)
    @info "Initializing model"
    rng = Random.Xoshiro(hyp.seed)
    Random.seed!(rand(rng, UInt64))
    if CUDA.functional()
        Random.seed!(CUDA.default_rng(), rand(rng, UInt64))
        Random.seed!(CUDA.CURAND.default_rng(), rand(rng, UInt64))
    end
    m = build_model(hyp) |> device
    opt = Optimisers.setup(
        OptimiserChain(
            Adam(hyp.learning_rate, (0.9f0, 0.999f0)),
            WeightDecay(hyp.l2penalty),
        ),
        m,
    )
    best_model = m |> cpu
    stopper = early_stopper(
        max_iters = max_checkpoints,
        patience = patience,
        min_rel_improvement = 1e-3,
    )
    if features == nothing
        @info "Getting data"
        features = get_features(hyp.alphas, task)
    end
    inference_data = get_inference_data(features)    
    epoch_size = Int(round(length(features.user_to_indexes) / hyp.batch_size))
    training_epoch_size = Int(round(epoch_size * 0.9))
    # when sampling x items with replacement, using 3x samples 
    # will get 0.95x distinct items
    validation_epoch_size = 3 * Int(round(epoch_size * 0.1)) 
    @info "Training model..."
    curloss = Inf
    losses = []
    while (!stop!(stopper, curloss))
        training_loss = train_epoch!(m, opt, features, hyp, training_epoch_size; rng = rng)
        curloss = average_loss(m, features, hyp, validation_epoch_size)
        push!(losses, curloss)
        if curloss == minimum(losses)
            best_model = m |> cpu
            write_params(
                Dict(
                    "m" => best_model,
                    "hyp" => hyp,
                    "inference_data" => inference_data,
                ),
                outdir,
            )
        end
        @info "training_loss $training_loss validation_loss $curloss"
    end

    best_model, minimum(losses)
end;

## Save Model

In [None]:
function train_alpha(
    hyp,
    task::String,
    outdir::String;
    tune_hyperparams::Bool = false,
    features::Features = nothing,
)
    set_logging_outdir(outdir)
    verify(hyp)

    if tune_hyperparams
        @info "Optimizing hyperparameters..."
        λ = optimize_hyperparams(hyp; max_evals = 10)
        hyp = create_hyperparams(hyp, λ)
    end

    m, validation_loss = train_model(
        hyp,
        task,
        outdir;
        max_checkpoints = 50,
        patience = 1,
        features = features,
    )
    @info "Trained model loss: $validation_loss"

    # TODO fix inference
    # write_alpha(
    #     (split::String, task::String, content::String; raw_splits::Bool = true) ->
    #         inference_model(m, hyp, f, split, task, content; raw_splits = raw_splits),
    #     outdir;
    #     by_split = true,
    #     log = false,
    # )
    @info "Wrote alpha!"
end;