# Collaborative Filtering

In [34]:
import pandas as pd
import numpy as np

In [None]:
from sklearn.metrics.pairwise import 

In [2]:
r_cols = ["user_id", "movie_id", "rating", "unix_timestamp"]
ratings = pd.read_csv("data/ml-100k/u.data", names=r_cols, sep="\t", encoding="latin-1")
ratings.head()

Unnamed: 0,user_id,movie_id,rating,unix_timestamp
0,196,242,3,881250949
1,186,302,3,891717742
2,22,377,1,878887116
3,244,51,2,880606923
4,166,346,1,886397596


Fill missing ratings with 0, so that it doesn't affect the similarity

In [14]:
X = ratings.pivot_table(values='rating', index='user_id', columns='movie_id')

In [15]:
X.shape

(943, 1682)

In [36]:
n_users, n_items = X.shape

In [132]:
user_sim = np.zeros((n_users, n_users))
item_sim = np.zeros((n_items, n_items))

In [107]:
X_isna = X.isna()

In [137]:
def calc_dist_nan(curr_id, other_id, calc, metric):
    if calc == "user":
        curr = X.loc[curr_id, :]
        other = X.loc[other_id, :]
        curr_isna = X_isna.loc[curr_id, :]
        other_isna = X_isna.loc[other_id, :]
        
    elif calc == "item":
        curr = X.loc[:, curr_id]
        other = X.loc[:, other_id]
        curr_isna = X_isna.loc[:, curr_id]
        other_isna = X_isna.loc[:, other_id]
        
        
    valid_idx = ~(curr_isna | other_isna)
    other = other.loc[valid_idx].values.reshape(1, -1)
    curr = curr.loc[valid_idx].values.reshape(1, -1)
    
    try:
        res = pairwise_distances(curr, other, metric=metric)[0, 0]
    except ValueError:
        res = np.nan
    
    return res

In [141]:
calc_dist_nan(1, 5, "user", "correlation")

0.5791913817512464

In [142]:
# for user_id in user_ids:
#     for other_id in user_ids:
#         user_sim[user_id-1, other_id-1] = calc_dist_nan(user_id, other_id, "user", "correlation")

* Cosine Similarity is not affected by ZERO padding (X -> X_zf (zero-fill))
* Pearson Correlation is not affected by MEAN padding (X -> X_umf, X_imf (user and item mean-fill))

In [147]:
X_umf

movie_id,1,2,3,4,5,6,7,8,9,10,...,1673,1674,1675,1676,1677,1678,1679,1680,1681,1682
user_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,5.000000,3.000000,4.000000,3.000000,3.000000,5.000000,4.000000,1.00000,5.000000,3.000000,...,,,,,,,,,,
2,4.000000,3.709677,2.796296,4.333333,2.874286,3.635071,3.965261,3.79661,4.272727,2.000000,...,,,,,,,,,,
3,3.610294,3.709677,2.796296,4.333333,2.874286,3.635071,3.965261,3.79661,4.272727,4.206522,...,,,,,,,,,,
4,3.610294,3.709677,2.796296,4.333333,2.874286,3.635071,3.965261,3.79661,4.272727,4.206522,...,,,,,,,,,,
5,4.000000,3.000000,2.796296,4.333333,2.874286,3.635071,3.965261,3.79661,4.272727,4.206522,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
939,3.610294,3.709677,2.796296,4.333333,2.874286,3.635071,3.965261,3.79661,5.000000,4.206522,...,,,,,,,,,,
940,3.610294,3.709677,2.796296,2.000000,2.874286,3.635071,4.000000,5.00000,3.000000,4.206522,...,,,,,,,,,,
941,5.000000,3.709677,2.796296,4.333333,2.874286,3.635071,4.000000,3.79661,4.272727,4.206522,...,,,,,,,,,,
942,3.610294,3.709677,2.796296,4.333333,2.874286,3.635071,3.965261,3.79661,4.272727,4.206522,...,,,,,,,,,,


In [149]:
X_zf = X.fillna(0)
X_umf = X.T.fillna(X.mean(axis=1), axis=0).T
X_imf = X.fillna(X.mean(axis=0), axis=0)

Here I choose the pearson correlation similarity

In [158]:
user_sim = pairwise_distances(X_umf, X_umf, metric="correlation")
item_sim = pairwise_distances(X_imf.T, X_imf.T, metric="correlation")

In [160]:
user_sim = pd.DataFrame(user_sim, index=X.index, columns=X.index)
item_sim = pd.DataFrame(item_sim, index=X.columns, columns=X.columns)

In [161]:
user_sim

user_id,1,2,3,4,5,6,7,8,9,10,...,934,935,936,937,938,939,940,941,942,943
user_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,0.000000,9.565892e-01,0.988949,0.940697,8.654862e-01,0.896627,0.889444,0.819109,0.987747,1.000621,...,0.974165,1.047952,0.912776,0.992282,0.925622,0.921286,0.932567,0.971210,1.031270,0.967877
2,0.956589,2.220446e-16,0.986342,1.017016,9.642297e-01,0.905497,0.910592,0.944360,0.972706,0.902154,...,0.987147,1.028798,0.943341,0.802165,0.909991,0.967495,0.984947,1.017344,0.987932,0.960827
3,0.988949,9.863423e-01,0.000000,1.059638,9.839633e-01,1.017158,0.983859,0.958823,1.010093,0.976144,...,0.998385,0.999342,1.006888,0.963843,1.018513,1.006240,1.023907,0.965586,1.009187,0.998511
4,0.940697,1.017016e+00,1.059638,0.000000,9.926269e-01,1.053929,1.025604,0.863954,0.983918,1.013588,...,0.988105,0.997826,1.028000,1.025021,0.977118,1.005960,0.720182,0.741406,0.935496,1.019222
5,0.865486,9.642297e-01,0.983963,0.992627,2.220446e-16,0.961516,0.932126,0.859894,0.989805,0.985665,...,0.929986,1.070821,0.975722,0.961328,0.906433,0.948218,0.970460,0.963766,0.956682,0.900676
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
939,0.921286,9.674949e-01,1.006240,1.005960,9.482181e-01,1.047520,0.986416,0.974974,0.981525,0.985613,...,0.967392,1.004944,0.935964,1.035230,0.980071,0.000000,0.983165,1.030376,1.023190,0.995549
940,0.932567,9.849471e-01,1.023907,0.720182,9.704598e-01,1.012071,0.994156,0.921778,0.995509,0.955428,...,1.029460,0.945354,1.059929,1.032935,0.977354,0.983165,0.000000,0.897992,1.011483,0.934586
941,0.971210,1.017344e+00,0.965586,0.741406,9.637664e-01,0.998441,0.998057,0.942051,0.959252,0.973821,...,1.025764,1.031663,0.942415,0.905817,1.094838,1.030376,0.897992,0.000000,1.019055,0.999334
942,1.031270,9.879321e-01,1.009187,0.935496,9.566816e-01,0.963395,0.893748,0.969391,0.973741,0.937933,...,0.885821,1.035153,1.002613,0.968415,0.992428,1.023190,1.011483,1.019055,0.000000,0.959646
