# Matrix Factorization
* Prediction is $\tilde R = UA^T$ 
* Loss fuction is $L = \lVert (R - \tilde R)^\Omega \rVert _2^2 + \lambda_u \lVert U \rVert _2^2 + \lambda_a \lVert A \rVert _2^2$
* $\Omega$ is the set of oberved pairs $(i, j)$
* $M^\Omega$ is the projection of $M$ onto $\Omega$ for any matrix $M$, that is $M_{ij}^\Omega$ is defined to be $M_{ij}$ when $(i, j) \in \Omega$ and $0$ otherwise
* $U$ is an $m x k$ matrix, $A$ is an $n x k$ matrix and $R$ is the $m x n$ ratings matrix

In [1]:
name = "MatrixFactorization";
residual_alphas = ["UserItemBiases"];

In [2]:
using Random
using SparseArrays

In [3]:
using NBInclude
@nbinclude("Alpha.ipynb");

# Alternating Least Squares Algorithm
* $u_{ik} = \dfrac{\sum_{j \in \Omega_i}(r_{ij} - \tilde r_{ij} + u_{ik}a_{kj})}{\sum_{j \in \Omega_i} a_j^2 + \lambda_u}$
* $\Omega$ is the set of (user, item) pairs that we have ratings for
* $\Omega_i$ is subset of $\Omega$ for which the user is the $i$-th user
* Note that this equation is equivalent to solving $A^{\Omega_i} u_i = R^{\Omega_i}$ with $L_2$ regularization $\lambda_u$, where $\Omega_i = \{(i', j) \in \Omega | i' = i \}$

In [5]:
function make_prediction(users, items, U, A)
    r = zeros(eltype(U), length(users))
    @views Threads.@threads for i = 1:length(r)
        if (users[i] <= size(U)[1]) && (items[i] <= size(A)[1])
            r[i] = dot(U[users[i], :], A[items[i], :])
        end
    end
    r
end;

In [6]:
function calc_loss(df, U, A)
    loss = mse(df.rating, make_prediction(df.user, df.item, U, A))
    @debug "loss: $loss"
    loss
end;

In [7]:
function ridge_regression(X, y, λ)
    (Matrix(X'X) + λ * I(size(X)[2])) \ Vector(X'y)
end;

In [8]:
# julia matrices are column major by default so we take adjoints to make them row major
@memoize function sparse_csr(i, j, v, m, n)
    sparse(j, i, v, n, m)'
end;

@memoize function gaussian_init_csr(source, K, el_type) 
    Random.seed!(20211204 * hash(source) * K)
    (zeros(el_type, K, maximum(source)) + randn(K, maximum(source)) * K^(-1 / 4))'
end;

In [9]:
function sparse_subset(A, rows)
    # returns a sparse matrix B such that: 
    # size(B) == size(A), B[rows, :] == A[rows, :], and B[~rows, :] == 0
    K = size(A)[2]
    nzval = vec(A[rows, :])
    rowval = repeat(rows, K)
    colptr = [1 + (x - 1) * length(rows) for x = 1:K+1]
    SparseMatrixCSC(size(A)..., colptr, rowval, nzval)
end;

In [10]:
function update_users!(users, items, ratings, U, A, λ_u)
    R = sparse_csr(users, items, ratings, size(U)[1], size(A)[1])
    Threads.@threads for i = 1:size(U)[1]
        X = sparse_subset(A, rowvals(R[i, :]))
        y = R[i, :]
        U[i, :] = ridge_regression(X, y, λ_u)
    end
end;

In [12]:
function train_model(training, validation, λ_u, λ_a, K, stop_criteria)
    @debug "training model with parameters [$λ_u, $λ_a]"
    users, items, ratings = training.user, training.item, training.rating
    U = copy(gaussian_init_csr(users, K, eltype(λ_u)))
    A = copy(gaussian_init_csr(items, K, eltype(λ_a)))
    calc_loss(training, U, A)
    loss = calc_loss(validation, U, A)

    while !stop!(stop_criteria, loss)
        update_users!(users, items, ratings, U, A, λ_u)
        update_users!(items, users, ratings, A, U, λ_a)
        calc_loss(training, U, A)
        loss = calc_loss(validation, U, A)
    end
    U, A, loss
end;

## Training

In [13]:
function validation_mse(λ, K)
    λ = exp.(λ) # ensure λ is nonnegative
    # stop really early so we can spend more computation exploring the parameter space
    stop_criteria = early_stopper(max_iters = 10)
    U, A, loss = train_model(training, validation, λ..., K, stop_criteria)
    loss
end;

In [14]:
K = 20;

In [16]:
# Find the best regularization hyperparameters
res = optimize(
    λ -> validation_mse(λ, K),
    [5.0, 2.0],  # intial guess
    LBFGS(),
    autodiff = :forward,
    Optim.Options(show_trace = true, extended_trace = true),
)
λ = exp.(Optim.minimizer(res));

[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 14:58:18 training model with parameters [Dual{ForwardDiff.Tag{var"#16#17", Float64}}(148.4131591025766,148.4131591025766,0.0), Dual{ForwardDiff.Tag{var"#16#17", Float64}}(7.38905609893065,0.0,7.38905609893065)]
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 14:58:22 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(2.662548382165211,0.0,0.0)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 14:58:22 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(2.6848836842873296,0.0,0.0)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 15:01:56 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(1.49616177365311,0.03150636131802458,0.0102092317229034)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 15:01:56 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(1.7444521368957828,-0.034198613633866594,-0.01644497126844936)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m202

Iter     Function value   Gradient norm 
     0     1.422653e+00     4.229810e-03
 * Current step size: 1.0
 * time: 0.02636885643005371
 * g(x): [-0.004229809655142995, -0.003896003263645237]
 * x: [5.0, 2.0]


[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 15:32:08 training model with parameters [Dual{ForwardDiff.Tag{var"#16#17", Float64}}(149.04224804119818,149.04224804119818,0.0), Dual{ForwardDiff.Tag{var"#16#17", Float64}}(7.417900037161294,0.0,7.417900037161294)]
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 15:32:11 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(2.662548382165211,0.0,0.0)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 15:32:11 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(2.6848836842873296,0.0,0.0)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 15:35:32 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(1.496335082476864,0.03157689168480901,0.010270253176915314)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 15:35:33 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(1.7442432582778873,-0.03423304445951837,-0.016487280614071745)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m

     1     1.420918e+00     3.215102e-04
 * Current step size: 92.35606305460917
 * time: 10126.351088047028
 * g(x): [-0.00032151017380974134, -0.00018112589911479493]
 * x: [5.390648567219381, 2.3598195230781824]


[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 18:20:52 training model with parameters [Dual{ForwardDiff.Tag{var"#16#17", Float64}}(226.51188865364492,226.51188865364492,0.0), Dual{ForwardDiff.Tag{var"#16#17", Float64}}(10.776305338705427,0.0,10.776305338705427)]
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 18:20:55 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(2.662548382165211,0.0,0.0)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 18:20:55 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(2.6848836842873296,0.0,0.0)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 18:24:19 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(1.5166302214082847,0.04163816532210679,0.018335375331397917)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 18:24:19 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(1.722641334031789,-0.03587953079204512,-0.019136308720896462)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22

     2     1.420912e+00     8.469477e-05
 * Current step size: 0.8464636068411004
 * time: 14157.538706064224
 * g(x): [-4.705165378167764e-5, 8.469476892217073e-5]
 * x: [5.417861411327414, 2.3746582423125226]


[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 19:28:04 training model with parameters [Dual{ForwardDiff.Tag{var"#16#17", Float64}}(226.65774807549704,226.65774807549704,0.0), Dual{ForwardDiff.Tag{var"#16#17", Float64}}(10.686552095419485,0.0,10.686552095419485)]
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 19:28:06 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(2.662548382165211,0.0,0.0)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 19:28:06 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(2.6848836842873296,0.0,0.0)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 19:31:28 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(1.5165039915330434,0.04153033377668983,0.018251353517739932)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 19:31:28 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(1.7227782340376054,-0.03588084008782181,-0.019123829695802036)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [2

     3     1.420896e+00     8.750445e-06
 * Current step size: 42.53784802518851
 * time: 24256.015596866608
 * g(x): [8.750445472899882e-6, 6.522473018049102e-6]
 * x: [5.655211997549526, 2.133379109950202]


[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 22:16:22 training model with parameters [Dual{ForwardDiff.Tag{var"#16#17", Float64}}(284.5137244285863,284.5137244285863,0.0), Dual{ForwardDiff.Tag{var"#16#17", Float64}}(8.47104057458554,0.0,8.47104057458554)]
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 22:16:24 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(2.662548382165211,0.0,0.0)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 22:16:25 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(2.6848836842873296,0.0,0.0)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 22:19:46 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(1.521748763500984,0.04270935749937766,0.01903162622627912)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 22:19:47 loss: Dual{ForwardDiff.Tag{var"#16#17", Float64}}(1.719080923254866,-0.03527820851812389,-0.018693468836704117)
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20

     4     1.420896e+00     7.707058e-09
 * Current step size: 1.0314567230282798
 * time: 30322.71363902092
 * g(x): [-5.297670290145997e-9, -7.707058005484854e-9]
 * x: [5.650642123768275, 2.1367563522881894]


In [17]:
res

 * Status: success

 * Candidate solution
    Final objective value:     1.420896e+00

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 4.57e-03 ≰ 0.0e+00
    |x - x'|/|x'|          = 8.09e-04 ≰ 0.0e+00
    |f(x) - f(x')|         = 8.98e-09 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 6.32e-09 ≰ 0.0e+00
    |g(x)|                 = 7.71e-09 ≤ 1.0e-08

 * Work counters
    Seconds run:   30323  (vs limit Inf)
    Iterations:    4
    f(x) calls:    16
    ∇f(x) calls:   16


In [18]:
@info "The optimal [λ_u, λ_a] is $λ, found in " *
      repr(Optim.f_calls(res)) *
      " function calls"

[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20211210 23:57:30 The optimal [λ_u, λ_a] is [284.4740747537367, 8.471913111521085], found in 16 function calls


In [19]:
stop_criteria = early_stopper(max_iters = 100, patience = 1, min_rel_improvement = 0.0001)
U, A, loss = train_model(training, validation, λ..., K, stop_criteria);

[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 23:57:31 training model with parameters [284.4740747537367, 8.471913111521085]
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 23:57:32 loss: 2.662548382165211
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211210 23:57:33 loss: 2.6848836842873296
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211211 00:00:10 loss: 1.5217447713812076
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211211 00:00:10 loss: 1.7190839145964352
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211211 00:02:47 loss: 1.2824542772316048
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211211 00:02:47 loss: 1.5813365899174856
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211211 00:05:23 loss: 1.2202543700182886
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211211 00:05:23 loss: 1.5135724757124607
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20211211 00:07:59 l

## Inference

In [20]:
model(users, items) = make_prediction(users, items, U, A);

In [21]:
write_predictions(model);

[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20211211 01:36:38 training set: RMSE 1.0651358195805292 MAE 0.7900524609219547 R2 0.3222331436128555
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20211211 01:36:39 validation set: RMSE 1.1845196093577437 MAE 0.8746890547569217 R2 0.17377616015977593


In [22]:
write_params(Dict("U" => U, "A" => A, "λ" => λ));