# Ranking
* This is trained to learn the partial ordering implied by each user's watches
* Items that are watched are preferred to items that have not been watched
* If two items have been watched, then the impression metadata determines
  which one, if any, is liked more
* It uses the position aware maximum likehood estimation loss  
* The inputs to this model are features generated by other models

In [None]:
import NBInclude: @nbinclude
@nbinclude("MLE.Base.ipynb");

## Define Subclass

In [None]:
@with_kw struct EnsembleFeatures <: Features
    user_features::SparseMatrixCSC{Float32,Int32}
    query_features::Matrix{Float32}
    preprocessing_data::Dict

    priorities::Matrix{Float16}

    index_to_item::Vector{Int32}
    user_to_indexes::Dict{Int32,Vector{Int32}}
    item_user_index::SparseMatrixCSC{Int32,Int32}
end

function get_inference_data(f::Features)
    f.preprocessing_data
end;

In [None]:
function get_query_features(alphas::Vector{String}, split::String, content::String)
    @info "getting $split $content alphas"
    df = get_raw_split(split, content)
    T = Float16
    A = Matrix{T}(undef, length(df.user), length(alphas))
    @tprogress Threads.@threads for i = 1:length(alphas)
        A[:, i] = convert.(T, read_raw_alpha(alphas[i], split, content).rating)
    end
    collect(A')
end;

function normalize(x::AbstractArray; dims = 1)
    T = eltype(x)
    x = convert.(Float32, x)
    μ = mean(x, dims = dims)
    σ = std(x, dims = dims, mean = μ, corrected = false)
    convert.(T, (x .- μ) ./ σ), Dict("μ" => μ, "σ" => σ)
end

function get_implicit_features()
    df = get_split("training", "implicit")
    sparse(df.item, df.user, df.rating, num_items(), num_users())
end

function get_explicit_features()
    df = get_split("training", "explicit")
    sparse(df.item, df.user, df.rating, num_items(), num_users())
end

function get_user_features()
    vcat(get_implicit_features(), get_explicit_features())
end;

In [None]:
function get_features(alphas::Vector{String})
    contents = all_contents
    splits = ["test"]

    user_to_indexes = get_user_to_indexes(
        [(split, content) for split in splits for content in contents],
        (split, content) -> true,
    )

    hreduce(f; agg = hcat) =
        reduce(agg, f(split, content) for split in splits for content in contents)
    user_features = get_user_features()
    query_features, preprocessing_data = normalize(
        hreduce((split, content) -> get_query_features(alphas, split, content));
        dims = 2,
    )
    query_features = convert.(Float32, query_features)
    priorities = hreduce(get_priorities)
    index_to_item =
        hreduce((split, content) -> get_raw_split(split, content).item; agg = vcat)

    item_user_index = sparse(Int32[], Int32[], Int32[], num_items(), num_users())
    idx = 1
    for split in splits
        for content in contents
            df = get_raw_split(split, content)
            sp =
                sparse(df.item, df.user, fill(1, length(df.item)), num_items(), num_users())
            @tprogress Threads.@threads for i = 1:length(df.item)
                sp[df.item[i], df.user[i]] = i + (idx - 1)
            end
            item_user_index += sp
            idx += length(df.item)
        end
    end

    EnsembleFeatures(
        user_features = user_features,
        query_features = query_features,
        preprocessing_data = preprocessing_data,
        priorities = priorities,
        index_to_item = index_to_item,
        user_to_indexes = user_to_indexes,
        item_user_index = item_user_index,
    )
end

function get_user_embedding(u::Integer, f::Features)
    f.user_features[:, u]
end

function get_item_embedding(q::Integer, f::Features)
    f.index_to_item[q]
end;

function get_query_embedding(q::Integer, f::Features)
    f.query_features[:, q]
end;

In [None]:
function get_sample(f::Features, training::Bool, list_size::Integer)
    max_training_user = Int(floor(num_users() * 0.9))
    if training
        user_range = Int32(1):Int32(max_training_user)
    else
        user_range = Int32(max_training_user + 1):Int32(num_users())
    end

    while true
        # sample a random user
        u = rand(user_range)
        if u ∉ keys(f.user_to_indexes)
            continue
        end
        idxs = f.user_to_indexes[u]
        if length(idxs) < list_size
            continue
        end
        # sample random items for the user
        list = sample(idxs, list_size; replace = false)
        if all(f.priorities[1, i] == 0 for i in list)
            # need atleast one positive example to train on
            continue
        end
        prefs =
            get_preferences(list, (i, j) -> compare(f.priorities[:, i], f.priorities[:, j]))
        # batch the input features  
        u_embs = hcat(fill(get_user_embedding(u, f), list_size)...)
        a_embs = Int32[get_item_embedding(q, f) for q in list]
        q_embs = hcat((get_query_embedding(q, f) for q in list)...)
        return u_embs, a_embs, q_embs, prefs
    end
end;

In [None]:
function get_batch(
    f::Features,
    training::Bool,
    list_size::Integer,
    batch_size::Integer,
    holdout::Float32,
)
    u_embs = SparseMatrixCSC{Float32,Int32}[]
    a_embs = Vector{Int32}[]
    q_embs = Matrix{Float32}[]
    prefs = Matrix{Int32}[]
    for _ = 1:batch_size
        u_emb, a_emb, q_emb, pref = get_sample(f, training, list_size)
        push!(u_embs, u_emb)
        push!(a_embs, a_emb)
        push!(q_embs, q_emb)
        push!(prefs, pref)
    end

    # move to GPU
    U = device(hcat(u_embs...))
    A = device(hcat(a_embs...))
    Q = device(Flux.batch(q_embs))
    P = device(Flux.batch(prefs))
    tsize = (size(U)[1], size(U)[2] ÷ batch_size, batch_size)
    if training
        randfn = CUDA.functional() ? CUDA.rand : rand
        mask = randfn(num_items()) .>= holdout
        U .*= repeat(mask, size(U)[1] ÷ size(mask)[1])
    end
    (reshape(U, tsize), A, Q), P
end;

In [None]:
function build_model(hyp::Hyperparams)
    K = hyp.embedding_size
    Chain(
        Join(
            vcat,
            Dense((num_items()) * 2 => K),
            Embedding((num_items()) => K; init = Flux.glorot_uniform),
            identity,
        ),
        Dense(K * 2 + length(hyp.alphas), K, relu),
        Dense(K => K ÷ 2, relu),
        Dense(K ÷ 2, 1),
    )
end;

In [None]:
function inference_model(
    hyp::Hyperparams,
    f::Features,
    split::String,
    content::String;
    raw_splits = true,
)
    @info "making predictions for $split $content"

    if raw_splits
        df = get_raw_split(split, content)
    else
        df = get_split(split, content)
    end
    if split in ["training", "validation"]
        return zeros(Float32, length(df.item))
    end

    output = Array{Float32}(undef, length(df.item))
    @showprogress for batch in
                      collect(Iterators.partition(1:length(df.item), hyp.batch_size))
        u_embs = SparseVector{Float32,Int32}[]
        a_embs = Int32[]
        q_embs = Vector{Float32}[]
        for i in batch
            push!(u_embs, get_user_embedding(df.user[i], f))
            push!(a_embs, df.item[i])
            push!(q_embs, get_query_embedding(f.item_user_index[df.item[i], df.user[i]], f))
        end
        U, A, Q = device(hcat(u_embs...)), device(a_embs), device(hcat(q_embs...))
        output[batch] .= cpu(vec(m((U, A, Q))))
    end
    output
end;

## Train model

In [None]:
function get_alphas(allow_ptw::Bool)
    alphas = [
        "LinearExplicit"
        "LinearImplicit"
        "Explicit"
        "NonlinearImplicit"
        explicit_raw_alphas
        implicit_raw_alphas
        nondirectional_raw_alphas
    ]
    if allow_ptw
        alphas = vcat(
            alphas,
            [
                "LinearPtw"
                "NonlinearPtw"
                ptw_raw_alphas
            ],
        )
    end
    alphas
end;

In [None]:
hyp = Hyperparams(
    alphas = [],
    batch_size = 1024,
    embedding_size = 256,
    holdout = NaN,
    l2penalty = NaN,
    learning_rate = NaN,
    list_size = 16,
    seed = 20220609,
)
hyp = @set hyp.alphas = get_alphas(false)
hyp = create_hyperparams(hyp, [0.0f0, 0.0f0, 0.0f0])

In [None]:
train_alpha(hyp, "MLE.Ensemble")

In [None]:
# 1.304526987566872 with alphas = ["MLE.Training"]
# 1.1126399600829462 with all non-ptw alphas