## Write predictions

In [31]:
# returns a dict that maps a user to the list of items they have watched
function user_to_items(users, items)
    utoa = Dict()
    # making this multithreaded is slower
    @showprogress for j = 1:length(users)
        u = users[j]
        a = items[j]
        if u ∉ keys(utoa)
            utoa[u] = []
        end
        push!(utoa[u], a)
    end
    utoa
end;

In [32]:
# returns a ratings dataset of predicted ratings
function evaluate!(hyp, m, users, items, ratings)
    # get model inputs
    global G = hyp
    m = m |> device
    utoa = user_to_items(users, items)
    epoch = [get_epoch_inputs(G.input_data, G.implicit, G.num_users, G.input_alphas)]
    activation = G.implicit ? softmax : identity

    # reuse the input buffers for space efficiency
    out_users = users
    out_items = items
    out_ratings = ratings
    out_ratings .= NaN32
    out_idx = 1

    # compute predictions    
    @showprogress for iter = 1:Int(ceil(G.num_users / G.batch_size))
        batch, sampled_users = get_batch(epoch, iter, G.batch_size, false)
        alpha = activation(m(batch[1][1])) |> cpu
        for j = 1:length(sampled_users)
            u = sampled_users[j]
            if u in keys(utoa)
                item_mask = utoa[u]
                next_idx = out_idx + length(item_mask)
                out_users[out_idx:next_idx-1] .= u
                out_items[out_idx:next_idx-1] = item_mask
                out_ratings[out_idx:next_idx-1] = alpha[item_mask, j]
                out_idx = next_idx
            end
        end
    end

    global G = nothing
    RatingsDataset(user = out_users, item = out_items, rating = out_ratings)
end;

In [33]:
function write_alpha(hyp::Hyperparams, m, outdir)
    splits = reduce(cat, [get_split(split, hyp.implicit) for split in all_raw_splits])
    preds = evaluate!(hyp, m, splits.user, splits.item, splits.rating)
    sparse_preds = sparse(preds.user, preds.item, preds.rating)
    write_alpha(sparse_preds, hyp.residual_alphas, hyp.implicit, outdir)
end;