# User Item Biases With Regularization
* Prediction for user $i$ and item $j$ is $\tilde r_{ij} = u_i + a_j$
* Loss function is $L = \sum_{\Omega}w_{ij}\text{loss}(r_{ij}, \tilde r_{ij}) + \lambda_u \sum_i (u_i - \bar u) ^2 + \lambda_a \sum_j (a_j - \bar a)^2 $
* $\bar u$ is the mean of $u_i$ and $\bar a$ is the mean of $a_j$ 
* $\Omega$ is the set of oberved pairs $(i, j)$
* $r_{ij}$ is the rating for user $i$ and item $j$
* $w_{ij}$ is the weight for the prediction $r_{ij}$ and is modeled as a power-law in the number of items seen by $i$ and users than have seen $j$: $w_{ij} = |j' : (i, j') \in \Omega| ^ {\lambda_{wu}} |i' : (i', j) \in \Omega| ^ {\lambda_{wa}}$
* $\text{loss}$ is mean squared error

In [1]:
const name = "ExplicitUserItemBiases"
const implicit = false;

In [2]:
import NBInclude: @nbinclude
@nbinclude("Alpha.ipynb");
@nbinclude("ExplicitUserItemBiasesBase.ipynb");

In [3]:
const training = get_split("training", implicit)
const validation = get_split("validation", implicit);

## Alternating Least Squares
* Given some hyperparameters $\lambda$, we can solve for $U$ and $A$ via Alternating Least Squares
* This is an iterative algorithm where we fix $A$, then solve for the $U$ that minimizes the loss function
* Then we fix $U$ and solve for the best $A$
* These two steps are repeated until the matrices $U$ and $A$ converge
### More details
* If we fix $a$, then for each user $i$, $u_i$ is optimized when
* $u_i = \dfrac{\sum_{j \in \Omega_i}(r_{ij} - a_j) w_{ij} + \bar u \lambda_u}{ \sum_{j \in \Omega_i} w_{ij} + \lambda_u}$
* $\Omega$ is the set of (user, item) pairs that we have ratings for
* $\Omega_i$ is subset of $\Omega$ for which the user is the $i$-th user

In [4]:
function train_model(training, stop_criteria, λ)
    @info "training model with parameters $λ"
    λ_u, λ_a, λ_wu, λ_wa = λ
    users, items, ratings = training.user, training.item, training.rating
    weights =
        expdecay(get_counts("training", implicit), log(λ_wu)) .*
        expdecay(get_counts("training", implicit; by_item = true), log(λ_wa))
    u = zeros(eltype(λ_u), num_users())
    a = zeros(eltype(λ_a), num_items())

    ρ_u = zeros(eltype(u), length(u), Threads.nthreads())
    Ω_u = zeros(eltype(u), length(u), Threads.nthreads())
    ρ_a = zeros(eltype(a), length(a), Threads.nthreads())
    Ω_a = zeros(eltype(a), length(a), Threads.nthreads())

    while !stop!(stop_criteria, [u, a])
        update_users!(users, items, ratings, weights, u, a, λ_u, ρ_u, Ω_u)
        update_users!(items, users, ratings, weights, a, u, λ_a, ρ_a, Ω_a)
    end
    u, a
end;

In [5]:
function validation_mse(λ)
    λ = exp.(λ) # ensure λ is nonnegative
    stop_criteria = convergence_stopper(1e-6, max_iters = 16)
    u, a = train_model(training, stop_criteria, λ)
    r = make_prediction(validation.user, validation.item, u, a)
    residualized_loss([], implicit, r)
end;

In [6]:
# Find the best regularization hyperparameters
res = Optim.optimize(
    validation_mse,
    fill(0.0f0, 4),
    Optim.NewtonTrustRegion(),
    autodiff = :forward,
    Optim.Options(show_trace = true, extended_trace = true),
);
λ = exp.(Optim.minimizer(res));

[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 19:57:37 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0,1.0,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0,0.0,1.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0,0.0,0.0,1.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0,0.0,0.0,0.0,1.0)]
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (18.34 ns/it)[39m39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 ( 4.16 ns/it)[39m
[32mProgress: 100%|███████████████████████████| Time: 0:00:00 (33.27 ns/it)[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 19:57:58 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}, 4}[Dual{For

Iter     Function value   Gradient norm 
     0     1.778538e+00     7.672866e-02
 * time: 0.007397890090942383
 * g(x): Float32[-0.0065773665, -1.8496082f-6, 0.019701673, 0.07672866]
 * reached_subproblem_solution: true
 * h(x): Float32[0.0055676643 -2.1045787f-7 -0.0015772729 -0.06168165; -2.1045791f-7 -8.33904f-7 1.2805032f-5 1.1528118f-5; -0.0015772716 1.2805035f-5 0.0027728304 0.0076725623; -0.06168165 1.1528117f-5 0.007672585 0.9672913]
 * x: Float32[0.0, 0.0, 0.0, 0.0]
 * lambda: NaN
 * interior: true
 * hard case: false
 * delta: 1.0


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 19:59:14 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0378308,1.0378308,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0009433,0.0,1.0009433,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.36824173,0.0,0.0,0.36824173,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9343703,0.0,0.0,0.0,0.9343703)]


     1     1.778538e+00     7.672866e-02
 * time: 15.757678031921387
 * g(x): Float32[-0.0065773665, -1.8496082f-6, 0.019701673, 0.07672866]
 * reached_subproblem_solution: false
 * h(x): Float32[0.0055676643 -2.1045787f-7 -0.0015772729 -0.06168165; -2.1045791f-7 -8.33904f-7 1.2805032f-5 1.1528118f-5; -0.0015772716 1.2805035f-5 0.0027728304 0.0076725623; -0.06168165 1.1528117f-5 0.007672585 0.9672913]
 * x: Float32[0.0, 0.0, 0.0, 0.0]
 * lambda: 0.016406853
 * interior: false
 * hard case: false
 * delta: 0.25050202


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 19:59:29 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0228587,1.0228587,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0000753,0.0,1.0000753,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.78556234,0.0,0.0,0.78556234,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.93201923,0.0,0.0,0.0,0.93201923)]
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 19:59:44 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0228587,1.0228587,0.0,0.0,0.0),Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.02285

     2     1.776582e+00     1.603866e-01
 * time: 97.02475786209106
 * g(x): Float32[0.01269753, -5.4449692f-6, -0.019455392, -0.16038658]
 * reached_subproblem_solution: false
 * h(x): Float32[0.031616744 -1.3931758f-6 -0.07689739 -0.36362463; -1.3931757f-6 4.3427194f-6 -1.7518863f-5 -5.5419046f-6; -0.07689739 -1.7518852f-5 0.24810919 0.87214375; -0.3636246 -5.5418986f-6 0.87214375 4.4505343]
 * x: Float32[0.022601405, 7.527841f-5, -0.24135548, -0.070401825]
 * lambda: 0.07713252
 * interior: false
 * hard case: false
 * delta: 0.25050202


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:00:50 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.85269064,0.85269064,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(1.0025737,0.0,1.0025737,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6653756,0.0,0.0,0.6653756,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.98343474,0.0,0.0,0.0,0.98343474)]
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:01:05 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.85269064,0.85269064,0.0,0.0,0.0),Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.852

     3     1.772745e+00     1.127567e-02
 * time: 178.95110487937927
 * g(x): Float32[0.0011574734, 3.8595313f-6, -0.004101767, -0.011275671]
 * reached_subproblem_solution: false
 * h(x): Float32[0.023535723 -2.6046007f-6 -0.05506038 -0.2707483; -2.6046027f-6 2.5232008f-5 -0.00013850554 -0.000111207744; -0.05506038 -0.00013850554 0.19711536 0.6218829; -0.2707483 -0.00011120776 0.6218829 3.3693933]
 * x: Float32[-0.1593585, 0.0025704808, -0.4074036, -0.016703993]
 * lambda: 0.0010578422
 * interior: false
 * hard case: false
 * delta: 0.50100404


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:02:12 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.78124666,0.78124666,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9168468,0.0,0.9168468,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.678142,0.0,0.0,0.678142,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9763836,0.0,0.0,0.0,0.9763836)]
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:02:27 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.78124666,0.78124666,0.0,0.0,0.0),Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7812466

     4     1.772695e+00     8.018634e-04
 * time: 261.07906794548035
 * g(x): Float32[6.453184f-5, 4.8586827f-7, -0.0002687249, -0.00080186344]
 * reached_subproblem_solution: true
 * h(x): Float32[0.021701958 -2.296364f-6 -0.048566844 -0.24885547; -2.2963654f-6 1.9337116f-5 -0.000104442304 -8.364114f-5; -0.048566844 -0.00010444231 0.17058572 0.5450123; -0.2488555 -8.364114f-5 0.54501206 3.1162136]
 * x: Float32[-0.24686436, -0.0868149, -0.38839853, -0.023899725]
 * lambda: 0.0
 * interior: true
 * hard case: false
 * delta: 0.50100404


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:03:34 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7854143,0.7854143,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.90588576,0.0,0.90588576,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6795331,0.0,0.0,0.6795331,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.97669953,0.0,0.0,0.0,0.97669953)]
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:03:49 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7854143,0.7854143,0.0,0.0,0.0),Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.78541

     5     1.772694e+00     1.810524e-06
 * time: 342.1788549423218
 * g(x): Float32[2.8553362f-7, 1.1079687f-8, 2.4221816f-7, 1.8105244f-6]
 * reached_subproblem_solution: true
 * h(x): Float32[0.021567749 -2.2611277f-6 -0.048071183 -0.24730311; -2.261128f-6 1.8516803f-5 -9.928165f-5 -7.941069f-5; -0.048071187 -9.928165f-5 0.16850336 0.53925675; -0.24730311 -7.9410696f-5 0.5392568 3.0981238]
 * x: Float32[-0.24154392, -0.09884205, -0.38634932, -0.023576241]
 * lambda: 0.0
 * interior: true
 * hard case: false
 * delta: 0.50100404


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:04:55 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7851838,0.7851838,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9051357,0.0,0.9051357,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6795197,0.0,0.0,0.6795197,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9766794,0.0,0.0,0.0,0.9766794)]
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:05:10 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7851838,0.7851838,0.0,0.0,0.0),Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7851838,0

     6     1.772694e+00     2.458161e-06
 * time: 423.8721158504486
 * g(x): Float32[-2.6804674f-8, -6.8044077f-12, -5.201158f-7, -2.4581614f-6]
 * reached_subproblem_solution: true
 * h(x): Float32[0.021567939 -2.2614581f-6 -0.04807282 -0.24730478; -2.2614581f-6 1.8499546f-5 -9.916444f-5 -7.932625f-5; -0.04807281 -9.9164456f-5 0.1685251 0.53931636; -0.24730481 -7.932626f-5 0.53931636 3.0981934]
 * x: Float32[-0.24183749, -0.09967044, -0.38636905, -0.023596844]
 * lambda: 0.0
 * interior: true
 * hard case: false
 * delta: 0.50100404


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:06:17 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7853056,0.7853056,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9052537,0.0,0.9052537,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.67952776,0.0,0.0,0.67952776,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.97669023,0.0,0.0,0.0,0.97669023)]
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:06:32 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7853056,0.7853056,0.0,0.0,0.0),Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.78530

     7     1.772694e+00     1.047199e-06
 * time: 505.15867495536804
 * g(x): Float32[1.3075725f-8, -9.620025f-12, -8.799742f-8, -1.0471994f-6]
 * reached_subproblem_solution: true
 * h(x): Float32[0.021567639 -2.2612708f-6 -0.04807157 -0.24730209; -2.2612708f-6 1.8499299f-5 -9.916334f-5 -7.932175f-5; -0.048071574 -9.916334f-5 0.16851448 0.5392933; -0.24730209 -7.932174f-5 0.5392932 3.0981913]
 * x: Float32[-0.24168234, -0.09954001, -0.38635722, -0.02358572]
 * lambda: 0.0
 * interior: true
 * hard case: false
 * delta: 0.50100404


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:07:38 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.78533965,0.78533965,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9052808,0.0,0.9052808,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6795289,0.0,0.0,0.6795289,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9766937,0.0,0.0,0.0,0.9766937)]
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:07:53 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.78533965,0.78533965,0.0,0.0,0.0),Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.78533

     8     1.772694e+00     5.566362e-07
 * time: 588.2163858413696
 * g(x): Float32[3.441435f-8, 4.481524f-11, -7.778611f-8, 5.566362f-7]
 * reached_subproblem_solution: true
 * h(x): Float32[0.02156755 -2.2612617f-6 -0.04807122 -0.24730098; -2.261261f-6 1.849933f-5 -9.916351f-5 -7.9320365f-5; -0.04807122 -9.916351f-5 0.16851342 0.5392743; -0.24730101 -7.932038f-5 0.5392742 3.0981574]
 * x: Float32[-0.24163897, -0.099510096, -0.38635552, -0.023582214]
 * lambda: 0.0
 * interior: true
 * hard case: false
 * delta: 0.50100404


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:09:02 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.78530407,0.78530407,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.90525573,0.0,0.90525573,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6795284,0.0,0.0,0.6795284,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9766901,0.0,0.0,0.0,0.9766901)]
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:09:16 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.78530407,0.78530407,0.0,0.0,0.0),Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.785

     9     1.772694e+00     7.543752e-07
 * time: 670.2137680053711
 * g(x): Float32[-3.939987f-8, -5.4435262f-11, 1.2571918f-7, -7.5437515f-7]
 * reached_subproblem_solution: true
 * h(x): Float32[0.021567553 -2.2612594f-6 -0.048071235 -0.24730097; -2.2612599f-6 1.8499204f-5 -9.916272f-5 -7.932129f-5; -0.048071235 -9.9162746f-5 0.16851279 0.5392845; -0.24730095 -7.932128f-5 0.5392845 3.0981803]
 * x: Float32[-0.24168429, -0.09953777, -0.38635626, -0.023585882]
 * lambda: 0.0
 * interior: true
 * hard case: false
 * delta: 0.50100404


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:10:24 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7853477,0.7853477,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9052841,0.0,0.9052841,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.6795286,0.0,0.0,0.6795286,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9766946,0.0,0.0,0.0,0.9766946)]
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:10:38 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7853477,0.7853477,0.0,0.0,0.0),Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7853477,0

    10     1.772694e+00     1.327470e-07
 * time: 753.5726640224457
 * g(x): Float32[3.4234148f-8, 5.3027454f-11, -1.3274696f-7, 1.0247344f-7]
 * reached_subproblem_solution: true
 * h(x): Float32[0.021567563 -2.261245f-6 -0.04807131 -0.24730137; -2.2612455f-6 1.8499382f-5 -9.916365f-5 -7.9320846f-5; -0.04807131 -9.916365f-5 0.16851261 0.53928; -0.24730138 -7.932085f-5 0.53928 3.098172]
 * x: Float32[-0.24162872, -0.09950643, -0.38635594, -0.02358126]
 * lambda: 0.0
 * interior: true
 * hard case: false
 * delta: 0.50100404


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:11:47 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7853306,0.7853306,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9052745,0.0,0.9052745,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.67952895,0.0,0.0,0.67952895,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9766928,0.0,0.0,0.0,0.9766928)]
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:12:02 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7853306,0.7853306,0.0,0.0,0.0),Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.7853306

    11     1.772694e+00     1.123424e-06
 * time: 836.1069769859314
 * g(x): Float32[-9.862168f-9, -1.4078085f-11, -1.0403517f-7, 1.1234237f-6]
 * reached_subproblem_solution: true
 * h(x): Float32[0.021567512 -2.2612455f-6 -0.048071153 -0.24730042; -2.2612448f-6 1.8499253f-5 -9.916307f-5 -7.9320154f-5; -0.048071153 -9.916309f-5 0.16851625 0.53927857; -0.24730042 -7.932014f-5 0.53927845 3.0981455]
 * x: Float32[-0.24165049, -0.09951706, -0.3863554, -0.023583125]
 * lambda: 0.0
 * interior: true
 * hard case: false
 * delta: 0.50100404


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:13:09 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.78529906,0.78529906,0.0,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9052623,0.0,0.9052623,0.0,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.67952985,0.0,0.0,0.67952985,0.0), Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.9766891,0.0,0.0,0.0,0.9766891)]
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:13:24 training model with parameters ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}, Float32, 4}, 4}[Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.78529906,0.78529906,0.0,0.0,0.0),Dual{ForwardDiff.Tag{typeof(validation_mse), Float32}}(0.785

    12     1.772694e+00     4.080317e-07
 * time: 918.4016778469086
 * g(x): Float32[-1.7941112f-8, -6.264748f-11, 4.0803172f-7, -7.340126f-8]
 * reached_subproblem_solution: true
 * h(x): Float32[0.021567527 -2.261239f-6 -0.04807097 -0.2473006; -2.2612378f-6 1.8499133f-5 -9.916272f-5 -7.932059f-5; -0.04807097 -9.916272f-5 0.16851087 0.5392814; -0.24730058 -7.932058f-5 0.53928125 3.0981536]
 * x: Float32[-0.24169067, -0.09953056, -0.38635412, -0.023586921]
 * lambda: 0.0
 * interior: true
 * hard case: false
 * delta: 0.50100404


In [7]:
@info "The optimal λ is $λ, found in " * repr(Optim.f_calls(res)) * " function calls"

[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:14:32 The optimal λ is Float32[0.78529906, 0.9052623, 0.67952985, 0.9766891], found in 13 function calls


In [8]:
stop_criteria = convergence_stopper(1e-6, max_iters = 16)
u, a = train_model(training, stop_criteria, λ);

[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:14:32 training model with parameters Float32[0.78529906, 0.9052623, 0.67952985, 0.9766891]


In [9]:
validation_mse(Optim.minimizer(res))

[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:14:39 training model with parameters Float32[0.78529906, 0.9052623, 0.67952985, 0.9766891]


1.7725984f0

## Inference

In [10]:
model(users, items) = make_prediction(users, items, u, a)
write_alpha(model, [], implicit, name);

[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:14:48 validation loss: 1.7725984, β: Float32[1.0016403]
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220625 20:14:50 training loss: 1.6141425, β: Float32[1.0016403]


In [11]:
write_params(Dict("u" => u, "a" => a, "λ" => λ), name);