# Ranking
* This is trained to learn the partial ordering implied by each user's watches
* Items that are watched are preferred to items that have not been watched
* If two items have been watched, then the impression metadata determines
  which one, if any, is liked more
* It uses the position aware maximum likehood estimation loss  
* The inputs to this model are features generated by other models

In [None]:
task = ""

In [None]:
import NBInclude: @nbinclude
import Statistics: quantile
@nbinclude("MLE.Base.ipynb");

In [None]:
@with_kw struct EnsembleFeatures <: Features
    query_features::Matrix{Float32}
    preprocessing_data::Dict
    priorities::Matrix{Float16}
    index_to_item::Vector{Int32}
    user_to_indexes::Dict{Int32,Vector{Int32}}
    item_user_index::SparseMatrixCSC{Int32,Int32}
    user_to_watched_indexes::Dict{Int32,Vector{Int32}}
end

function get_inference_data(f::Features)
    f.preprocessing_data
end;

In [None]:
function get_query_features(
    alphas::Vector{String},
    split::String,
    task::String,
    content::String,
)
    @info "getting $split $content alphas"
    df = get_raw_split(split, task, content)
    T = Float16
    A = Matrix{T}(undef, length(df.user), length(alphas))
    @tprogress Threads.@threads for i = 1:length(alphas)
        if occursin("NonlinearImplicit", alphas[i])
            transform = identity
        elseif occursin("ItemCount", alphas[i])
            transform = x -> log(x + 1)
        elseif occursin("Variance", alphas[i]) || occursin("implicit", lowercase(alphas[i]))
            transform = x -> log(x + Float32(eps(Float64)))
        else
            transform = identity
        end
        A[:, i] =
            transform.(convert.(T, read_raw_alpha(alphas[i], split, task, content).rating))
    end
    collect(A')
end;

function normalize!(x::AbstractArray; clip_std = 3)
    T = eltype(x)
    N = size(x)[1]
    μ = zeros(Float32, N)
    σ = ones(Float32, N)
    for i = 1:N
        y = convert.(Float32, x[i, :])
        if all((y .>= 0) .&& (y .<= 1))
            # the parameter is uniformly scaled
            μ[i] = mean(y)
            σ[i] = 1
        else
            μ[i] = mean(y)
            σ[i] = std(y, mean = μ[i], corrected = false)
        end
        q = (y .- μ[i]) ./ σ[i]
        @info "normalization metrics for alpha $i: $(μ[i]), $(σ[i]), $(minimum(q)), " *
              "$(quantile(q, 0.1)), $(quantile(q, 0.9)), $(maximum(q))"
        if (abs(maximum(q)) > clip_std) || (abs(minimum(q)) > clip_std)
            @info "clipping values to [-$clip_std, $clip_std]"
        end
        x[i, :] = convert.(T, (x[i, :] .- μ[i]) ./ σ[i])
    end
    clamp!(x, -clip_std, clip_std)
    x, Dict("μ" => μ, "σ" => σ, "clip_std" => clip_std)
end;

In [None]:
function get_features(alphas::Vector{String}, task::String)
    @info "using alphas $alphas"
    contents = ALL_CONTENTS
    splits = ["test"]

    hreduce(f; agg = hcat) =
        reduce(agg, f(split, task, content) for split in splits for content in contents)
    query_features, preprocessing_data = normalize!(
        hreduce((split, task, content) -> get_query_features(alphas, split, task, content)),
    )
    query_features = convert.(Float32, query_features)

    user_to_indexes = get_user_to_indexes(
        [(split, content) for split in splits for content in contents],
        task,
        (split, content) -> true,
    )
    user_to_watched_indexes = get_user_to_indexes(
        [(split, content) for split in splits for content in contents],
        task,
        (split, content) -> content in ["implicit", "explicit"],
    )

    priorities = hreduce(get_priorities)
    index_to_item = hreduce(
        (split, task, content) -> get_raw_split(split, task, content).item;
        agg = vcat,
    )
    item_user_index = sparse(Int32[], Int32[], Int32[], num_items(), num_users())
    idx = 0
    for split in splits
        for content in contents
            df = get_raw_split(split, task, content)
            item_user_index += sparse(
                df.item,
                df.user,
                idx+1:idx+length(df.item),
                num_items(),
                num_users(),
            )
            idx += length(df.item)
        end
    end
    EnsembleFeatures(
        query_features = query_features,
        preprocessing_data = preprocessing_data,
        priorities = priorities,
        index_to_item = index_to_item,
        user_to_indexes = user_to_indexes,
        item_user_index = item_user_index,
        user_to_watched_indexes = user_to_watched_indexes,
    )
end;

In [None]:
function random_subsample(a, N; rng = rng)
    size = min(length(a), N)
    sample(rng, a, size; replace = false)
end

function subsample(u::Int32, list_size::Integer, f::Features; rng = rng)
    # filter out users that haven't watched any items
    if u ∉ keys(f.user_to_indexes)
        return Int32[], false
    end
    if u in keys(f.user_to_watched_indexes)
        watched_list = f.user_to_watched_indexes[u]
    else
        return Int32[], false
    end
    list = random_subsample(f.user_to_indexes[u], list_size; rng = rng)

    # ensure at least one item is watched
    if all(f.priorities[1, i] == 0 for i in list)
        list[1] = rand(rng, watched_list)
    end

    # pad to list_size
    while length(list) < list_size
        push!(list, -1)
    end
    list, true
end;

In [None]:
function get_query_embedding(f::Features, q::Integer)
    if q == -1
        return zeros(Float32, size(f.query_features)[1])
    else
        return f.query_features[:, q]
    end
end

function get_priority_embedding(f::Features, i::Integer)
    if i == -1
        return Float16[NaN, NaN, NaN, NaN]
    else
        return f.priorities[:, i]
    end
end

function get_sample(f::Features, training::Bool, list_size::Integer; rng = rng)
    max_training_user = Int(floor(num_users() * 0.9))
    if training
        user_range = Int32(1):Int32(max_training_user)
    else
        user_range = Int32(max_training_user + 1):Int32(num_users())
    end

    while true
        u = rand(rng, user_range)
        list, ok = subsample(u, list_size, f; rng = rng)
        if !ok
            continue
        end
        prefs = Flux.batch(get_priority_embedding(f, i) for i in list)
        q_embs = hcat((get_query_embedding(f, q) for q in list)...)
        return q_embs, prefs
    end
end;

In [None]:
function get_batch(
    f::Features,
    training::Bool,
    list_size::Integer,
    batch_size::Integer;
    rng,
)
    q_embs = Matrix{Float32}[]
    prefs = Matrix{Float16}[]
    for _ = 1:batch_size
        q_emb, pref = get_sample(f, training, list_size; rng = rng)
        push!(q_embs, q_emb)
        push!(prefs, pref)
    end

    # move to GPU
    Q = device(Flux.batch(q_embs))
    P = device(Flux.batch(prefs))
    Q, P
end;

In [None]:
function build_model(hyp::Hyperparams)
    K = hyp.embedding_size
    Chain(
        Dense(length(hyp.alphas), K, relu),
        Dense(K => div(K, 2), relu),
        Dense(div(K, 2), 1),
    )
end;

## Train model

In [None]:
function get_alphas(task::String)
    alphas = [
        "$task/LinearImplicit"
        "$task/Explicit"
        nondirectional_raw_alphas
        "$task/LinearExplicit"
        "$task/NonlinearImplicit"
        implicit_raw_alphas(task)
        explicit_raw_alphas(task)
    ]
    # this alpha is non-personalized and overfits to the objective function
    forbidden_alphas = ["$task/ExplicitUserItemBiases"]
    [x for x in alphas if x ∉ forbidden_alphas]
end;

In [None]:
function get_interaction_weights(task::String)
    if task == "random"
        return Float32[2.0^-10, 1, 0, 0]
    elseif task == "temporal"
        return Float32[2.0^-10, 1, 0, 0]
    end
end;

In [None]:
hyp = Hyperparams(
    alphas = get_alphas(task),
    batch_size = 64,
    embedding_size = 256,
    l2penalty = NaN,
    learning_rate = NaN,
    list_size = 1024,
    seed = 20220609,
    ranking_weight = 1.0,
    interaction_weights = get_interaction_weights(task),
)
hyp = create_hyperparams(hyp, [0.0f0, 0.0f0])

In [None]:
features = get_features(hyp.alphas, task);

In [None]:
train_alpha(hyp, task, "$task/MLE.Ensemble"; features = features)