# Ranking
* This is trained to learn the partial ordering implied by each user's watches
* Items that are watched are preferred to items that have not been watched
* If two items have been watched, then the impression metadata determines
  which one, if any, is liked more
* It uses the position aware maximum likehood estimation loss  
* The inputs to this model are features generated by other models

In [None]:
using Flux

import CUDA
import SparseArrays: sparse
import Statistics: mean, std
import StatsBase: sample
import NLopt
import Random
import NBInclude: @nbinclude
import Setfield: @set
@nbinclude("../Alpha.ipynb")
@nbinclude("../Neural/Helpers/GPU.ipynb");
@nbinclude("EnsembleInputs.ipynb");

## Hyperparameters

In [None]:
@with_kw struct Hyperparams
    allow_ptw::Bool
    alphas::Vector{String}
    batch_size::Int32
    input_size::Int32
    l2penalty::Float32
    learning_rate::Float32
    list_size::Int32
    seed::UInt64
end

function to_dict(x::Hyperparams)
    Dict(string(key) => getfield(x, key) for key ∈ fieldnames(Hyperparams))
end

function Base.string(x::Hyperparams)
    fields = [x for x in fieldnames(Hyperparams)]
    max_field_size = maximum(length(string(k)) for k in fields)
    ret = "Hyperparameters:\n"
    for f in fields
        ret *= "$(rpad(string(f), max_field_size)) => $(getfield(x, f))\n"
    end
    ret
end;

In [None]:
function create_hyperparams(hyp, λ)
    hyp = @set hyp.input_size = length(alphas) + num_items()
    hyp = @set hyp.learning_rate = 3e-4 * 10^(-λ[1])
    hyp = @set hyp.l2penalty = 10^(λ[2] - 5)
    hyp
end;

## Data Preprocessing

In [None]:
function get_priority(df::RatingsDataset, content::String, i::Integer)
    if content == "explicit"
        priority = [2, df.rating[i], df.status[i], df.completion[i]]
    elseif content == "implicit"
        priority = [2, NaN, df.status[i], df.completion[i]]
    elseif content == "ptw"
        priority = [1, NaN, NaN, NaN]
    elseif content == "negative"
        priority = [0, NaN, NaN, NaN]
    else
        @assert false
    end
    @assert length(priority) == get_priority_size()
    convert.(Float16, priority)
end

function get_priority_size()
    4
end;

In [None]:
function get_priorities(split::String, content::String)
    @info "getting $split $content priorities"
    df = get_raw_split(split, content)
    A = Matrix{Float16}(undef, get_priority_size(), length(df.user))
    @tprogress Threads.@threads for i = 1:length(df.user)
        A[:, i] = get_priority(df, content, i)
    end
    A
end;

In [None]:
function get_user_to_indexes(split_content_pairs::Vector)
    # preallocate to avoid race conditions
    u_to_xs = Dict{Int32,Vector{Int32}}()
    for u = 1:num_users()
        u_to_xs[u] = Int32[]
    end

    for (split, content) in split_content_pairs
        df = get_raw_split(split, content)
        # multithread by sharding on userid
        idxs = [[[] for _ = 1:Threads.nthreads()] for _ = 1:Threads.nthreads()]
        @tprogress Threads.@threads for i = 1:length(df.user)
            push!(idxs[Threads.threadid()][df.user[i]%Threads.nthreads()+1], i)
        end
        Threads.@threads for t = 1:Threads.nthreads()
            for idx in idxs
                for i in idx[t]
                    push!(u_to_xs[df.user[i]], i)
                end
            end
        end
    end

    # prune unused users
    for u = 1:num_users()
        if length(u_to_xs[u]) == 0
            delete!(u_to_xs, u)
        end
    end
    u_to_xs
end;

## Batching

In [None]:
function compare(x::Number, y::Number)
    if isnan(x) && isnan(y)
        return 0
    elseif isnan(x) || isnan(y)
        return NaN
    elseif x == y
        return 0
    elseif x > y
        return 1
    else
        return -1
    end
end

function compare(x::Vector, y::Vector)
    @assert length(x) == length(y)
    for i = 1:length(x)
        r = compare(x[i], y[i])
        if isnan(r)
            return NaN
        elseif r != 0
            return r
        end
    end
    0
end;

In [None]:
function topological_sort(V::Vector, E::Function)
    # performs a topoplogical sort to get a random permutation that's
    # consistent with the poset implied by the graph G = (V, E)
    T = eltype(V)
    children = Dict{T,Set{T}}(u => Set{T}() for u in V)
    num_parents = Dict{T,Int32}(u => 0 for u in V)
    edges = 0
    for u in V
        for v in V
            if E(u, v) == 1
                push!(children[u], v)
                num_parents[v] += 1
                edges += 1
            end
        end
    end
    if edges == 0
        return V, false
    end

    rootless = Set{T}(v for v in V if num_parents[v] == 0)
    @assert length(rootless) > 0

    toposort = T[]
    while length(rootless) > 0
        v = rand(rootless)
        delete!(rootless, v)
        push!(toposort, v)
        for u in children[v]
            num_parents[u] -= 1
            @assert num_parents[u] >= 0
            if num_parents[u] == 0
                push!(rootless, u)
            end
        end
    end

    toposort, true
end;

In [None]:
function get_batch(
    f,
    users::Vector,
    list_size::Integer,
    batch_size::Integer,
)
    embs = []
    for _ = 1:batch_size
        sample = get_sample(
            f,
            users,
            list_size,
        )
        if length(embs) == 0
            for _ in 1:length(sample)
                push!(embs, [])
            end
        end        
        push!.(embs, sample)
    end
    embs
end;

In [None]:
function setup_batch_channel(
    f,
    users::Vector,
    hyp::Hyperparams,
    channel_size::Integer,
)
    batches = Channel(channel_size)
    for t = 1:Threads.nthreads()
        Threads.@spawn begin
            while true
                try
                    batch = get_batch(
                        f,
                        users,
                        hyp.list_size,
                        hyp.batch_size,
                    )
                    put!(batches, batch)
                catch e
                    break
                end
            end
        end
    end
    batches
end;

In [None]:
function get_batch(c::Channel)
    batch = take!(c)
    tuple((Flux.batch(device.(batch[i])) for i in 1:length(batch))...)
end;

## Loss Functions

In [None]:
function position_aware_list_mle_loss(m, x)
    p = Flux.flatten(m(x))

    # for numerical stability
    p = p .- maximum(p; dims = 1)
    ϵ = Float32(eps(Float64))

    N, batch_size = size(p)
    C = collect(LinearAlgebra.UpperTriangular(ones(Float32, N, N))) |> device
    w = convert.(Float32, vec(2 .^ (N:-1:1) .- 1)) ./ (2^N - 1) |> device
    sum(w .* (-p + log.(C * exp.(p) .+ ϵ))) / batch_size
end;

In [None]:
function average_loss(m, batches::Channel, iters::Integer, hyp::Hyperparams)
    loss = 0.0
    @showprogress for _ = 1:iters
        loss += position_aware_list_mle_loss(m, get_batch(batches))
    end
    loss / iters
end;

## Training

In [None]:
# trains a model with the given hyperparams and returns its validation loss
function train_model(
    hyp::Hyperparams;
    max_checkpoints::Integer = 100,
    epochs_per_checkpoint::Integer = 10,
    patience::Integer = 0,
    verbose::Bool = true,
)
    if verbose
        @info "Initializing model"
    end
    opt = ADAMW(hyp.learning_rate, (0.9, 0.999), hyp.l2penalty)
    Random.seed!(hyp.seed)
    m = build_model(hyp) |> device
    best_model = m |> cpu
    ps = Flux.params(m)
    stopper = early_stopper(
        max_iters = max_checkpoints,
        patience = patience,
        min_rel_improvement = 1e-3,
    )
    batchloss(x...) = position_aware_list_mle_loss(m, x)
    epoch_size = Int(round(num_users() / hyp.batch_size))
    function loginfo(x)
        if verbose
            @info x
        end
    end

    loginfo("Getting data")
    f = get_features(hyp.alphas, false)
    training_users, test_users = training_test_split(f)
    setup_channel(users) = setup_batch_channel(
        f,
        users,
        hyp,
        64,
    )
    training_batches = setup_channel(training_users)
    test_batches = setup_channel(test_users)
    @info "Testing channels"
    @info size.(get_batch(training_batches))
    @info size.(get_batch(test_batches))


    loginfo(
        "Training model with initial loss $(average_loss(m, test_batches, epoch_size, hyp))",
    )
    loss = Inf
    losses = []
    while (!stop!(stopper, loss))
        for _ = 1:epochs_per_checkpoint
            @showprogress for _ = 1:epoch_size
                Flux.train!(batchloss, ps, [get_batch(training_batches)], opt)
            end
        end

        loss = average_loss(m, test_batches, epoch_size, hyp)
        push!(losses, loss)
        if loss == minimum(losses)
            best_model = m |> cpu
        end
        loginfo("loss $loss")
    end
    close(training_batches)
    close(test_batches)
    best_model, minimum(losses), get_inference_data(f)
end;

## Hyperparameter Tuning

In [None]:
# function optimize_hyperparams(hyp; max_evals)
#     function nlopt_loss(λ, grad)
#         # nlopt internally converts to float64 because it calls a c library
#         λ = convert.(Float32, λ)
#         _, loss = train_model(create_hyperparams(hyp, λ))
#         @info "$λ $loss"
#         loss
#     end
#     opt = NLopt.Opt(:LN_NELDERMEAD, 2)
#     opt.initial_step = 1
#     opt.maxeval = max_evals
#     opt.min_objective = nlopt_loss
#     minf, λ, ret = NLopt.optimize(opt, zeros(Float32, 2))
#     numevals = opt.numevals

#     @info (
#         "found minimum $minf at point $λ after $numevals function calls " *
#         "(ended because $ret) and saved model at"
#     )
#     λ
# end;

## Save Model

In [None]:
function train_alpha(hyp, outdir; tune_hyperparams = false)
    set_logging_outdir(outdir)

    if tune_hyperparams
        @info "Optimizing hyperparameters..."
        λ = optimize_hyperparams(hyp; max_evals = 10)
    else
        λ = zeros(2)
    end
    hyp = create_hyperparams(hyp, λ)

    @info "Training model..."
    m, validation_loss, inference_data =
        train_model(hyp; max_checkpoints = 100, epochs_per_checkpoint = 1, patience = 5)
    @info "Trained model loss: $validation_loss"

    @info "Writing alpha..."
    # TODO write the actual alpha value
    write_params(
        Dict("m" => m, "λ" => λ, "hyp" => hyp, "inference_data" => inference_data),
        outdir,
    )
    @info "Wrote alpha!"
end;

## Methods to override in subclasses

In [None]:
function get_features(alphas::Vector{String}, allow_ptw::Bool)
    Base.error("not implemented")
end

function build_model(hyp::Hyperparams)
    Base.error("not implemented")
end

function get_inference_data(f)
    Base.error("not implemented")    
end

function get_sample(
    f,
    users::Vector,
    list_size::Integer,
)
    Base.error("not implemented")    
end;