# Baseline Helper
* Functions that are shared between training and inference

In [None]:
import Statistics: mean

In [None]:
function num_threads()
    # using too many threads causes OOM errors
    min(Threads.nthreads(), 8)
end

In [None]:
function make_prediction(users, items, u, a)
    r = Array{eltype(u)}(undef, length(users))
    Threads.@threads for i = 1:length(r)
        @inbounds r[i] = u[users[i] + 1] + a[items[i] + 1]
    end
    r
end;

In [None]:
function get_residuals!(users, items, ratings, weights, a, ρ, Ω)
    @inbounds for row = 1:length(users)
        i = users[row] + 1
        j = items[row] + 1
        r = ratings[row]
        w = weights[row]
        ρ[i] += (r - a[j]) * w
        Ω[i] += w
    end
    ρ, Ω
end

# partitions the range 1:n
function thread_range(tid, n)
    nt = num_threads()
    d, r = divrem(n, nt)
    from = (tid - 1) * d + min(r, tid - 1) + 1
    to = from + d - 1 + (tid ≤ r ? 1 : 0)
    from:to
end;

function update_users!(users, items, ratings, weights, u, a, λ_u, ρ, Ω; μ = nothing)
    Threads.@threads :static for t = 1:num_threads()
        range = thread_range(t, length(ratings))
        ρ[:, t] .= 0
        Ω[:, t] .= 0
        @views get_residuals!(
            users[range],
            items[range],
            ratings[range],
            weights[range],
            a,
            ρ[:, t],
            Ω[:, t],
        )
    end

    ρ = sum(ρ, dims = 2)
    Ω = sum(Ω, dims = 2)
    if isnothing(μ)
        μ = mean(u)
    end
    Threads.@threads for i = 1:length(u)
        @inbounds u[i] = (ρ[i] + μ * λ_u) / (Ω[i] + λ_u)
    end
end;