# Product Recommendation
Reference: https://ieeexplore.ieee.org/document/5430993

In [1]:
import numexpr as ne
import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm import tqdm

In [2]:
tf.__version__

'2.3.1'

## Data Preprocessing

In [3]:
Y_data = pd.read_csv('data/Y.csv', header=None, names=['Rating','Movie','User'], dtype=np.int32) # training data
P_data = pd.read_csv('data/P.csv', header=None, names=['Rating','Movie','User'], dtype=np.int32) # test data ('probe-set' mentioned in paper)

In [4]:
display(Y_data.head())
display(P_data.head())
Y_data.shape, P_data.shape

Unnamed: 0,Rating,Movie,User
0,5,2,1
1,4,7,1
2,4,8,1
3,4,11,1
4,4,12,1


Unnamed: 0,Rating,Movie,User
0,3,6,1
1,5,96,1
2,3,1,2
3,3,33,3
4,5,42,4


((3399874, 3), (189699, 3))

In [5]:
print(Y_data['Rating'].unique().max(), Y_data['Movie'].unique().max(), Y_data['User'].unique().max())
print(P_data['Rating'].unique().max(), P_data['Movie'].unique().max(), P_data['User'].unique().max())

5 100 137328
5 100 137328


In [6]:
k, n = Y_data['Movie'].unique().max(), Y_data['User'].unique().max()
k, n

(100, 137328)

In [7]:
Z_sparse = tf.SparseTensor(indices=Y_data[['Movie', 'User']].values-1, values=Y_data['Rating'].values, dense_shape=[k, n])
Z_sparse = tf.cast(Z_sparse, tf.float64)

In [8]:
# use dense matrices for faster linear transformations since all matrices can fit in memory
Z = tf.sparse.to_dense(Z_sparse, validate_indices=False)
Z

<tf.Tensor: shape=(100, 137328), dtype=float64, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
       [5., 0., 0., ..., 0., 0., 3.],
       [0., 0., 0., ..., 3., 0., 0.],
       ...,
       [5., 0., 0., ..., 4., 0., 4.],
       [4., 0., 3., ..., 0., 0., 4.],
       [3., 4., 0., ..., 4., 5., 4.]])>

In [9]:
Y_data_user_ids = Y_data['User'].values
P_data_user_ids = P_data['User'].values

data_preprocessed = list()
for t in tqdm(range(n)):
    movie_ids_indices = Y_data[ne.evaluate(f'Y_data_user_ids == {t+1}')]['Movie'].values - 1
    H_yt = tf.constant(np.identity(k)[movie_ids_indices], dtype=tf.float64)
    H_xt = tf.constant(np.delete(np.identity(k), movie_ids_indices, 0), dtype=tf.float64)
    
    k_t = tf.constant(H_yt.shape[0], dtype=tf.float64)
    Z_t = tf.expand_dims(Z[:, t], axis=1) 
    y_t = tf.matmul(H_yt, Z_t)
    x_t = tf.matmul(H_xt, Z_t)
    
    movie_ids_t = P_data[ne.evaluate(f'P_data_user_ids == {t+1}')]['Movie'].values
    labels_t = tf.expand_dims(P_data[ne.evaluate(f'P_data_user_ids == {t+1}')]['Rating'].values, axis=1)
    data_preprocessed.append((H_yt, H_xt, tf.transpose(H_yt), tf.transpose(H_xt), k_t, Z_t, y_t, x_t, movie_ids_t, labels_t))
        
del Y_data
del P_data
del Z_sparse
del Y_data_user_ids
del P_data_user_ids

100%|██████████| 137328/137328 [10:28<00:00, 218.59it/s]


## Initialization
$\mu$ has 1 type available <br />
R has 4 types available

In [10]:
# initial estimate of mu
N = 0
H_yty_t = 0
    
for (H_yt, H_xt, H_yt_trans, H_xt_trans, k_t, Z_t, y_t, x_t, movie_ids_t, labels_t) in tqdm(data_preprocessed):
    N += tf.matmul(H_yt_trans, H_yt)
    H_yty_t += tf.matmul(H_yt_trans, y_t)

100%|██████████| 137328/137328 [00:14<00:00, 9359.09it/s] 


In [11]:
# The ith diagonal element of N equals the total number of ratings of the ith product.
N

<tf.Tensor: shape=(100, 100), dtype=float64, numpy=
array([[20017.,     0.,     0., ...,     0.,     0.,     0.],
       [    0., 23917.,     0., ...,     0.,     0.,     0.],
       [    0.,     0., 31634., ...,     0.,     0.,     0.],
       ...,
       [    0.,     0.,     0., ..., 60896.,     0.,     0.],
       [    0.,     0.,     0., ...,     0., 61521.,     0.],
       [    0.,     0.,     0., ...,     0.,     0., 64506.]])>

In [12]:
mu_hat0 = tf.matmul(tf.linalg.inv(N), H_yty_t)
tf.transpose(mu_hat0)

<tf.Tensor: shape=(1, 100), dtype=float64, numpy=
array([[3.45266523, 3.57674457, 3.28788645, 3.90478757, 3.79035475,
        3.44415598, 3.19071562, 4.52835008, 3.82013753, 3.6159503 ,
        3.40382731, 3.83725101, 4.07603884, 4.22836664, 3.35395465,
        4.0645276 , 3.72119599, 3.48700861, 4.16388921, 3.40982441,
        3.86926003, 3.43583485, 3.20324443, 4.08487897, 3.23199846,
        3.88664794, 4.33189497, 4.38358165, 4.31638739, 3.86591733,
        4.33975717, 3.89147883, 3.70029269, 3.36247781, 4.32901523,
        4.06706884, 4.56922029, 3.77104091, 3.68586682, 3.84532386,
        4.3454114 , 3.90999207, 3.39949928, 3.60786807, 3.96267104,
        4.14386102, 3.4072049 , 3.7040225 , 4.00350359, 4.64280228,
        3.21623279, 3.77238583, 4.26565116, 4.45377313, 3.83848945,
        3.79374176, 3.7629172 , 3.88698608, 3.80041727, 4.34696995,
        3.80469565, 3.84624795, 3.64122601, 3.27221683, 3.42333499,
        3.71631568, 3.20698918, 4.45410441, 4.26541296, 3.86109184

In [13]:
# initial estimates of R (4 types available)
R_hat0_1 = tf.eye(k, dtype=tf.float64)
R_hat0_1

<tf.Tensor: shape=(100, 100), dtype=float64, numpy=
array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]])>

In [14]:
S = 0
for (H_yt, H_xt, H_yt_trans, H_xt_trans, k_t, Z_t, y_t, x_t, movie_ids_t, labels_t) in tqdm(data_preprocessed):
    Hytmu_hat0 = tf.matmul(H_yt, mu_hat0)
    S += H_yt_trans @ (y_t - Hytmu_hat0) @ tf.transpose(y_t - Hytmu_hat0) @ H_yt

100%|██████████| 137328/137328 [00:21<00:00, 6323.46it/s]


In [15]:
# diag_S is the diagonal matrix consisting of the diagonal elements of S
diag_S = tf.linalg.diag(tf.linalg.tensor_diag_part(S))
R_hat0_2 = tf.matmul(tf.linalg.inv(N), diag_S)
R_hat0_2

<tf.Tensor: shape=(100, 100), dtype=float64, numpy=
array([[1.72440427, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.94219113, 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 1.43659411, ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 1.18291506, 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 1.03485685,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        1.26227449]])>

In [16]:
# R_hat0_3 is not a good initializer when rating variances are far from one
R_hat0_3 = tf.matmul(tf.linalg.sqrtm(tf.linalg.inv(diag_S)), tf.matmul(S, tf.linalg.sqrtm(tf.linalg.inv(diag_S))))
R_hat0_3

<tf.Tensor: shape=(100, 100), dtype=float64, numpy=
array([[ 1.        ,  0.07418256, -0.01158277, ..., -0.01462987,
        -0.02215371, -0.01844816],
       [ 0.07418256,  1.        ,  0.03674347, ...,  0.0256191 ,
         0.03563234,  0.03926307],
       [-0.01158277,  0.03674347,  1.        , ...,  0.10955311,
         0.12823359,  0.15560634],
       ...,
       [-0.01462987,  0.0256191 ,  0.10955311, ...,  1.        ,
         0.19781317,  0.15164928],
       [-0.02215371,  0.03563234,  0.12823359, ...,  0.19781317,
         1.        ,  0.18995689],
       [-0.01844816,  0.03926307,  0.15560634, ...,  0.15164928,
         0.18995689,  1.        ]])>

In [17]:
R_hat0_4 = tf.matmul(tf.linalg.sqrtm(tf.linalg.inv(N)), tf.matmul(S, tf.linalg.sqrtm(tf.linalg.inv(N))))
R_hat0_4

<tf.Tensor: shape=(100, 100), dtype=float64, numpy=
array([[ 1.72440427,  0.09455639, -0.01823052, ..., -0.02089473,
        -0.02959417, -0.02721758],
       [ 0.09455639,  0.94219113,  0.04274809, ...,  0.02704644,
         0.03518471,  0.04281842],
       [-0.01823052,  0.04274809,  1.43659411, ...,  0.14281326,
         0.15635399,  0.20954206],
       ...,
       [-0.02089473,  0.02704644,  0.14281326, ...,  1.18291506,
         0.21886288,  0.18530794],
       [-0.02959417,  0.03518471,  0.15635399, ...,  0.21886288,
         1.03485685,  0.21710614],
       [-0.02721758,  0.04281842,  0.20954206, ...,  0.18530794,
         0.21710614,  1.26227449]])>

## Expectation Maximization Algorithm

In [18]:
LOG_2PI = tf.math.log(2*tf.constant(np.pi, dtype=tf.float64))

@tf.function(experimental_relax_shapes=True)
def run_graph_em(mu, R, y_t, H_xt, H_xt_trans, H_yt, H_yt_trans, k_t):
    # for R estimation
    R_xt = H_xt @ R @ H_xt_trans
    R_yt = H_yt @ R @ H_yt_trans
    R_yt_det = tf.linalg.det(R_yt)
    R_yt_inv = tf.linalg.inv(R_yt)
    R_xtyt = H_xt @ R @ H_yt_trans

    mu_yt = tf.matmul(H_yt, mu)
    mu_xt = tf.matmul(H_xt, mu)

    X_t_hat = R_xtyt @ R_yt_inv @ (y_t - mu_yt) + mu_xt
    Z_t_hat = H_yt_trans @ y_t + H_xt_trans @ X_t_hat
    
    R_hat_sum_part = (Z_t_hat - mu) @ tf.transpose(Z_t_hat - mu) \
                        + H_xt_trans @ (R_xt - R_xtyt @ R_yt_inv @ tf.transpose(R_xtyt)) @ H_xt

    # for mu estimation
    Hyt_trans_Ryt_inv_Hyt_sum_part = H_yt_trans @ R_yt_inv @ H_yt
    Hyt_trans_Ryt_inv_yt_sum_part = H_yt_trans @ R_yt_inv @ y_t
    
    # for log likelihood calculation
    log_p_hat_part = -1/2*(tf.math.log(R_yt_det) + tf.transpose(y_t - mu_yt) @ R_yt_inv @ (y_t - mu_yt) + k_t*LOG_2PI) 
    
    return R_hat_sum_part, Hyt_trans_Ryt_inv_Hyt_sum_part, Hyt_trans_Ryt_inv_yt_sum_part, log_p_hat_part

In [19]:
def expectation_maximization(mu, R):
    Hyt_trans_Ryt_inv_Hyt_sum = 0
    Hyt_trans_Ryt_inv_yt_sum = 0
    R_hat_sum = 0
    log_p_hat = 0
    
    for (H_yt, H_xt, H_yt_trans, H_xt_trans, k_t, Z_t, y_t, x_t, movie_ids_t, labels_t) in tqdm(data_preprocessed):
        R_hat_sum_part, Hyt_trans_Ryt_inv_Hyt_sum_part, Hyt_trans_Ryt_inv_yt_sum_part, log_p_hat_part = \
            run_graph_em(mu, R, y_t, H_xt, H_xt_trans, H_yt, H_yt_trans, k_t)
        
        R_hat_sum += R_hat_sum_part
        Hyt_trans_Ryt_inv_Hyt_sum += Hyt_trans_Ryt_inv_Hyt_sum_part
        Hyt_trans_Ryt_inv_yt_sum += Hyt_trans_Ryt_inv_yt_sum_part
        log_p_hat += log_p_hat_part
        
    R_hat = R_hat_sum / n
    mu_hat = tf.matmul(tf.linalg.inv(Hyt_trans_Ryt_inv_Hyt_sum), Hyt_trans_Ryt_inv_yt_sum)    
    return mu_hat, R_hat, log_p_hat

In [20]:
delta = 0.0005
mu = mu_hat0
R = R_hat0_4
log_p = tf.constant(-np.inf, dtype=tf.float64)

for i in range(30):
    if i % 5 == 0:
        print(f'iteration: {i}')
    
    mu_hat, R_hat, log_p_hat = expectation_maximization(mu, R)
    convergence_criterion = log_p_hat/n - log_p/n < delta
    
    print('normalized log_p_hat:', (log_p_hat/n).numpy().flatten()[0])
    print('normalized log_p:    ', (log_p/n).numpy().flatten()[0])
    print('convergence gap:     ', (log_p_hat/n - log_p/n).numpy().flatten()[0])
    
    if convergence_criterion:
        break
        
    # use new estimattions for next iteration
    mu = mu_hat
    R = R_hat
    log_p = log_p_hat

  0%|          | 1/137328 [00:00<7:22:28,  5.17it/s]

iteration: 0


100%|██████████| 137328/137328 [02:05<00:00, 1095.41it/s]
  0%|          | 107/137328 [00:00<02:09, 1061.38it/s]

normalized log_p_hat: -32.23267266620053
normalized log_p:     -inf
convergence gap:      inf


100%|██████████| 137328/137328 [02:03<00:00, 1115.24it/s]
  0%|          | 104/137328 [00:00<02:12, 1038.99it/s]

normalized log_p_hat: -31.935708124463915
normalized log_p:     -32.23267266620053
convergence gap:      0.2969645417366138


100%|██████████| 137328/137328 [01:59<00:00, 1145.74it/s]
  0%|          | 124/137328 [00:00<01:51, 1230.70it/s]

normalized log_p_hat: -31.783971765856748
normalized log_p:     -31.935708124463915
convergence gap:      0.15173635860716672


100%|██████████| 137328/137328 [01:55<00:00, 1189.90it/s]
  0%|          | 124/137328 [00:00<01:50, 1239.34it/s]

normalized log_p_hat: -31.68821475729513
normalized log_p:     -31.783971765856748
convergence gap:      0.09575700856161973


100%|██████████| 137328/137328 [01:55<00:00, 1184.97it/s]
  0%|          | 114/137328 [00:00<02:01, 1130.91it/s]

normalized log_p_hat: -31.62162645616368
normalized log_p:     -31.68821475729513
convergence gap:      0.06658830113144987
iteration: 5


100%|██████████| 137328/137328 [01:57<00:00, 1166.77it/s]
  0%|          | 111/137328 [00:00<02:04, 1102.72it/s]

normalized log_p_hat: -31.573243249771807
normalized log_p:     -31.62162645616368
convergence gap:      0.04838320639187188


100%|██████████| 137328/137328 [01:57<00:00, 1164.64it/s]
  0%|          | 239/137328 [00:00<01:54, 1200.43it/s]

normalized log_p_hat: -31.537223741214067
normalized log_p:     -31.573243249771807
convergence gap:      0.0360195085577395


100%|██████████| 137328/137328 [01:57<00:00, 1173.34it/s]
  0%|          | 247/137328 [00:00<01:51, 1231.74it/s]

normalized log_p_hat: -31.509964661098955
normalized log_p:     -31.537223741214067
convergence gap:      0.0272590801151118


100%|██████████| 137328/137328 [02:00<00:00, 1135.07it/s]
  0%|          | 117/137328 [00:00<01:58, 1160.96it/s]

normalized log_p_hat: -31.489087091946207
normalized log_p:     -31.509964661098955
convergence gap:      0.02087756915274852


100%|██████████| 137328/137328 [02:00<00:00, 1138.77it/s]
  0%|          | 114/137328 [00:00<02:01, 1131.31it/s]

normalized log_p_hat: -31.472953845101234
normalized log_p:     -31.489087091946207
convergence gap:      0.01613324684497286
iteration: 10


100%|██████████| 137328/137328 [01:58<00:00, 1155.27it/s]
  0%|          | 237/137328 [00:00<01:55, 1183.88it/s]

normalized log_p_hat: -31.460401700272747
normalized log_p:     -31.472953845101234
convergence gap:      0.012552144828486433


100%|██████████| 137328/137328 [02:00<00:00, 1143.40it/s]
  0%|          | 213/137328 [00:00<02:09, 1062.32it/s]

normalized log_p_hat: -31.450582796059315
normalized log_p:     -31.460401700272747
convergence gap:      0.009818904213432234


100%|██████████| 137328/137328 [01:59<00:00, 1149.23it/s]
  0%|          | 220/137328 [00:00<02:05, 1095.99it/s]

normalized log_p_hat: -31.442866788410797
normalized log_p:     -31.450582796059315
convergence gap:      0.007716007648518541


100%|██████████| 137328/137328 [02:05<00:00, 1098.24it/s]
  0%|          | 116/137328 [00:00<01:58, 1153.35it/s]

normalized log_p_hat: -31.436778194800365
normalized log_p:     -31.442866788410797
convergence gap:      0.006088593610432014


100%|██████████| 137328/137328 [01:58<00:00, 1162.91it/s]
  0%|          | 125/137328 [00:00<01:50, 1241.51it/s]

normalized log_p_hat: -31.431954690094038
normalized log_p:     -31.436778194800365
convergence gap:      0.004823504706326531
iteration: 15


100%|██████████| 137328/137328 [01:59<00:00, 1148.08it/s]
  0%|          | 89/137328 [00:00<02:35, 882.40it/s]

normalized log_p_hat: -31.428118271647552
normalized log_p:     -31.431954690094038
convergence gap:      0.0038364184464860784


100%|██████████| 137328/137328 [02:08<00:00, 1072.25it/s]
  0%|          | 119/137328 [00:00<01:56, 1180.80it/s]

normalized log_p_hat: -31.42505463546097
normalized log_p:     -31.428118271647552
convergence gap:      0.0030636361865816752


100%|██████████| 137328/137328 [01:57<00:00, 1170.58it/s]
  0%|          | 121/137328 [00:00<01:53, 1207.40it/s]

normalized log_p_hat: -31.422598022641022
normalized log_p:     -31.42505463546097
convergence gap:      0.002456612819948134


100%|██████████| 137328/137328 [01:59<00:00, 1153.56it/s]
  0%|          | 116/137328 [00:00<01:59, 1150.22it/s]

normalized log_p_hat: -31.420619870288082
normalized log_p:     -31.422598022641022
convergence gap:      0.001978152352940299


100%|██████████| 137328/137328 [01:56<00:00, 1178.58it/s]
  0%|          | 247/137328 [00:00<01:51, 1226.72it/s]

normalized log_p_hat: -31.419020207835388
normalized log_p:     -31.420619870288082
convergence gap:      0.0015996624526941616
iteration: 20


100%|██████████| 137328/137328 [01:56<00:00, 1183.73it/s]
  0%|          | 243/137328 [00:00<01:53, 1204.99it/s]

normalized log_p_hat: -31.41772108965544
normalized log_p:     -31.419020207835388
convergence gap:      0.0012991181799471008


100%|██████████| 137328/137328 [02:03<00:00, 1114.58it/s]
  0%|          | 122/137328 [00:00<01:52, 1216.58it/s]

normalized log_p_hat: -31.416661564784253
normalized log_p:     -31.41772108965544
convergence gap:      0.001059524871187989


100%|██████████| 137328/137328 [01:49<00:00, 1250.26it/s]
  0%|          | 122/137328 [00:00<01:52, 1219.78it/s]

normalized log_p_hat: -31.4157938181978
normalized log_p:     -31.416661564784253
convergence gap:      0.0008677465864543876


100%|██████████| 137328/137328 [02:05<00:00, 1090.21it/s]
  0%|          | 116/137328 [00:00<01:58, 1156.72it/s]

normalized log_p_hat: -31.415080208803396
normalized log_p:     -31.4157938181978
convergence gap:      0.0007136093944026811


100%|██████████| 137328/137328 [01:54<00:00, 1201.49it/s]
  0%|          | 128/137328 [00:00<01:47, 1278.64it/s]

normalized log_p_hat: -31.414490994453352
normalized log_p:     -31.415080208803396
convergence gap:      0.0005892143500432212
iteration: 25


100%|██████████| 137328/137328 [01:51<00:00, 1230.33it/s]

normalized log_p_hat: -31.414002583121093
normalized log_p:     -31.414490994453352
convergence gap:      0.0004884113322596306





In [21]:
# 26 iterations, ~38 min
np.save('results/em_mu.npy', mu_hat)
np.save('results/em_R.npy', R_hat)
np.save('results/em_log_p.npy', log_p_hat)

## McMichael’s Algorithm

In [22]:
@tf.function(experimental_relax_shapes=True)
def run_graph_mcmichael(mu, R, y_t, H_yt, H_yt_trans, k_t):
    # for R estimation
    R_yt = H_yt @ R @ H_yt_trans
    R_yt_det = tf.linalg.det(R_yt)
    R_yt_inv = tf.linalg.inv(R_yt)
    mu_yt = tf.matmul(H_yt, mu)
    log_p_gradient_part = H_yt_trans @ (R_yt_inv - R_yt_inv @ (y_t - mu_yt) @ tf.transpose(y_t - mu_yt) @ R_yt_inv) @ H_yt

    # for mu estimation
    Hyt_trans_Ryt_inv_Hyt_sum_part = H_yt_trans @ R_yt_inv @ H_yt
    Hyt_trans_Ryt_inv_yt_sum_part = H_yt_trans @ R_yt_inv @ y_t
    
    # for log likelihood calculation
    log_p_hat_part = -1/2*(tf.math.log(R_yt_det) + tf.transpose(y_t - mu_yt) @ R_yt_inv @ (y_t - mu_yt) + k_t*LOG_2PI)
    
    return log_p_gradient_part, Hyt_trans_Ryt_inv_Hyt_sum_part, Hyt_trans_Ryt_inv_yt_sum_part, log_p_hat_part

In [23]:
def mcmichael(mu, R):
    gamma = 0.00001
    Hyt_trans_Ryt_inv_Hyt_sum = 0
    Hyt_trans_Ryt_inv_yt_sum = 0
    log_p_gradient = 0
    log_p_hat = 0

    for (H_yt, H_xt, H_yt_trans, H_xt_trans, k_t, Z_t, y_t, x_t, movie_ids_t, labels_t) in tqdm(data_preprocessed):
        log_p_gradient_part, Hyt_trans_Ryt_inv_Hyt_sum_part, Hyt_trans_Ryt_inv_yt_sum_part, log_p_hat_part = \
            run_graph_mcmichael(mu, R, y_t, H_yt, H_yt_trans, k_t)
        
        log_p_gradient += log_p_gradient_part
        Hyt_trans_Ryt_inv_Hyt_sum += Hyt_trans_Ryt_inv_Hyt_sum_part
        Hyt_trans_Ryt_inv_yt_sum += Hyt_trans_Ryt_inv_yt_sum_part
        log_p_hat += log_p_hat_part
        
    R_hat = R + gamma*(R @ (-1/2*log_p_gradient) @ R)
    mu_hat = tf.matmul(tf.linalg.inv(Hyt_trans_Ryt_inv_Hyt_sum), Hyt_trans_Ryt_inv_yt_sum)
    return mu_hat, R_hat, log_p_hat

In [24]:
delta = 0.0005
mu = mu_hat0
R = R_hat0_4
log_p = tf.constant(-np.inf, dtype=tf.float64)

for i in range(40):
    if i % 5 == 0:
        print(f'iteration: {i}')
    
    mu_hat, R_hat, log_p_hat = mcmichael(mu, R)
    convergence_criterion = log_p_hat/n - log_p/n < delta
    
    print('normalized log_p_hat:', (log_p_hat/n).numpy().flatten()[0])
    print('normalized log_p:    ', (log_p/n).numpy().flatten()[0])
    print('convergence gap:     ', (log_p_hat/n - log_p/n).numpy().flatten()[0])
    
    if convergence_criterion:
        break
        
    # use new estimattions for next iteration
    mu = mu_hat
    R = R_hat
    log_p = log_p_hat

  0%|          | 0/137328 [00:00<?, ?it/s]

iteration: 0


100%|██████████| 137328/137328 [01:25<00:00, 1605.61it/s]
  0%|          | 168/137328 [00:00<01:21, 1673.87it/s]

normalized log_p_hat: -32.23267266620053
normalized log_p:     -inf
convergence gap:      inf


100%|██████████| 137328/137328 [01:24<00:00, 1630.26it/s]
  0%|          | 165/137328 [00:00<01:23, 1647.38it/s]

normalized log_p_hat: -32.0069453647492
normalized log_p:     -32.23267266620053
convergence gap:      0.2257273014513288


100%|██████████| 137328/137328 [01:21<00:00, 1693.44it/s]
  0%|          | 172/137328 [00:00<01:20, 1712.98it/s]

normalized log_p_hat: -31.87427149044739
normalized log_p:     -32.0069453647492
convergence gap:      0.13267387430180833


100%|██████████| 137328/137328 [01:19<00:00, 1719.26it/s]
  0%|          | 176/137328 [00:00<01:18, 1752.04it/s]

normalized log_p_hat: -31.782812613324058
normalized log_p:     -31.87427149044739
convergence gap:      0.09145887712333334


100%|██████████| 137328/137328 [01:19<00:00, 1723.87it/s]
  0%|          | 170/137328 [00:00<01:21, 1692.11it/s]

normalized log_p_hat: -31.714939723299512
normalized log_p:     -31.782812613324058
convergence gap:      0.067872890024546
iteration: 5


100%|██████████| 137328/137328 [01:20<00:00, 1702.77it/s]
  0%|          | 173/137328 [00:00<01:19, 1729.74it/s]

normalized log_p_hat: -31.66230112168724
normalized log_p:     -31.714939723299512
convergence gap:      0.05263860161227285


100%|██████████| 137328/137328 [01:20<00:00, 1714.45it/s]
  0%|          | 355/137328 [00:00<01:17, 1762.07it/s]

normalized log_p_hat: -31.620417069815367
normalized log_p:     -31.66230112168724
convergence gap:      0.04188405187187172


100%|██████████| 137328/137328 [01:19<00:00, 1720.73it/s]
  0%|          | 354/137328 [00:00<01:17, 1763.20it/s]

normalized log_p_hat: -31.586550523687293
normalized log_p:     -31.620417069815367
convergence gap:      0.03386654612807405


100%|██████████| 137328/137328 [01:19<00:00, 1717.49it/s]
  0%|          | 176/137328 [00:00<01:18, 1754.86it/s]

normalized log_p_hat: -31.558859589778326
normalized log_p:     -31.586550523687293
convergence gap:      0.027690933908967708


100%|██████████| 137328/137328 [01:19<00:00, 1726.76it/s]
  0%|          | 169/137328 [00:00<01:21, 1685.85it/s]

normalized log_p_hat: -31.536027074323755
normalized log_p:     -31.558859589778326
convergence gap:      0.022832515454570768
iteration: 10


100%|██████████| 137328/137328 [01:19<00:00, 1718.00it/s]
  0%|          | 345/137328 [00:00<01:20, 1698.36it/s]

normalized log_p_hat: -31.51707511990148
normalized log_p:     -31.536027074323755
convergence gap:      0.018951954422274042


100%|██████████| 137328/137328 [01:19<00:00, 1720.66it/s]
  0%|          | 170/137328 [00:00<01:21, 1690.48it/s]

normalized log_p_hat: -31.501259448836496
normalized log_p:     -31.51707511990148
convergence gap:      0.015815671064984826


100%|██████████| 137328/137328 [01:20<00:00, 1715.99it/s]
  0%|          | 368/137328 [00:00<01:15, 1815.47it/s]

normalized log_p_hat: -31.488002872866485
normalized log_p:     -31.501259448836496
convergence gap:      0.013256575970011397


100%|██████████| 137328/137328 [01:21<00:00, 1677.21it/s]
  0%|          | 173/137328 [00:00<01:19, 1728.70it/s]

normalized log_p_hat: -31.476850867045204
normalized log_p:     -31.488002872866485
convergence gap:      0.011152005821280397


100%|██████████| 137328/137328 [01:23<00:00, 1638.45it/s]
  0%|          | 176/137328 [00:00<01:17, 1759.81it/s]

normalized log_p_hat: -31.467440782899924
normalized log_p:     -31.476850867045204
convergence gap:      0.009410084145279995
iteration: 15


100%|██████████| 137328/137328 [01:20<00:00, 1714.62it/s]
  0%|          | 177/137328 [00:00<01:17, 1763.61it/s]

normalized log_p_hat: -31.459480001278635
normalized log_p:     -31.467440782899924
convergence gap:      0.007960781621289215


100%|██████████| 137328/137328 [01:19<00:00, 1722.77it/s]
  0%|          | 171/137328 [00:00<01:20, 1706.33it/s]

normalized log_p_hat: -31.452730143592913
normalized log_p:     -31.459480001278635
convergence gap:      0.00674985768572256


100%|██████████| 137328/137328 [01:20<00:00, 1708.95it/s]
  0%|          | 341/137328 [00:00<01:20, 1694.44it/s]

normalized log_p_hat: -31.446995471130396
normalized log_p:     -31.452730143592913
convergence gap:      0.005734672462516244


100%|██████████| 137328/137328 [01:20<00:00, 1711.69it/s]
  0%|          | 177/137328 [00:00<01:17, 1765.32it/s]

normalized log_p_hat: -31.442114220094634
normalized log_p:     -31.446995471130396
convergence gap:      0.004881251035762091


100%|██████████| 137328/137328 [01:20<00:00, 1714.81it/s]
  0%|          | 343/137328 [00:00<01:20, 1700.39it/s]

normalized log_p_hat: -31.43795201995237
normalized log_p:     -31.442114220094634
convergence gap:      0.004162200142264538
iteration: 20


100%|██████████| 137328/137328 [01:20<00:00, 1709.95it/s]
  0%|          | 163/137328 [00:00<01:24, 1626.00it/s]

normalized log_p_hat: -31.434396809623163
normalized log_p:     -31.43795201995237
convergence gap:      0.003555210329206915


100%|██████████| 137328/137328 [01:18<00:00, 1743.09it/s]
  0%|          | 180/137328 [00:00<01:16, 1792.31it/s]

normalized log_p_hat: -31.43135484668383
normalized log_p:     -31.434396809623163
convergence gap:      0.003041962939331455


100%|██████████| 137328/137328 [01:17<00:00, 1770.18it/s]
  0%|          | 181/137328 [00:00<01:15, 1806.50it/s]

normalized log_p_hat: -31.42874752785493
normalized log_p:     -31.43135484668383
convergence gap:      0.002607318828900418


100%|██████████| 137328/137328 [01:15<00:00, 1808.53it/s]
  0%|          | 183/137328 [00:00<01:14, 1829.78it/s]

normalized log_p_hat: -31.426508822915086
normalized log_p:     -31.42874752785493
convergence gap:      0.0022387049398453485


100%|██████████| 137328/137328 [01:15<00:00, 1808.24it/s]
  0%|          | 182/137328 [00:00<01:15, 1813.43it/s]

normalized log_p_hat: -31.424583181429266
normalized log_p:     -31.426508822915086
convergence gap:      0.0019256414858190851
iteration: 25


100%|██████████| 137328/137328 [01:15<00:00, 1808.43it/s]
  0%|          | 364/137328 [00:00<01:15, 1814.25it/s]

normalized log_p_hat: -31.422923810728065
normalized log_p:     -31.424583181429266
convergence gap:      0.001659370701201368


100%|██████████| 137328/137328 [01:16<00:00, 1804.69it/s]
  0%|          | 383/137328 [00:00<01:11, 1924.51it/s]

normalized log_p_hat: -31.421491250329314
normalized log_p:     -31.422923810728065
convergence gap:      0.0014325603987508373


100%|██████████| 137328/137328 [01:15<00:00, 1810.49it/s]
  0%|          | 366/137328 [00:00<01:15, 1818.72it/s]

normalized log_p_hat: -31.420252186446103
normalized log_p:     -31.421491250329314
convergence gap:      0.0012390638832115997


100%|██████████| 137328/137328 [01:15<00:00, 1808.77it/s]
  0%|          | 367/137328 [00:00<01:15, 1825.91it/s]

normalized log_p_hat: -31.41917846315028
normalized log_p:     -31.420252186446103
convergence gap:      0.0010737232958213383


100%|██████████| 137328/137328 [01:16<00:00, 1805.14it/s]
  0%|          | 181/137328 [00:00<01:15, 1807.09it/s]

normalized log_p_hat: -31.418246255975216
normalized log_p:     -31.41917846315028
convergence gap:      0.0009322071750652583
iteration: 30


100%|██████████| 137328/137328 [01:16<00:00, 1783.69it/s]
  0%|          | 146/137328 [00:00<01:34, 1454.81it/s]

normalized log_p_hat: -31.417435380469346
normalized log_p:     -31.418246255975216
convergence gap:      0.0008108755058700012


100%|██████████| 137328/137328 [01:17<00:00, 1779.71it/s]
  0%|          | 183/137328 [00:00<01:15, 1827.70it/s]

normalized log_p_hat: -31.41672871330004
normalized log_p:     -31.417435380469346
convergence gap:      0.0007066671693074511


100%|██████████| 137328/137328 [01:16<00:00, 1791.70it/s]
  0%|          | 181/137328 [00:00<01:15, 1805.77it/s]

normalized log_p_hat: -31.416111707402408
normalized log_p:     -31.41672871330004
convergence gap:      0.0006170058976309178


100%|██████████| 137328/137328 [01:15<00:00, 1814.51it/s]
  0%|          | 182/137328 [00:00<01:15, 1809.47it/s]

normalized log_p_hat: -31.415571985818687
normalized log_p:     -31.416111707402408
convergence gap:      0.0005397215837206204


100%|██████████| 137328/137328 [01:15<00:00, 1807.30it/s]

normalized log_p_hat: -31.415099001376582
normalized log_p:     -31.415571985818687
convergence gap:      0.00047298444210497337





In [25]:
# 35 iterations, ~38 min
np.save('results/mcmichael_mu.npy', mu_hat)
np.save('results/mcmichael_R.npy', R_hat)
np.save('results/mcmichael_log_p.npy', log_p_hat)

## Evaluation

In [26]:
@tf.function(experimental_relax_shapes=True)
def run_graph_square_error(mu, R, movie_ids_t, labels_t, y_t, H_xt, H_xt_trans, H_yt, H_yt_trans):
    # calculate X_t_hat
    R_xt = H_xt @ R @ H_xt_trans
    R_yt = H_yt @ R @ H_yt_trans
    R_yt_inv = tf.linalg.inv(R_yt)
    R_xtyt = H_xt @ R @ H_yt_trans

    mu_yt = tf.matmul(H_yt, mu)
    mu_xt = tf.matmul(H_xt, mu)

    X_t_hat = R_xtyt @ R_yt_inv @ (y_t - mu_yt) + mu_xt
    
    # clip ratings
    predictions_t = tf.gather(tf.matmul(H_xt_trans, X_t_hat), indices=movie_ids_t-1)
    predictions_t = tf.clip_by_value(predictions_t, 1, 5)
    
    return tf.matmul(tf.transpose(labels_t - predictions_t), labels_t - predictions_t)

In [27]:
def evaluate(mu, R):
    square_error = 0
    l = 0
    for (H_yt, H_xt, H_yt_trans, H_xt_trans, k_t, Z_t, y_t, x_t, movie_ids_t, labels_t) in tqdm(data_preprocessed):
        square_error += run_graph_square_error(mu, R, movie_ids_t, tf.cast(labels_t, dtype=tf.float64), y_t, H_xt, H_xt_trans, H_yt, H_yt_trans)
        l += len(labels_t)
    return np.sqrt(square_error/l)

In [28]:
em_mu = np.load('results/em_mu.npy')
em_R = np.load('results/em_R.npy')
rmse = evaluate(em_mu, em_R)
rmse

100%|██████████| 137328/137328 [01:34<00:00, 1454.87it/s]


array([[0.91700701]])

In [29]:
mcmichael_mu = np.load('results/mcmichael_mu.npy')
mcmichael_R = np.load('results/mcmichael_R.npy')
rmse = evaluate(mcmichael_mu, mcmichael_R)
rmse

100%|██████████| 137328/137328 [01:36<00:00, 1425.03it/s]


array([[0.91701472]])