# Collaborative filtering

In [1]:
from typing import Tuple
%load_ext autoreload
%autoreload 2

In [44]:
import os
import surprise
from surprise import Dataset, Reader
import pandas as pd
from scipy.sparse import csr_matrix

from sklearn.decomposition import NMF
from sklearn.model_selection import train_test_split

from sklearn.metrics import mean_squared_error

## Load MovieLens dataset

In [34]:
movielens = Dataset.load_builtin('ml-1m', prompt=False)
ratings_file = f"{surprise.get_dataset_dir()}/ml-1m/ml-1m/ratings.dat"
ratings_df = pd.read_csv(ratings_file, sep="::", names=["UserID", "MovieID", "Rating", "Timestamp"], engine='python')

In [35]:
print(f'Users: {ratings_df["UserID"].max()}')
print(f'Movies: {ratings_df["MovieID"].max()}')

Users: 6040
Movies: 3952


## Prepare train/test data

In [31]:
def convert_to_sparse(df: pd.DataFrame) -> csr_matrix:
    rows, cols, vals = zip(*df.values)
    return csr_matrix((vals, (rows, cols)))

In [39]:
ratings_df.sort_values(by='Timestamp', inplace=True)

ratings_df['UserID'] = ratings_df['UserID'].apply(lambda x: x - 1)
ratings_df['MovieID'] = ratings_df['MovieID'].apply(lambda x: x - 1)

train_df, test_df = train_test_split(
    ratings_df.iloc[:, :3], shuffle=False, test_size=0.1)

train_sparse = convert_to_sparse(train_df)
test_sparse = convert_to_sparse(test_df)

print(train_sparse.shape)
print(test_sparse.shape)

(6040, 3952)
(6040, 3952)


## Use Non-negative Matrix Factorization to predict users ratings

In [58]:
nmf = NMF(n_components=100)
res = nmf.fit_transform(train_sparse)
preds = res @ nmf.components_




In [47]:
rmse = mean_squared_error(preds, test_sparse.toarray(), squared=False)
rmse

0.36238612699460515

In [71]:
preds[23, 108]

0.0

In [69]:
print(train_sparse)

  (23, 109)	4.0
  (23, 265)	2.0
  (23, 295)	5.0
  (23, 317)	4.0
  (23, 363)	4.0
  (23, 371)	1.0
  (23, 526)	5.0
  (23, 538)	3.0
  (23, 584)	4.0
  (23, 592)	5.0
  (23, 596)	4.0
  (23, 607)	5.0
  (23, 911)	5.0
  (23, 952)	5.0
  (23, 1056)	4.0
  (23, 1078)	4.0
  (23, 1089)	4.0
  (23, 1196)	4.0
  (23, 1212)	5.0
  (23, 1219)	4.0
  (23, 1220)	5.0
  (23, 1224)	4.0
  (23, 1246)	4.0
  (23, 1258)	4.0
  (23, 1262)	5.0
  :	:
  (6039, 3223)	5.0
  (6039, 3261)	4.0
  (6039, 3288)	5.0
  (6039, 3333)	5.0
  (6039, 3341)	3.0
  (6039, 3358)	4.0
  (6039, 3360)	2.0
  (6039, 3387)	1.0
  (6039, 3417)	3.0
  (6039, 3421)	3.0
  (6039, 3423)	2.0
  (6039, 3448)	3.0
  (6039, 3470)	4.0
  (6039, 3503)	4.0
  (6039, 3504)	4.0
  (6039, 3520)	5.0
  (6039, 3523)	1.0
  (6039, 3542)	4.0
  (6039, 3546)	4.0
  (6039, 3551)	2.0
  (6039, 3682)	4.0
  (6039, 3702)	4.0
  (6039, 3734)	4.0
  (6039, 3750)	4.0
  (6039, 3818)	5.0
