# User Item Biases With Regularization
* Prediction for user $i$ and item $j$ is $\tilde r_{ij} = u_i + a_j$
* Loss function is $L = \sum_{\Omega}w_{ij}(r_{ij} - u_i - a_j)^2 + \lambda_u \sum_i (u_i - \bar u) ^2 + \lambda_a \sum_j (a_j - \bar a)^2 $
* $\bar u$ is the mean of $u_i$ and $\bar a$ is the mean of $a_j$ 
* $\Omega$ is the set of oberved pairs $(i, j)$
* $r_{ij}$ is the rating for user $i$ and item $j$
* $w_{ij}$ is the weight for the prediction $r_{ij}$ and is modeled as a power-law in the number of items seen by $i$ and users than have seen $j$: $w_{ij} = |j' : (i, j') \in \Omega| ^ {\lambda_{wu}} |i' : (i', j) \in \Omega| ^ {\lambda_{wa}}$

In [1]:
const name = "UserItemBiases"
const residual_alphas = []
const validation_weight_scheme = "inverse";

In [2]:
using NBInclude
@nbinclude("Alpha.ipynb");

In [3]:
const training = get_residuals("training", residual_alphas)
const validation = get_residuals("validation", residual_alphas);

## Alternating Least Squares Algorithm
* $u_i = \dfrac{\sum_{j \in \Omega_i}(r_{ij} - a_j) w_{ij} + \bar u \lambda_u}{ \sum_{j \in \Omega_i} w_{ij} + \lambda_u}$
* $\Omega$ is the set of (user, item) pairs that we have ratings for
* $\Omega_i$ is subset of $\Omega$ for which the user is the $i$-th user

In [4]:
function get_residuals!(users, items, ratings, weights, a, ρ, Ω)
    for row = 1:length(users)
        i = users[row]
        j = items[row]
        r = ratings[row]
        w = weights[row]
        ρ[i] += (r - a[j]) * w
        Ω[i] += w
    end
    ρ, Ω
end

function thread_range(n)
    tid = Threads.threadid()
    nt = Threads.nthreads()
    d, r = divrem(n, nt)
    from = (tid - 1) * d + min(r, tid - 1) + 1
    to = from + d - 1 + (tid ≤ r ? 1 : 0)
    from:to
end

function update_users!(users, items, ratings, weights, u, a, λ_u, ρ, Ω)
    Threads.@threads for t = 1:Threads.nthreads()
        range = thread_range(length(ratings))
        ρ[:, Threads.threadid()] .= 0
        Ω[:, Threads.threadid()] .= 0
        @views get_residuals!(
            users[range],
            items[range],
            ratings[range],
            weights[range],
            a,
            ρ[:, Threads.threadid()],
            Ω[:, Threads.threadid()],
        )
    end

    ρ = sum(ρ, dims = 2)
    Ω = sum(Ω, dims = 2)

    μ = mean(u)
    Threads.@threads for i = 1:length(u)
        u[i] = (ρ[i] + μ * λ_u) / (Ω[i] + λ_u)
    end
end;

In [5]:
function train_model(training, stop_criteria, λ_u, λ_a, λ_wu, λ_wa)
    @info "training model with parameters [$λ_u, $λ_a, $λ_wu, $λ_wa]"
    users, items, ratings = training.user, training.item, training.rating
    weights =
        safe_exp.(get_counts("training"), log(λ_wu)) .*
        safe_exp.(get_counts("training"; by_item = true), log(λ_wa))
    u = zeros(eltype(λ_u), maximum(users))
    a = zeros(eltype(λ_a), maximum(items))

    ρ_u = zeros(eltype(u), length(u), Threads.nthreads())
    Ω_u = zeros(eltype(u), length(u), Threads.nthreads())
    ρ_a = zeros(eltype(a), length(a), Threads.nthreads())
    Ω_a = zeros(eltype(a), length(a), Threads.nthreads())

    while !stop!(stop_criteria, [u, a])
        update_users!(users, items, ratings, weights, u, a, λ_u, ρ_u, Ω_u)
        update_users!(items, users, ratings, weights, a, u, λ_a, ρ_a, Ω_a)
    end
    u, a
end;

In [6]:
function make_prediction(users, items, u, a)
    r = zeros(eltype(u), length(users))
    u_mean = mean(u)
    a_mean = mean(a)
    for i = 1:length(r)
        if users[i] > length(u)
            r[i] += mean(u)
        else
            r[i] += u[users[i]]
        end
        if items[i] > length(a)
            r[i] += mean(a)
        else
            r[i] += a[items[i]]
        end
    end
    r
end;

## Training

In [7]:
function validation_mse(λ)
    λ = exp.(λ) # ensure λ is nonnegative
    stop_criteria = convergence_stopper(1e-6, max_iters = 16)
    u, a = train_model(training, stop_criteria, λ...)
    pred_score = make_prediction(validation.user, validation.item, u, a)
    weights = get_weights("validation", validation_weight_scheme)
    mse(validation.rating, pred_score, weights)
end;

In [8]:
# Find the best regularization hyperparameters
res = optimize(
    validation_mse,
    fill(0.0f0, 4),
    LBFGS(),
    autodiff = :forward,
    Optim.Options(show_trace = true, extended_trace = true),
);
λ = exp.(Optim.minimizer(res));

[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:02:32 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0,1.0,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0,0.0,1.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0,0.0,0.0,1.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0,0.0,0.0,0.0,1.0)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:01 ( 0.12 μs/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (35.72 ns/it)[39m


Iter     Function value   Gradient norm 
     0     1.820281e+00     5.464972e-02
 * Current step size: 1.0
 * time: 0.02705097198486328
 * g(x): Float32[-0.0050086514, -1.0955863f-6, 0.019451264, 0.05464972]
 * x: Float32[0.0, 0.0, 0.0, 0.0]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:03:02 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0050212,1.0050212,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0000011,0.0,1.0000011,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9807367,0.0,0.0,0.9807367,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.94681674,0.0,0.0,0.0,0.94681674)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.67 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.00 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:03:49 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.004674,1.004674,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0000011,0.0,1.0000011,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.98205394,0.0,0.0,0.98205394,0.0), Dua

     1     1.818468e+00     1.585315e-02
 * Current step size: 0.930997
 * time: 96.36286497116089
 * g(x): Float32[-0.0008187882, -1.7059443f-6, 0.015853152, -0.003936903]
 * x: Float32[0.0046630395, 1.0199876f-6, -0.018109068, -0.05087873]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:04:36 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0057998,1.0057998,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0000026,0.0,1.0000026,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9665626,0.0,0.0,0.9665626,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9499285,0.0,0.0,0.0,0.9499285)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.63 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (35.95 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:05:23 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0103159,1.0103159,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0000092,0.0,1.0000092,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9070028,0.0,0.0,0.9070028,0.0), Dual{

     2     1.817068e+00     3.685882e-02
 * Current step size: 7.2415323
 * time: 287.321496963501
 * g(x): Float32[0.002224842, -7.730692f-7, 0.008650017, -0.03685882]
 * x: Float32[0.012773503, 1.2871211f-5, -0.13325036, -0.054425683]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:07:47 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.031327,1.031327,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0000452,0.0,1.0000452,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6376831,0.0,0.0,0.6376831,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9908382,0.0,0.0,0.0,0.9908382)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.57 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (35.96 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:08:34 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0238904,1.0238904,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0000323,0.0,1.0000323,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.72389185,0.0,0.0,0.72389185,0.0), Dual{

     3     1.815303e+00     5.056256e-02
 * Current step size: 0.59957314
 * time: 381.53649282455444
 * g(x): Float32[0.0042570867, 1.14318345f-5, -0.0033773936, -0.05056256]
 * x: Float32[0.023609465, 3.22618f-5, -0.32311326, -0.027311988]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:09:21 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0260113,1.0260113,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.99995685,0.0,0.99995685,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.65157425,0.0,0.0,0.65157425,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0156094,0.0,0.0,0.0,1.0156094)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.61 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.03 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:10:08 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0256885,1.0256885,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9999683,0.0,0.9999683,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6620865,0.0,0.0,0.6620865,0.0), D

     4     1.814570e+00     1.824260e-03
 * Current step size: 0.847936
 * time: 475.6453859806061
 * g(x): Float32[0.00034904332, 2.3264034f-5, 0.00048238898, 0.00182426]
 * x: Float32[0.025364114, -3.1709635f-5, -0.4123591, 0.00898036]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:10:56 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0256451,1.0256451,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.99995655,0.0,0.99995655,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.66004425,0.0,0.0,0.66004425,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0087895,0.0,0.0,0.0,1.0087895)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.63 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.00 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:11:42 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0256531,1.0256531,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9999587,0.0,0.9999587,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6604213,0.0,0.0,0.6604213,0.0), D

     5     1.814569e+00     4.978255e-04
 * Current step size: 0.8151318
 * time: 569.7327129840851
 * g(x): Float32[0.0004978255, 2.3935152f-5, -1.6750528f-5, 0.00015245673]
 * x: Float32[0.025329629, -4.1283958f-5, -0.4148773, 0.0087935245]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:12:30 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0253233,1.0253233,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.99994534,0.0,0.99994534,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6610061,0.0,0.0,0.6610061,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0085347,0.0,0.0,0.0,1.0085347)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.56 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (35.98 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:13:17 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0240049,1.0240049,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.99989176,0.0,0.99989176,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6633503,0.0,0.0,0.6633503,0.0), D

     6     1.814569e+00     5.393652e-04
 * Current step size: 2.743032
 * time: 710.9488139152527
 * g(x): Float32[0.0005393652, 2.3494218f-5, 2.7530128f-5, -0.00050523505]
 * x: Float32[0.024447335, -7.8015204f-5, -0.41244963, 0.007984106]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:14:51 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0234873,1.0234873,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.99986386,0.0,0.99986386,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.661971,0.0,0.0,0.661971,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0079305,0.0,0.0,0.0,1.0079305)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.64 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (35.98 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:15:38 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0184577,1.0184577,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.99963146,0.0,0.99963146,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6617486,0.0,0.0,0.6617486,0.0), Dua

     7     1.814477e+00     5.985096e-04
 * Current step size: 287.29395
 * time: 993.5735778808594
 * g(x): Float32[4.905001f-5, 3.600637f-5, -7.690468f-6, -0.0005985096]
 * x: Float32[-0.32937917, -0.016775848, -0.43658063, -0.016396757]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:19:33 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7143449,0.7143449,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.958061,0.0,0.958061,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.644876,0.0,0.0,0.644876,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9838099,0.0,0.0,0.0,0.9838099)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.60 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (35.91 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:20:20 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6945924,0.6945924,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.86319387,0.0,0.86319387,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6394394,0.0,0.0,0.6394394,0.0), Dual{Fo

     8     1.814475e+00     1.301556e-03
 * Current step size: 3.379301
 * time: 1134.243127822876
 * g(x): Float32[-0.000111540496, 3.3978136f-5, -4.839737f-6, 0.0013015565]
 * x: Float32[-0.35306868, -0.10486754, -0.4437331, -0.016146176]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:21:54 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7023489,0.7023489,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.85282475,0.0,0.85282475,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.64178145,0.0,0.0,0.64178145,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.98389965,0.0,0.0,0.0,0.98389965)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.61 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.03 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:22:41 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.70162904,0.70162904,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6862344,0.0,0.6862344,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6423608,0.0,0.0,0.6423608,0.0

     9     1.814456e+00     5.075865e-04
 * Current step size: 36.447052
 * time: 1369.4394299983978
 * g(x): Float32[-3.0566647f-5, 6.2033075f-7, 0.00050758646, -0.00019168075]
 * x: Float32[-0.3624118, -2.0851705, -0.43551087, -0.019252166]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:25:49 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6995755,0.6995755,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.116371416,0.0,0.116371416,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.64235896,0.0,0.0,0.64235896,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.98277277,0.0,0.0,0.0,0.98277277)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.60 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (35.99 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:26:36 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7140796,0.7140796,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.08944248,0.0,0.08944248,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6243797,0.0,0.0,0.6243797,0

    10     1.814454e+00     8.956702e-05
 * Current step size: 1.0271554
 * time: 1510.4238839149475
 * g(x): Float32[8.470384f-6, 4.9114254f-7, -2.1622092f-5, -8.956702f-5]
 * x: Float32[-0.3571423, -2.152755, -0.44280073, -0.017326465]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:28:10 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6990393,0.6990393,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.112710185,0.0,0.112710185,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.64218974,0.0,0.0,0.64218974,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.982798,0.0,0.0,0.0,0.982798)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.59 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.00 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:28:57 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.69651055,0.69651055,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.09989279,0.0,0.09989279,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6420079,0.0,0.0,0.6420079,0.0

    11     1.814454e+00     2.276308e-04
 * Current step size: 3.6690035
 * time: 1651.4721069335938
 * g(x): Float32[-2.2096765f-5, 3.7551897f-7, 3.954319f-5, 0.00022763084]
 * x: Float32[-0.36046648, -2.2634876, -0.44306046, -0.01741908]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:30:31 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6978182,0.6978182,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.09570978,0.0,0.09570978,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6421168,0.0,0.0,0.6421168,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9827496,0.0,0.0,0.0,0.9827496)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.59 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (35.98 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:31:18 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6996905,0.6996905,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.06868501,0.0,0.06868501,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.64231044,0.0,0.0,0.64231044,0.0),

    12     1.814454e+00     5.978787e-04
 * Current step size: 15.496704
 * time: 1839.483263015747
 * g(x): Float32[5.8979465f-5, 3.5884133f-7, -0.000104989005, -0.0005978787]
 * x: Float32[-0.3500858, -3.5488982, -0.44189212, -0.017137911]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:33:39 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7061485,0.7061485,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.004041991,0.0,0.004041991,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.64309466,0.0,0.0,0.64309466,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.983015,0.0,0.0,0.0,0.983015)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.59 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (35.96 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:34:26 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.712265,0.712265,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.5777747e-6,0.0,1.5777747e-6,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6441985,0.0,0.0,0.6441985,0.0

    13     1.814454e+00     8.775190e-04
 * Current step size: 1.3392977
 * time: 1980.2877449989319
 * g(x): Float32[8.422473f-5, 2.3585695f-8, -0.00013394393, -0.00087751896]
 * x: Float32[-0.3471981, -6.1767597, -0.44131792, -0.017128522]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:36:00 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.70136464,0.70136464,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.0019732495,0.0,0.0019732495,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6425479,0.0,0.0,0.6425479,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.98284876,0.0,0.0,0.0,0.98284876)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.57 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (35.98 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:36:47 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6805564,0.6805564,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.0016070902,0.0,0.0016070902,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6399932,0.0,0.0,0.639

    14     1.814453e+00     4.975067e-06
 * Current step size: 1.426342
 * time: 2120.9826998710632
 * g(x): Float32[5.15322f-7, 2.123389f-8, -4.9750665f-6, -1.7822676f-6]
 * x: Float32[-0.35793743, -6.249951, -0.4427385, -0.017373122]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:38:21 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6990819,0.6990819,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.0016379341,0.0,0.0016379341,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6423121,0.0,0.0,0.6423121,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.982763,0.0,0.0,0.0,0.982763)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.60 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.00 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:39:08 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.69894236,0.69894236,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.00084870914,0.0,0.00084870914,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.64246,0.0,0.0,0.64246,0

    15     1.814453e+00     1.906952e-05
 * Current step size: 4.7035766
 * time: 2261.3871688842773
 * g(x): Float32[-2.6872337f-6, 5.8949126f-9, 1.9069523f-5, 1.474695f-5]
 * x: Float32[-0.35817224, -7.0230713, -0.44246784, -0.017439736]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:40:41 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6980886,0.6980886,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.00070700236,0.0,0.00070700236,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.64227563,0.0,0.0,0.64227563,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9826357,0.0,0.0,0.0,0.9826357)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.58 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (35.95 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:41:28 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.69842976,0.69842976,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.0007746662,0.0,0.0007746662,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6423441,0.0,0.0,0.6

    16     1.814453e+00     3.140748e-05
 * Current step size: 0.6050304
 * time: 2354.981173992157
 * g(x): Float32[4.570264f-7, 4.6751967f-9, 2.230804f-6, -3.1407482f-5]
 * x: Float32[-0.35892066, -7.1630783, -0.44263113, -0.01748638]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:42:15 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6991232,0.6991232,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.00058569026,0.0,0.00058569026,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.64228964,0.0,0.0,0.64228964,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.98277193,0.0,0.0,0.0,0.98277193)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.54 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.01 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:43:02 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7019038,0.7019038,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.00019137321,0.0,0.00019137321,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.64207184,0.0,0.0,

    17     1.814453e+00     4.547795e-06
 * Current step size: 1.0774752
 * time: 2495.79451584816
 * g(x): Float32[7.4153905f-7, 2.8467821f-9, -4.547795f-6, -3.2921869f-6]
 * x: Float32[-0.35785142, -7.464385, -0.44272247, -0.017369818]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:44:36 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.69914836,0.69914836,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.00046903393,0.0,0.00046903393,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6422994,0.0,0.0,0.6422994,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9827741,0.0,0.0,0.0,0.9827741)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.64 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (35.97 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:45:25 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.69903415,0.69903415,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.00021037126,0.0,0.00021037126,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6423553,0.0,0.0,0

    18     1.814453e+00     1.540225e-05
 * Current step size: 4.7776413
 * time: 2640.0163099765778
 * g(x): Float32[-1.6312562f-6, 6.459795f-10, 7.660067f-6, 1.540225f-5]
 * x: Float32[-0.35804656, -8.422065, -0.44261852, -0.017399423]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:47:00 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.69904435,0.69904435,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.00017783235,0.0,0.00017783235,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.64232135,0.0,0.0,0.64232135,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.98275584,0.0,0.0,0.0,0.98275584)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.58 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.02 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:47:47 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6990597,0.6990597,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(7.597643e-5,0.0,7.597643e-5,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.642198,0.0,0.0,0.64

    19     1.814453e+00     4.696696e-06
 * Current step size: 1.6080762
 * time: 2780.8111758232117
 * g(x): Float32[4.7711893f-7, 3.972438f-10, -2.147419f-6, -4.6966957f-6]
 * x: Float32[-0.35803774, -8.763949, -0.4426957, -0.017391577]


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:49:21 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.69903857,0.69903857,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.00012561135,0.0,0.00012561135,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.64230627,0.0,0.0,0.64230627,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9827574,0.0,0.0,0.0,0.9827574)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.61 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.02 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:50:08 training model with parameters [Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.69900584,0.69900584,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(5.2442563e-5,0.0,5.2442563e-5,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.64232075,0.0,0.0,

    20     1.814453e+00     7.348267e-06
 * Current step size: 6.887216
 * time: 2968.514011859894
 * g(x): Float32[-8.7457266f-7, 6.163955f-11, 3.4997806f-6, 7.3482665f-6]
 * x: Float32[-0.35811824, -10.2679, -0.44265687, -0.017401274]


In [9]:
@info "The optimal [λ_u, λ_a, λ_w] is $λ, found in " *
      repr(Optim.f_calls(res)) *
      " function calls"

[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:52:28 The optimal [λ_u, λ_a, λ_w] is Float32[0.6989904, 3.4730216f-5, 0.64232755, 0.9827492], found in 64 function calls


In [10]:
empty!(memoize_cache(get_weights))
stop_criteria = convergence_stopper(1e-6, max_iters = 16)
u, a = train_model(training, stop_criteria, λ...);

[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:52:29 training model with parameters [0.6989904, 3.4730216e-5, 0.64232755, 0.9827492]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.58 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (35.95 ns/it)[39m


## Inference

In [11]:
model(users, items) = make_prediction(users, items, u, a);

In [12]:
write_predictions(model; residual_alphas = residual_alphas);

[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:53:10 training set: RMSE 1.2844313 MAE 0.960797 R2 0.46220547
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (36.53 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:53:18 training set weighted-loss: RMSE 1.2922145 MAE 0.9602112 R2 0.48654854
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:53:18 validation set: RMSE 1.3195661 MAE 0.9857658 R2 0.40952367
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220514 02:53:19 validation set weighted-loss: RMSE 1.3469539 MAE 1.0024743 R2 0.39468563


In [13]:
write_params(Dict("u" => u, "a" => a, "λ" => λ));