## Write predictions

In [None]:
function evaluate(hyp::Hyperparams, m, users::Vector{Int32}, items::Vector{Int32})
    # returns a ratings dataset of predicted ratings    
    if hyp.output_data == "allitems"
        return  evaluate_allitems(hyp, m, users, items)
    elseif hyp.output_data == "item"
        return  evaluate_item(hyp, m, users, items)
    else
        @assert false
    end
end;

In [None]:
function evaluate_item(hyp::Hyperparams, m, users::Vector{Int32}, items::Vector{Int32})
    # get model inputs
    global G = hyp
    m = m |> device
    epoch = (
        get_epoch_inputs(
            users,
            items,
            G.input_data,
            G.output_data,
            G.implicit,
            G.num_users,
            G.input_alphas,
        ),
        nothing,
        nothing,
        nothing,
    )
    activation = G.implicit ? sigmoid : identity
    ratings = fill(NaN32, length(out_users))

    # compute predictions    
    @showprogress for iter = 1:Int(ceil(epoch_size(epoch) / G.batch_size))
        batch, order = get_batch(epoch, iter, G.batch_size, false)
        alpha = activation(m(batch[1][1])) |> cpu
        ratings[order] .= vec(alpha)
    end

    global G = nothing
    RatingsDataset(user = users, item = items, rating = ratings)
end;

In [None]:
# returns a vector that maps a user to the list of items to predict
function user_to_items(users::Vector, items::Vector)
        user_to_count = zeros(Int32, num_users(), Threads.nthreads())
        @tprogress Threads.@threads for u in users
            user_to_count[u, Threads.threadid()] += 1
        end
        user_to_count = convert.(Int32, vec(sum(user_to_count, dims = 2)))

        utoa = Vector{Vector{Int32}}()
        @showprogress for u = 1:num_users()
            push!(utoa, Vector{Int32}(undef, user_to_count[u]))
        end

        @showprogress for i = 1:length(users)
            u = users[i]
            a = items[i]
            utoa[u][user_to_count[u]] = a
            user_to_count[u] -= 1
        end
        utoa
end;

In [None]:
function evaluate_allitems(hyp::Hyperparams, m, users::Vector{Int32}, items::Vector{Int32})
    # get model inputs
    global G = hyp
    m = m |> device
    utoa = user_to_items(users, items)
    epoch = (
        get_epoch_inputs(
            users,
            items,
            G.input_data,
            G.output_data,
            G.implicit,
            G.num_users,
            G.input_alphas,
        ),
        nothing,
        nothing,
        nothing,
    )
    activation = G.implicit ? softmax : identity
    out_users = Vector{Int32}(undef, length(users))
    out_items = Vector{Int32}(undef, length(users))
    out_ratings = fill(NaN32, length(out_users))
    out_idx = 1

    # compute predictions    
    @showprogress for iter = 1:Int(ceil(epoch_size(epoch) / G.batch_size))
        batch, sampled_users = get_batch(epoch, iter, G.batch_size, false)
        alpha = activation(m(batch[1][1])) |> cpu
        for j = 1:length(sampled_users)
            u = sampled_users[j]
            if length(utoa[u]) > 0
                item_mask = utoa[u]
                next_idx = out_idx + length(item_mask)
                out_users[out_idx:next_idx-1] .= u
                out_items[out_idx:next_idx-1] = item_mask
                out_ratings[out_idx:next_idx-1] = alpha[item_mask, j]
                out_idx = next_idx
            end
        end
    end

    global G = nothing
    RatingsDataset(user = out_users, item = out_items, rating = out_ratings)
end;

In [None]:
function write_alpha(hyp::Hyperparams, m, outdir::String)
    hyp = @set hyp.num_users = num_users()
    function model(users, items)
        p = sparse(evaluate(hyp, m, users, items))
        r = zeros(length(users))
        @tprogress Threads.@threads for j = 1:length(r)
            r[j] = p[items[j], users[j]]
        end
        r
    end
    write_alpha(model, hyp.residual_alphas, hyp.implicit, outdir; log_splits=hyp.content)
end;