In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [2]:
df = pd.read_csv("data/100x100.csv")
data_matrix = df.pivot(index='u_id', columns='a_id', values='score').fillna(0)
data_matrix_values = data_matrix.values

In [13]:
# example given in the notebook
'''
Original matrix = 0.5 ? 4
                   1  3 5
'''
U = np.array([[0.7461],
              [1.7966]])
P = np.array([[0.758, 2.5431, 4.7999]])
prediction = U@P

In [4]:
def als(matrix, rank, iterations, regularization=0.1):
    num_users, num_items = matrix.shape
    X = np.ones((num_users, rank)) 
    Y = np.ones((num_items, rank)) 
    mask = matrix > 0

    for _ in range(iterations):
        for i in range(num_users):
            Y_i = Y[mask[i]]
            if Y_i.size == 0:
                continue
            A = Y_i.T @ Y_i + regularization * np.eye(rank)
            b = Y_i.T @ matrix[i, mask[i]]
            X[i] = np.linalg.lstsq(A, b, rcond=None)[0]

        for j in range(num_items):
            X_j = X[mask[:, j]]
            if X_j.size == 0:
                continue
            A = X_j.T @ X_j + regularization * np.eye(rank)
            b = X_j.T @ matrix[mask[:, j], j]
            Y[j] = np.linalg.lstsq(A, b, rcond=None)[0]

    return X, Y


In [None]:
# usage
X, Y = als(data_matrix_values, rank=20, iterations=10)

X shape: (98, 20)
Y shape: (97, 20)
Matrix shape (98, 97)
Reconstructed matrix shape (98, 97)


array([[ 6.62841695,  7.68756157,  5.40264905, ...,  7.03429342,
         9.21592991,  7.86476443],
       [ 8.04431393,  8.18118253,  7.98742659, ...,  6.1651632 ,
         5.26046528,  5.00631648],
       [ 9.18245238,  8.44696363,  5.61542585, ...,  6.15504619,
         6.96268395,  8.46482085],
       ...,
       [10.38276835,  9.02836327,  8.92303121, ...,  6.27756589,
         7.72610822,  8.49420705],
       [ 9.28835902,  8.91231395,  8.02109134, ...,  7.15230149,
         9.69991859,  7.95093381],
       [10.48358212,  9.28696434,  8.99515993, ...,  7.19024711,
         6.21426391,  9.15314535]])

In [6]:
# we tested for different rank values between 10 and 100, and found that 20 works the best for not overfitting
rank = 20

In [None]:
X, Y = als(data_matrix_values, rank=rank, iterations=10)
predicted_ratings_10_iters = X @ Y.T

predicted_ratings_10_iters_df = pd.DataFrame(predicted_ratings_10_iters, index=data_matrix.index, columns=data_matrix.columns)

a_id,1,30,32,199,226,227,323,339,356,849,...,34134,34561,34599,34618,34933,35790,35849,37349,37450,38000
u_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,6.628417,7.687562,5.402649,8.780447,5.507075,8.145473,5.433910,5.959895,8.759678,8.558854,...,5.146699,5.936240,7.999051,9.024651,7.993552,7.646334,5.746000,7.034293,9.215930,7.864764
10,8.044314,8.181183,7.987427,9.922563,7.009426,8.967301,7.009512,8.399231,6.498519,7.992086,...,3.683092,5.815052,7.826018,6.956949,5.845903,6.675370,5.711624,6.165163,5.260465,5.006316
16,9.182452,8.446964,5.615426,4.847797,7.709941,6.833527,6.975561,8.412437,6.133734,6.644996,...,5.684588,5.533363,8.025919,6.294171,6.540905,5.568686,6.988648,6.155046,6.962684,8.464821
31,9.979005,7.016263,5.018637,8.958344,6.750463,8.978737,7.595017,9.992796,7.353880,8.212673,...,8.357179,7.993623,8.987070,7.375778,7.216598,6.974011,5.997378,7.585954,8.997802,11.336266
33,9.012604,6.795737,3.798263,6.213745,7.698718,7.196985,5.992718,4.989136,6.340017,5.002813,...,6.879338,2.514162,7.356398,5.405338,3.880746,6.917623,6.328062,5.724771,7.367927,8.486104
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4808,9.604811,10.100569,9.125708,5.289214,9.000467,9.954338,7.729192,8.964086,7.997218,7.984716,...,6.687483,7.311397,8.646190,7.728576,6.975162,8.431555,9.272508,6.467508,8.412898,7.437706
5520,8.112532,7.840883,5.222264,1.680468,8.993221,7.858330,6.228015,7.221919,8.000747,6.295283,...,6.190355,5.647271,6.830615,7.329422,6.539738,7.528291,8.462395,7.057407,8.982968,9.182918
5956,10.382768,9.028363,8.923031,8.073051,6.960473,9.068252,7.638480,8.046400,6.799384,8.730838,...,5.123984,6.930332,10.324481,7.832357,7.720624,8.221867,6.838282,6.277566,7.726108,8.494207
6009,9.288359,8.912314,8.021091,4.166932,9.936562,8.339075,7.727307,7.977334,7.537736,6.544186,...,8.264552,6.650389,8.309616,7.228990,7.787021,9.092286,8.392553,7.152301,9.699919,7.950934


In [None]:
X, Y =  als(data_matrix_values, rank=rank, iterations=100)
predicted_ratings_100_iters = X @ Y.T

predicted_ratings_100_iters_df = pd.DataFrame(predicted_ratings_100_iters, index=data_matrix.index, columns=data_matrix.columns)

a_id,1,30,32,199,226,227,323,339,356,849,...,34134,34561,34599,34618,34933,35790,35849,37349,37450,38000
u_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,9.799286,6.209814,5.134251,9.691858,9.045617,7.751278,6.067830,7.337744,8.041424,7.696107,...,6.584201,5.919802,8.014215,8.984295,7.988165,7.130738,8.398727,5.218742,8.850616,7.137172
10,8.135538,7.733444,6.497250,9.954206,6.983077,8.988919,6.995223,7.639507,5.937502,7.960583,...,5.818651,6.211850,8.527341,6.621685,6.600408,5.889068,5.152743,5.063778,6.167534,6.458517
16,9.109404,7.334765,5.394615,7.474145,5.259301,6.571761,6.618282,7.273938,5.801734,8.051011,...,3.686087,7.047830,9.205129,5.792112,6.472330,7.344952,6.978390,5.118155,7.364363,8.014984
31,9.985370,6.993286,5.020568,8.640313,5.791259,8.996141,6.145709,9.964826,5.297306,8.199904,...,5.708388,7.973317,8.988121,7.317239,7.299086,5.892437,6.006733,4.570227,8.973351,6.285145
33,8.981424,6.229385,4.779015,6.645957,5.364214,4.671210,5.994459,5.012500,4.413752,5.011028,...,3.370971,3.220063,8.457979,6.813336,2.275008,5.187574,4.045186,4.281681,6.611366,7.468022
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4808,8.388236,6.947158,7.910509,7.134604,9.005417,9.969349,7.047952,7.966596,7.997364,8.014402,...,7.093992,6.388685,7.566883,6.435107,4.697656,7.518970,5.715564,4.950643,8.768531,6.251255
5520,7.235031,6.507074,5.334414,4.714400,8.974790,6.460744,5.203004,7.007599,7.974276,5.591701,...,4.675106,4.219928,5.539201,6.669975,4.291201,6.322156,7.375943,5.879305,7.392896,6.689860
5956,9.496504,8.986047,8.972662,8.477463,9.335238,8.982334,7.762898,7.720141,7.208664,8.481604,...,5.904192,6.938307,9.880742,7.345450,6.393735,8.475976,6.321693,5.757966,8.202797,7.753348
6009,8.581916,8.655184,7.985773,7.000023,9.982654,8.440981,7.492994,7.985404,7.963345,7.549696,...,6.755327,5.990340,8.364098,7.170233,5.350733,7.545703,6.322989,6.340096,8.777811,8.179139


In [None]:
X, Y = als(data_matrix_values, rank=rank, iterations=1000)
predicted_ratings_1000_iters = X @ Y.T
predicted_ratings_1000_iters_df = pd.DataFrame(predicted_ratings_1000_iters, index=data_matrix.index, columns=data_matrix.columns)

a_id,1,30,32,199,226,227,323,339,356,849,...,34134,34561,34599,34618,34933,35790,35849,37349,37450,38000
u_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,9.722517,5.943924,4.671962,9.487938,7.787467,7.538301,7.003425,8.117456,7.425531,7.501276,...,7.051081,6.588838,8.009796,8.984622,7.985859,7.330384,9.367366,6.520074,8.798900,7.126216
10,7.630680,7.748732,7.024354,9.964332,6.986110,8.982757,6.997319,7.963682,6.363917,7.981700,...,5.611411,6.404032,8.000449,7.174808,6.102048,4.683355,4.135138,5.251463,6.093876,7.108155
16,9.880541,7.777455,6.624077,9.264008,5.440328,7.743563,8.087513,8.645583,6.460953,8.354615,...,4.778063,7.684557,9.528250,6.810366,7.630969,8.060188,6.986242,5.719411,7.968621,8.386620
31,9.989097,6.995555,5.024237,9.237064,5.987968,8.989608,7.319195,9.970884,5.794415,8.155168,...,7.269933,7.976571,8.987946,7.339855,9.134279,5.489827,6.003199,5.210916,8.985723,6.981356
33,8.980025,5.102578,5.113566,6.178715,5.087214,5.374637,5.997520,5.015249,4.561015,5.018501,...,5.558526,4.497428,7.260592,6.500286,3.232335,5.710462,3.820992,3.571940,8.249099,6.558594
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4808,8.440877,6.999913,8.393720,6.925241,8.991095,9.969244,8.026831,7.902612,8.008445,8.013802,...,8.189527,6.958161,7.945098,7.271438,5.263821,6.944768,5.790718,5.119581,9.429061,7.530345
5520,8.304399,7.362213,5.839307,5.374516,8.993461,7.281695,6.359059,7.554754,7.974924,6.234686,...,6.418203,5.863546,6.698377,6.164832,4.854941,5.615509,7.081614,6.478895,8.530052,8.621976
5956,10.255611,8.993479,8.985433,8.201136,8.937520,8.433962,8.287975,7.942757,7.127582,8.734911,...,6.835212,7.419557,9.873752,7.917017,6.238603,8.265779,6.627914,6.276775,8.924603,8.961490
6009,9.218717,8.970226,8.234007,7.686513,9.985772,7.981421,8.493786,7.993991,7.981500,8.364209,...,6.998027,6.377089,9.120712,8.008261,5.102630,7.471181,7.128175,7.163241,8.830189,9.447999


In [None]:
'''
print("min values for 10, 100, and 1000 iterations")
print(predicted_ratings_10_iters.min())
print(predicted_ratings_100_iters.min())
print(predicted_ratings_1000_iters.min())
print("<------------------>")
print("max values for 10, 100, and 1000 iterations")
print(predicted_ratings_10_iters.max())
print(predicted_ratings_100_iters.max())
print(predicted_ratings_1000_iters.max())
'''

min values for 10, 100, and 1000 iterations
-4.503779663879647
-0.20542168432411986
-0.4981605184747932
<------------------>
max values for 10, 100, and 1000 iterations
13.89144967384726
12.462630826054921
11.807310308930694


In [11]:
mask = data_matrix_values > 0

predicted_ratings_10_iters_df_no_original = predicted_ratings_10_iters_df.where(~mask)
predicted_ratings_100_iters_df_no_original = predicted_ratings_100_iters_df.where(~mask)
predicted_ratings_1000_iters_df_no_original = predicted_ratings_1000_iters_df.where(~mask)

In [12]:
predicted_ratings_10_iters_df_no_original

a_id,1,30,32,199,226,227,323,339,356,849,...,34134,34561,34599,34618,34933,35790,35849,37349,37450,38000
u_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,6.628417,7.687562,5.402649,8.780447,5.507075,8.145473,5.433910,5.959895,8.759678,8.558854,...,5.146699,5.936240,,,,7.646334,5.746000,7.034293,9.215930,7.864764
10,8.044314,8.181183,7.987427,,,,,8.399231,6.498519,,...,3.683092,5.815052,7.826018,6.956949,5.845903,6.675370,5.711624,6.165163,5.260465,5.006316
16,9.182452,8.446964,5.615426,4.847797,7.709941,6.833527,6.975561,8.412437,6.133734,6.644996,...,5.684588,5.533363,8.025919,6.294171,6.540905,5.568686,,6.155046,6.962684,8.464821
31,,,,8.958344,6.750463,,7.595017,,7.353880,8.212673,...,8.357179,,,7.375778,7.216598,6.974011,,7.585954,,11.336266
33,,6.795737,3.798263,6.213745,7.698718,7.196985,,,6.340017,,...,6.879338,2.514162,7.356398,5.405338,3.880746,6.917623,6.328062,5.724771,7.367927,8.486104
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4808,9.604811,10.100569,9.125708,5.289214,,,7.729192,8.964086,,,...,6.687483,7.311397,8.646190,7.728576,6.975162,8.431555,9.272508,6.467508,8.412898,7.437706
5520,8.112532,7.840883,5.222264,1.680468,,7.858330,6.228015,7.221919,,6.295283,...,6.190355,5.647271,6.830615,7.329422,6.539738,7.528291,8.462395,7.057407,8.982968,9.182918
5956,10.382768,,,8.073051,6.960473,9.068252,7.638480,8.046400,6.799384,8.730838,...,5.123984,6.930332,10.324481,7.832357,7.720624,8.221867,6.838282,6.277566,7.726108,8.494207
6009,9.288359,8.912314,8.021091,4.166932,,8.339075,7.727307,,7.537736,6.544186,...,8.264552,6.650389,8.309616,7.228990,7.787021,9.092286,8.392553,7.152301,9.699919,7.950934
