# Ranking
* Learns the preference relation implied by future watches
* Uses a modified form of the position-aware list-mle loss

In [None]:
medium = ""

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb")
@nbinclude("Utility.ipynb");

In [None]:
import CUDA
import Flux
import MLUtils
import Optimisers
import Random
import StatsBase

# Data

In [None]:
function get_features(alphas::Vector{String}, split::String, medium::String)
    @info "getting $split $medium $alphas"
    N = length(get_raw_split(split, medium, [:userid], nothing).userid)
    T = Float16
    A = Matrix{T}(undef, N, length(alphas))
    @assert length(alphas) == get_feature_size()
    @showprogress for i = 1:length(alphas)
        x = get_raw_split(split, medium, Symbol[], alphas[i]).alpha
        # normalize and make monotonic
        if alphas[i] == "$medium/Linear/rating"
            x = clamp.(x / 10, 0, 1)
        elseif alphas[i] in ["$medium/Linear/watch", "$medium/Linear/plantowatch"]
            nothing
        elseif alphas[i] == "$medium/Linear/drop"
            x = 1 .- x
        else
            @assert false
        end
        @assert minimum(x) >= 0 && maximum(x) <= 1
        A[:, i] = convert.(T, x)
    end
    collect(A')
end

function get_features(alphas::Vector{String}, medium::String)
    reduce(hcat, [get_features(alphas, x, medium) for x in ["test", "negative"]])
end

function get_feature_size()
    4
end;

In [None]:
function get_priority_size()
    3
end

function get_priorities(split::String, medium::String)
    @info "getting $split $medium priorities"
    if split == "test"
        fields = [:userid, :itemid, :rating, :status]
    elseif split == "negative"
        fields = [:userid, :itemid]
    else
        @assert false
    end
    df = get_raw_split(split, medium, fields, nothing)
    A = Matrix{Float16}(undef, get_priority_size(), length(df.userid))
    @showprogress for i = 1:length(df.userid)
        if split == "test"
            p = Float16[1, NaN, df.status[i]]
            if df.rating[i] != 0
                p[2] = df.rating[i]
            end
        elseif split == "negative"
            p = Float16[0, NaN, NaN]
        else
            @assert false
        end
        A[:, i] = p
    end
    A
end

function get_priorities(medium::String)
    reduce(hcat, [get_priorities(x, medium) for x in ["test", "negative"]])
end;

In [None]:
function get_user_to_indexes(medium::String, splits::Vector{String})
    u_to_xs = Dict{Int32,Vector{Int32}}()
    index::Int32 = 1
    for split in splits
        df = get_raw_split(split, medium, [:userid], nothing)
        @showprogress for u in df.userid
            if u ∉ keys(u_to_xs)
                u_to_xs[u] = Int32[]
            end
            push!(u_to_xs[u], index)
            index += 1
        end
    end
    u_to_xs
end;

In [None]:
@kwdef struct Features
    features::Matrix{Float32}
    priorities::Matrix{Float16}
    user_to_indexes::Dict{Int32,Vector{Int32}}
    user_to_watched_indexes::Dict{Int32,Vector{Int32}}
    training_users::Vector{Int32}
    test_users::Vector{Int32}
end

function load_features()
    alphas = ["$medium/Linear/$metric" for metric in ALL_METRICS]
    F = get_features(alphas, medium)
    P = get_priorities(medium)
    u_to_i = get_user_to_indexes(medium, ["test", "negative"])
    u_to_w = get_user_to_indexes(medium, ["test"])

    users = collect(keys(u_to_i))
    test_users =
        Set(StatsBase.sample(users, Int(round(length(users) * 0.1)); replace = false))
    training_users = Set(x for x in users if x ∉ test_users)
    Features(F, P, u_to_i, u_to_w, collect(training_users), collect(test_users))
end;

# Batching

In [None]:
function subsample(f::Features, u::Int32, list_size::Int32)
    l = f.user_to_indexes[u]
    w = f.user_to_watched_indexes[u]
    list = StatsBase.sample(l, min(length(l), list_size); replace = false)
    # ensure at least one item is watched
    if all(f.priorities[1, i] == 0 for i in list)
        list[1] = rand(w)
    end
    # pad to list_size
    while length(list) < list_size
        push!(list, -1)
    end
    list
end

function get_feature(f::Features, i::Int32)
    if i == -1
        return zeros(Float32, size(f.features)[1])
    else
        return f.features[:, i]
    end
end

function get_priority(f::Features, i::Int32)
    if i == -1
        return Float16[0, NaN, NaN]
    else
        return f.priorities[:, i]
    end
end

function get_sample(f::Features, user::Int32, list_size::Int32)
    list = subsample(f, user, list_size)
    features = hcat((get_feature(f, q) for q in list)...)
    prios = MLUtils.batch(get_priority(f, i) for i in list)
    features, prios
end

function get_epoch(f::Features, training::Bool, list_size::Int32)
    if training
        users = f.training_users
    else
        users = f.test_users
    end
    users = Random.shuffle(users)
    feats = Vector{Matrix{Float32}}(undef, length(users))
    prios = Vector{Matrix{Float16}}(undef, length(users))
    @showprogress for i = 1:length(users)
        feat, prio = get_sample(f, users[i], list_size)
        feats[i] = feat
        prios[i] = prio
    end
    Q = MLUtils.batch(feats)
    P = MLUtils.batch(prios)
    Q, P
end;

# Model

In [None]:
function build_model()
    # we constrain the model to be monotonic
    # TODO replace with monotonic networks paper
    N = get_feature_size()
    Join(sum, Split(Utility(N, "log"), Utility(N, "linear")))
end;

In [None]:
device = Flux.gpu;
cpu = Flux.cpu;

# Training

In [None]:
function get_partial_ordering(P, i, interaction_weights)
    # returns the weighted adjacency matrix of the partial ordering graph
    # note that diagonal entries are zero
    a = (P[1, :, i] .> P[1, :, i]') .* interaction_weights[1]
    for j = 2:size(P)[1]
        b = (a .== 0) .&& (a' .== 0) .&& (P[j, :, i] .> P[j, :, i]')
        a = a + b .* interaction_weights[j]
    end
    a
end

function get_position_weights(P, z)
    P = convert.(Float32, (P .> 0))
    w = P * z
    total = z' * w
    total == 0 ? w : w ./ total
end

function position_aware_list_mle_loss(m, x, priorities, interaction_weights)
    # position aware list mle loss with a modifications to handle non-comparable items and 
    # varying comparison strengths. See [Position-Aware ListMLE: A Sequential Learning Process 
    # for Ranking](https://auai.org/uai2014/proceedings/individuals/164.pdf)
    p = Flux.flatten(m(x))
    p = p .- maximum(p; dims = 1)
    q = exp.(p)
    ϵ = Float32(eps(Float64))
    N, batch_size = size(p)
    z = ones(Float32, N) |> device

    total = 0.0f0
    for i = 1:batch_size
        P = get_partial_ordering(priorities, i, interaction_weights)
        w = get_position_weights(P, z)
        unweighted_loss = -p[:, i] + log.(P * q[:, i] + q[:, i] .+ ϵ)
        total += sum(w .* unweighted_loss)
    end
    total / batch_size
end;

In [None]:
function evalute_metrics(m, epoch, batch_size::Int32, interaction_weights::Vector{Float32})
    N = size(epoch[1])[2]
    loss = 0.0
    iters = 0
    @showprogress for idx in Iterators.partition(1:N, batch_size)
        df = epoch[1][:, :, idx] |> device
        dp = epoch[2][:, :, idx] |> device
        loss += position_aware_list_mle_loss(m, df, dp, interaction_weights)
        iters += 1
    end
    loss / iters
end;

In [None]:
function train_epoch!(
    m,
    opt,
    epoch,
    batch_size::Int32,
    interaction_weights::Vector{Float32},
)
    N = size(epoch[1])[2]
    loss = 0.0
    iters = 0
    indexes = Random.shuffle(collect(1:N))
    @showprogress for idx in Iterators.partition(indexes, batch_size)
        df = epoch[1][:, :, idx] |> device
        dp = epoch[2][:, :, idx] |> device
        tloss, grads = Flux.withgradient(m) do model
            position_aware_list_mle_loss(model, df, dp, interaction_weights)
        end
        Flux.update!(opt, m, grads[1])
        loss += tloss
        iters += 1
    end
    loss / iters
end;

In [None]:
function train_model(
    f::Features,
    learning_rate::Float32,
    weight_decay::Float32,
    batch_size::Int32,
    interaction_weights::Vector{Float32},
    list_size::Int32,
)
    m = build_model() |> device
    opt = Optimisers.setup(
        Optimisers.OptimiserChain(
            Optimisers.Adam(learning_rate, (0.9f0, 0.999f0)),
            Optimisers.WeightDecay(learning_rate * weight_decay),
        ),
        m,
    )
    training = get_epoch(f, true, list_size)
    test = get_epoch(f, false, list_size)
    best_loss = Inf
    best_model = m |> cpu
    patience = 0
    for _ = 1:16
        training_loss = train_epoch!(m, opt, training, batch_size, interaction_weights)
        test_loss = evalute_metrics(m, test, batch_size, interaction_weights)
        @info "Losses: $training_loss $test_loss"
        if test_loss < best_loss
            best_loss = test_loss
            best_model = m |> cpu
            patience = 0
        else
            patience += 1
            if patience >= 2
                break
            end
        end
    end
    best_model
end;

# Run

In [None]:
Random.seed!(20240213)
Random.seed!(CUDA.default_rng(), rand(UInt64))
Random.seed!(CUDA.CURAND.default_rng(), rand(UInt64))

In [None]:
f = load_features();

In [None]:
learning_rate = 1f-3
weight_decay = 1f-2
batch_size = Int32(8)
interaction_weights = Float32[1, 1f-3, 0] # TODO ablate
list_size = Int32(10240);

In [None]:
m = train_model(f, learning_rate, weight_decay, batch_size, interaction_weights, list_size);

In [None]:
write_params(Dict("m" => m), "$medium/Ranking", true)