# MLE Training

In [None]:
import NBInclude: @nbinclude
@nbinclude("MLE.Base.ipynb");

## Define Subclass

In [None]:
@with_kw struct TrainingFeatures <: Features
    user_features::SparseMatrixCSC{Float32,Int32}
    priorities::Matrix{Float16}

    index_to_item::Vector{Int32}
    index_to_training::Vector{Bool}
    item_user_index::SparseMatrixCSC{Int32,Int32}

    user_to_training_indexes::Dict{Int32,Vector{Int32}}
    user_to_validation_indexes::Dict{Int32,Vector{Int32}}
    user_to_training_items::Dict{Int32,Set{Int32}}
end;

In [None]:
function get_inference_data(f::Features)
    Dict()
end;

In [None]:
function get_implicit_features()
    df = get_split("training", "implicit")
    sparse(df.item, df.user, df.rating, num_items() + 1, num_users())
end
function get_explicit_features()
    df = get_split("training", "explicit")
    sparse(df.item, df.user, df.rating, num_items() + 1, num_users())
end
function get_user_features()
    vcat(get_implicit_features(), get_explicit_features())
end;

In [None]:
function get_features(alphas::Vector{String})
    @assert length(alphas) == 0
    contents = filter(x -> x != "negative", all_contents)
    splits = ["training", "validation"]

    user_to_training_indexes = get_user_to_indexes(
        [(split, content) for split in splits for content in contents],
        (split, content) -> split == "training",
    )
    user_to_validation_indexes = get_user_to_indexes(
        [(split, content) for split in splits for content in contents],
        (split, content) -> split == "validation",
    )

    hreduce(f; agg = hcat) =
        reduce(agg, f(split, content) for split in splits for content in contents)
    user_features = get_user_features()
    priorities = hreduce(get_priorities)
    index_to_item =
        hreduce((split, content) -> get_raw_split(split, content).item; agg = vcat)
    index_to_training = hreduce(
        (split, content) -> fill(
            split == "training" ? true : false,
            length(get_raw_split(split, content).item),
        );
        agg = vcat,
    )

    item_user_index = sparse(Int32[], Int32[], Int32[], num_items() + 1, num_users())
    idx = 1
    for split in splits
        for content in contents
            df = get_raw_split(split, content)
            sp =
                sparse(df.item, df.user, fill(1, length(df.item)), num_items(), num_users())
            @tprogress Threads.@threads for i = 1:length(df.item)
                sp[df.item[i], df.user[i]] = i + (idx - 1)
            end
            item_user_index += sp
            idx += length(df.item)
        end
    end

    user_to_training_items::Dict{Int32,Set{Int32}} = Dict()
    for u = 1:num_users()
        s = Set{Int32}()
        if u in keys(user_to_training_indexes)
            for i in user_to_training_indexes[u]
                push!(s, index_to_item[i])
            end
        end
        user_to_training_items[u] = s
    end

    TrainingFeatures(
        user_features = user_features,
        priorities = priorities,
        index_to_item = index_to_item,
        index_to_training = index_to_training,
        item_user_index = item_user_index,
        user_to_training_indexes = user_to_training_indexes,
        user_to_validation_indexes = user_to_validation_indexes,
        user_to_training_items = user_to_training_items,
    )
end

function get_user_embedding(u::Integer, list::Vector{Int32}, f::Features)
    U = f.user_features[:, u]
    for a in list
        for i = a:num_items()+1:length(U)
            U[i] = 0
        end
    end
    U
end

function get_item_embedding(q::Integer, f::Features)
    q
end;

In [None]:
function get_priority(f::Features, u::Integer, i::Integer, training::Bool)
    idx = f.item_user_index[i, u]
    if (idx == 0) || (training && !f.index_to_training[idx])
        return Float16[0, NaN, NaN, NaN]
    end
    f.priorities[:, idx]
end

function comparator(f::Features, u::Integer, i::Integer, j::Integer, training::Bool)
    lhs = get_priority(f, u, i, training)
    rhs = get_priority(f, u, j, training)
    compare(lhs, rhs)
end

function get_sample(f::Features, training::Bool, list_size::Integer)
    while true
        # sample an item the user has seen
        u = rand(1:num_users())
        if training
            u_to_idxs = f.user_to_training_indexes
        else
            u_to_idxs = f.user_to_validation_indexes
        end
        if u ∉ keys(u_to_idxs)
            continue
        end
        nonneg_item = f.index_to_item[rand(u_to_idxs[u])]
        # sample random items to fill out the list
        list = sample(Int32(1):Int32(num_items() + 1), list_size; replace = false)
        if nonneg_item ∉ list
            list[1] = nonneg_item
        end
        if !training
            training_items = f.user_to_training_items[u]
            for i = 1:length(list)
                if list[i] in training_items
                    list[i] = num_items() + 1
                end
            end
        end
        prefs = get_preferences(list, (i, j) -> comparator(f, u, i, j, training))
        # batch the input features
        u_embs = hcat(fill(get_user_embedding(u, list, f), list_size)...)
        a_embs = Int32[get_item_embedding(q, f) for q in list]
        return u_embs, a_embs, prefs
    end
end;

In [None]:
function get_batch(
    f::Features,
    training::Bool,
    list_size::Integer,
    batch_size::Integer,
    holdout::Float32,
)
    u_embs = SparseMatrixCSC{Float32,Int32}[]
    a_embs = Vector{Int32}[]
    prefs = Matrix{Int32}[]
    for i = 1:batch_size
        u_emb, a_emb, pref = get_sample(f, training, list_size)
        push!(u_embs, u_emb)
        push!(a_embs, a_emb)
        push!(prefs, pref)
    end

    # move to GPU
    U = device(hcat(u_embs...))
    A = device(hcat(a_embs...))
    P = device(Flux.batch(prefs))
    tsize = (size(U)[1], size(U)[2] ÷ batch_size, batch_size)
    if training
        randfn = CUDA.functional() ? CUDA.rand : rand
        mask = randfn(num_items() + 1) .>= holdout
        U .*= repeat(mask, size(U)[1] ÷ size(mask)[1])
    end
    (reshape(U, tsize), A), P
end;

In [None]:
function build_model(hyp::Hyperparams)
    K = hyp.embedding_size
    Chain(
        Join(
            vcat,
            Dense((num_items() + 1) * 2 => K),
            Embedding((num_items() + 1) => K; init = Flux.glorot_uniform),
        ),
        Dense(K * 2, K, relu),
        Dense(K => K ÷ 2, relu),
        Dense(K ÷ 2, 1),
    )
end;

In [None]:
function inference_model(
    m,
    hyp::Hyperparams,
    f::Features,
    split::String,
    content::String;
    raw_splits = true,
)
    @info "making predictions for $split $content"
    if raw_splits
        df = get_raw_split(split, content)
    else
        df = get_split(split, content)
    end
    if split == "training"
        return zeros(Float32, length(df.item))
    end

    output = Array{Float32}(undef, length(df.item))
    @showprogress for batch in
                      collect(Iterators.partition(1:length(df.item), hyp.batch_size))
        u_embs = SparseVector{Float32,Int32}[]
        a_embs = Int32[]
        for i in batch
            push!(u_embs, get_user_embedding(df.user[i], Int32[], f))
            push!(a_embs, get_item_embedding(df.item[i], f))
        end
        inputs = device(hcat(u_embs...)), device(a_embs)
        output[batch] .= cpu(vec(m(inputs)))
        device_free!(inputs)
    end
    output
end;

## Train model

In [None]:
hyp = Hyperparams(
    alphas = String[],
    embedding_size = 256,
    batch_size = 1024,
    holdout = NaN,
    l2penalty = NaN,
    learning_rate = NaN,
    list_size = 16,
    seed = 20220609,
)
hyp = create_hyperparams(hyp, Float32[0, 0, 0])

In [None]:
train_alpha(hyp, "MLE.Training")