In [2]:
# %pip install scikit-surprise==1.1.0 # pickle

In [3]:
from surprise import accuracy, Reader, Dataset, SVD
from surprise.model_selection import cross_validate

In [4]:
import pandas as pd
import pathlib
DATA_DIR = pathlib.Path().resolve().parent / "data"
print(DATA_DIR)
DATA_DIR.exists()

/home/louis/astraZenith/backend_recommender/recommender/src/data


True

In [5]:
dataset = DATA_DIR / 'ratings_small.csv'
dataset.exists()

True

In [6]:
df = pd.read_csv(dataset)
df['rating'].dropna(inplace=True) # drop NaNs in rating column
df.head()

Unnamed: 0,userId,movieId,rating,timestamp
0,1,31,2.5,1260759144
1,1,1029,3.0,1260759179
2,1,1061,3.0,1260759182
3,1,1129,2.0,1260759185
4,1,1172,4.0,1260759205


In [7]:
df.rating.max(), df.rating.min()

(5.0, 0.5)

In [8]:
reader = Reader(rating_scale=(0.5, 5))
data = Dataset.load_from_df(df[['userId', 'movieId', 'rating']], reader)

In [9]:
algo = SVD(verbose=True, n_epochs=20)
cross_validate(algo, data, measures=['RMSE', "MAE"], cv=4, verbose=True)

Processing epoch 0
Processing epoch 1
Processing epoch 2
Processing epoch 3
Processing epoch 4
Processing epoch 5
Processing epoch 6
Processing epoch 7
Processing epoch 8
Processing epoch 9
Processing epoch 10
Processing epoch 11
Processing epoch 12
Processing epoch 13
Processing epoch 14
Processing epoch 15
Processing epoch 16
Processing epoch 17
Processing epoch 18
Processing epoch 19
Processing epoch 0
Processing epoch 1
Processing epoch 2
Processing epoch 3
Processing epoch 4
Processing epoch 5
Processing epoch 6
Processing epoch 7
Processing epoch 8
Processing epoch 9
Processing epoch 10
Processing epoch 11
Processing epoch 12
Processing epoch 13
Processing epoch 14
Processing epoch 15
Processing epoch 16
Processing epoch 17
Processing epoch 18
Processing epoch 19
Processing epoch 0
Processing epoch 1
Processing epoch 2
Processing epoch 3
Processing epoch 4
Processing epoch 5
Processing epoch 6
Processing epoch 7
Processing epoch 8
Processing epoch 9
Processing epoch 10
Processing

{'test_rmse': array([0.89822947, 0.90293588, 0.90388827, 0.89507063]),
 'test_mae': array([0.69162933, 0.69393927, 0.69672426, 0.68894889]),
 'fit_time': (1.5768787860870361,
  1.581533670425415,
  1.6010565757751465,
  1.6083545684814453),
 'test_time': (0.10312032699584961,
  0.05711674690246582,
  0.05841660499572754,
  0.08768105506896973)}

In [10]:
trainset = data.build_full_trainset()
algo.fit(trainset)

Processing epoch 0
Processing epoch 1
Processing epoch 2
Processing epoch 3
Processing epoch 4
Processing epoch 5
Processing epoch 6
Processing epoch 7
Processing epoch 8
Processing epoch 9
Processing epoch 10
Processing epoch 11
Processing epoch 12
Processing epoch 13
Processing epoch 14
Processing epoch 15
Processing epoch 16
Processing epoch 17
Processing epoch 18
Processing epoch 19


<surprise.prediction_algorithms.matrix_factorization.SVD at 0x7f8828523730>

In [11]:
testset = trainset.build_testset()
predictions = algo.test(testset)
# RMSE should be low as we are biased
accuracy.rmse(predictions, verbose=True)
accuracy.mae(predictions, verbose=True)

RMSE: 0.6423
MAE:  0.4973


0.49727536114571536

In [19]:
sample_row = df.sample(n=5)
userId   = sample_row['userId'].values[0]
movieId = sample_row['movieId'].values[0]
rating   = sample_row['rating'].values[0]
print(userId, movieId)

610 61240


In [20]:
pred = algo.predict(uid=userId, iid=movieId)
pred.est

3.7442653458815305

In [14]:
import pickle

In [15]:
algo_data = {"model": algo}
with open('model.pkl', 'wb') as f:
    pickle.dump(algo_data, f)

In [16]:
model_algo= None
with open('model.pkl', 'rb') as f:
    model_data_loaded = pickle.load(f)
    model_algo = model_data_loaded.get('model')

In [17]:
model_algo.predict(uid=userId, iid=movieId).est

3.6156030039138636

In [18]:
sample_rows = df.sample(n=10).to_dict('records')
for row in sample_rows:
    userId = row['userId']
    movieId = row['movieId']
    pred = model_algo.predict(uid=userId, iid=movieId).est
    print(userId, movieId, pred)

509 5810 3.02573919438441
195 1276 4.1508391029215375
620 56174 2.7844894791037254
547 1958 3.783149807989825
262 3216 3.1798135480747294
242 1272 4.661879280719291
238 63082 3.8111817338203053
12 3809 2.916885948159241
95 999 3.707453942102831
624 86028 2.1734159007207867
