# Ranking
* This is trained to learn the partial ordering implied by each user's watches
* Items that are watched are preferred to items that have not been watched
* If two items have been watched, then the impression metadata determines
  which one, if any, is liked more
* It uses the position aware maximum likehood estimation loss  
* The inputs to this model are features generated by other models

In [None]:
task = ""

In [None]:
import NBInclude: @nbinclude
@nbinclude("MLE.Base.ipynb");

## Define Subclass

In [None]:
@with_kw struct EnsembleFeatures <: Features
    query_features::Matrix{Float32}
    preprocessing_data::Dict
    priorities::Matrix{Float16}
    index_to_item::Vector{Int32}
    user_to_indexes::Dict{Int32,Vector{Int32}}
    item_user_index::SparseMatrixCSC{Int32,Int32}
    user_to_watched_indexes::Dict{Int32,Vector{Int32}}
    sampling_factor::Float32
end

function get_inference_data(f::Features)
    f.preprocessing_data
end;

In [None]:
function get_query_features(
    alphas::Vector{String},
    split::String,
    task::String,
    content::String,
)
    @info "getting $split $content alphas"
    df = get_raw_split(split, task, content)
    T = Float16
    A = Matrix{T}(undef, length(df.user), length(alphas))
    @tprogress Threads.@threads for i = 1:length(alphas)
        A[:, i] = convert.(T, read_raw_alpha(alphas[i], split, task, content).rating)
    end
    collect(A')
end;

function normalize(x::AbstractArray; dims = 1)
    T = eltype(x)
    x = convert.(Float32, x)
    μ = mean(x, dims = dims)
    σ = std(x, dims = dims, mean = μ, corrected = false)
    convert.(T, (x .- μ) ./ σ), Dict("μ" => μ, "σ" => σ)
end;

In [None]:
function get_features(alphas::Vector{String}, task::String)
    contents = ALL_CONTENTS
    splits = ["test"]

    user_to_indexes = get_user_to_indexes(
        [(split, content) for split in splits for content in contents],
        task,
        (split, content) -> true,
    )
    user_to_watched_indexes = get_user_to_indexes(
        [(split, content) for split in splits for content in contents],
        task,
        (split, content) -> content in ["implicit", "explicit"],
    )
    hreduce(f; agg = hcat) =
        reduce(agg, f(split, task, content) for split in splits for content in contents)
    query_features, preprocessing_data = normalize(
        hreduce((split, task, content) -> get_query_features(alphas, split, task, content));
        dims = 2,
    )
    query_features = convert.(Float32, query_features)
    priorities = hreduce(get_priorities)
    index_to_item = hreduce(
        (split, task, content) -> get_raw_split(split, task, content).item;
        agg = vcat,
    )

    item_user_index = sparse(Int32[], Int32[], Int32[], num_items(), num_users())
    idx = 1
    for split in splits
        for content in contents
            df = get_raw_split(split, task, content)
            sp =
                sparse(df.item, df.user, fill(1, length(df.item)), num_items(), num_users())
            @tprogress Threads.@threads for i = 1:length(df.item)
                sp[df.item[i], df.user[i]] = i + (idx - 1)
            end
            item_user_index += sp
            idx += length(df.item)
        end
    end

    EnsembleFeatures(
        query_features = query_features,
        preprocessing_data = preprocessing_data,
        priorities = priorities,
        index_to_item = index_to_item,
        user_to_indexes = user_to_indexes,
        item_user_index = item_user_index,
        user_to_watched_indexes = user_to_watched_indexes,
        sampling_factor = 0.5,
    )
end;

In [None]:
function random_subsample(a, N)
    size = min(length(a), N)
    sample(a, size; replace = false)
end

function subsample(u::Int32, list_size::Integer, f::Features)
    # filter out users that haven't watched any items
    if u ∉ keys(f.user_to_indexes)
        return Int32[], false
    end
    if u in keys(f.user_to_watched_indexes)
        watched_list = f.user_to_watched_indexes[u]
    else
        return Int32[], false
    end

    # preferentially sample items that have been watched
    num_forced_items = Int(round(list_size * f.sampling_factor))
    list = random_subsample(watched_list, num_forced_items)
    remaining = random_subsample(f.user_to_indexes[u], list_size)
    for i = 1:length(remaining)
        if length(list) == list_size
            break
        end
        if remaining[i] ∉ list
            push!(list, remaining[i])
        end
    end

    # pad to list_size
    while length(list) < list_size
        push!(list, -1)
    end
    list, true
end;

In [None]:
function get_query_embedding(f::Features, q::Integer)
    if q == -1
        return zeros(Float32, size(f.query_features)[1])
    else
        return f.query_features[:, q]
    end
end

function prio(f::Features, i::Integer)
    if i == -1
        return Float16[0, NaN, NaN, NaN]
    else
        return f.priorities[:, i]
    end
end

function get_sample(f::Features, training::Bool, list_size::Integer)
    max_training_user = Int(floor(num_users() * 0.9))
    if training
        user_range = Int32(1):Int32(max_training_user)
    else
        user_range = Int32(max_training_user + 1):Int32(num_users())
    end

    while true
        # sample a random user
        u = rand(user_range)
        list, ok = subsample(u, list_size, f)
        if !ok
            continue
        end
        prefs = get_preferences(list, (i, j) -> compare(prio(f, i), prio(f, j)))
        q_embs = hcat((get_query_embedding(f, q) for q in list)...)
        return q_embs, prefs
    end
end;

In [None]:
function get_batch(f::Features, training::Bool, list_size::Integer, batch_size::Integer)
    q_embs = Matrix{Float32}[]
    prefs = Matrix{Int32}[]
    for _ = 1:batch_size
        q_emb, pref = get_sample(f, training, list_size)
        push!(q_embs, q_emb)
        push!(prefs, pref)
    end

    # move to GPU
    Q = device(Flux.batch(q_embs))
    P = device(Flux.batch(prefs))
    Q, P
end;

In [None]:
function build_model(hyp::Hyperparams)
    K = hyp.embedding_size
    Chain(
        Dense(length(hyp.alphas), K, relu),
        Dense(K => div(K, 2), relu),
        Dense(div(K, 2), 1),
    )
end;

## Train model

In [None]:
function get_alphas(task::String)
    alphas = [
        "$task/LinearExplicit"
        "$task/LinearImplicit"
        "$task/Explicit"
        "$task/NonlinearImplicit"
        explicit_raw_alphas(task)
        implicit_raw_alphas(task)
        nondirectional_raw_alphas
    ]
    alphas
end;

In [None]:
hyp = Hyperparams(
    alphas = [],
    batch_size = 1024,
    embedding_size = 256,
    l2penalty = NaN,
    learning_rate = NaN,
    list_size = 64,
    seed = 20220609,
)
hyp = @set hyp.alphas = get_alphas(task)
hyp = create_hyperparams(hyp, [0.0f0, 0.0f0])

In [None]:
train_alpha(hyp, task, "$task/MLE.Ensemble.list_size_64")