In [1]:
from collections import defaultdict
from surprise import SVD
from surprise import Dataset

In [2]:
# First train an SVD algorithm on the movielens dataset.
data = Dataset.load_builtin('ml-100k')

In [3]:
trainset = data.build_full_trainset()

In [4]:
trainset

<surprise.trainset.Trainset at 0x260ded92d48>

In [5]:
trainset.ur

defaultdict(list,
            {0: [(0, 3.0),
              (528, 4.0),
              (377, 4.0),
              (522, 3.0),
              (431, 5.0),
              (834, 5.0),
              (380, 4.0),
              (329, 4.0),
              (550, 5.0),
              (83, 4.0),
              (632, 2.0),
              (86, 4.0),
              (289, 5.0),
              (363, 3.0),
              (438, 5.0),
              (389, 5.0),
              (649, 4.0),
              (947, 4.0),
              (423, 3.0),
              (291, 3.0),
              (10, 2.0),
              (1006, 4.0),
              (179, 3.0),
              (751, 3.0),
              (487, 3.0),
              (665, 3.0),
              (92, 4.0),
              (512, 5.0),
              (1045, 3.0),
              (672, 4.0),
              (656, 4.0),
              (221, 5.0),
              (432, 2.0),
              (365, 3.0),
              (321, 2.0),
              (466, 4.0),
              (302, 4.0),
              (491, 3

In [6]:
algo = SVD()
algo.fit(trainset)

<surprise.prediction_algorithms.matrix_factorization.SVD at 0x260dfacdb48>

In [7]:
# Than predict ratings for all pairs (u, i) that are NOT in the training set.
testset = trainset.build_anti_testset()

In [8]:
testset

[('196', '302', 3.52986),
 ('196', '377', 3.52986),
 ('196', '51', 3.52986),
 ('196', '346', 3.52986),
 ('196', '474', 3.52986),
 ('196', '265', 3.52986),
 ('196', '465', 3.52986),
 ('196', '451', 3.52986),
 ('196', '86', 3.52986),
 ('196', '1014', 3.52986),
 ('196', '222', 3.52986),
 ('196', '40', 3.52986),
 ('196', '29', 3.52986),
 ('196', '785', 3.52986),
 ('196', '387', 3.52986),
 ('196', '274', 3.52986),
 ('196', '1042', 3.52986),
 ('196', '1184', 3.52986),
 ('196', '392', 3.52986),
 ('196', '486', 3.52986),
 ('196', '144', 3.52986),
 ('196', '118', 3.52986),
 ('196', '1', 3.52986),
 ('196', '546', 3.52986),
 ('196', '95', 3.52986),
 ('196', '768', 3.52986),
 ('196', '277', 3.52986),
 ('196', '234', 3.52986),
 ('196', '246', 3.52986),
 ('196', '98', 3.52986),
 ('196', '193', 3.52986),
 ('196', '88', 3.52986),
 ('196', '194', 3.52986),
 ('196', '1081', 3.52986),
 ('196', '603', 3.52986),
 ('196', '796', 3.52986),
 ('196', '32', 3.52986),
 ('196', '16', 3.52986),
 ('196', '304', 3.5

In [9]:
predictions = algo.test(testset)

In [10]:
predictions

[Prediction(uid='196', iid='302', r_ui=3.52986, est=3.987192275999595, details={'was_impossible': False}),
 Prediction(uid='196', iid='377', r_ui=3.52986, est=2.649174610043701, details={'was_impossible': False}),
 Prediction(uid='196', iid='51', r_ui=3.52986, est=3.452988064548957, details={'was_impossible': False}),
 Prediction(uid='196', iid='346', r_ui=3.52986, est=3.5542385496456674, details={'was_impossible': False}),
 Prediction(uid='196', iid='474', r_ui=3.52986, est=4.173017098426478, details={'was_impossible': False}),
 Prediction(uid='196', iid='265', r_ui=3.52986, est=3.839601808635949, details={'was_impossible': False}),
 Prediction(uid='196', iid='465', r_ui=3.52986, est=3.51711484973258, details={'was_impossible': False}),
 Prediction(uid='196', iid='451', r_ui=3.52986, est=3.327100371520821, details={'was_impossible': False}),
 Prediction(uid='196', iid='86', r_ui=3.52986, est=3.8019095243072476, details={'was_impossible': False}),
 Prediction(uid='196', iid='1014', r_u

In [11]:
def get_top_n(predictions, n=10):
    # First map the predictions to each user.
    top_n = defaultdict(list)
    for uid, iid, true_r, est, _ in predictions:
        top_n[uid].append((iid, est))

    # Then sort the predictions for each user and retrieve the k highest ones.
    for uid, user_ratings in top_n.items():
        user_ratings.sort(key=lambda x: x[1], reverse=True)
        top_n[uid] = user_ratings[:n]

    return top_n

In [12]:
top_n = get_top_n(predictions, n=10)

In [13]:
top_n

defaultdict(list,
            {'196': [('427', 4.522658015467568),
              ('357', 4.498307384991643),
              ('50', 4.493004948165097),
              ('483', 4.479719934836907),
              ('316', 4.463809115159127),
              ('169', 4.463800242226851),
              ('318', 4.463525139493012),
              ('132', 4.455844624833148),
              ('178', 4.442360528087856),
              ('315', 4.369813164025638)],
             '186': [('515', 4.765418586925616),
              ('318', 4.709841343786735),
              ('427', 4.662710320586977),
              ('178', 4.601150654564895),
              ('512', 4.573888202551311),
              ('185', 4.56123389918873),
              ('205', 4.544244785536863),
              ('648', 4.528382821180502),
              ('480', 4.523413424763061),
              ('496', 4.506427254930999)],
             '22': [('169', 4.921499764847399),
              ('357', 4.903161087022933),
              ('98', 4.849706875835157

In [14]:
# Print the recommended items for each user
for uid, user_ratings in top_n.items():
    print(uid, [iid for (iid, _) in user_ratings])

196 ['427', '357', '50', '483', '316', '169', '318', '132', '178', '315']
186 ['515', '318', '427', '178', '512', '185', '205', '648', '480', '496']
22 ['169', '357', '98', '182', '100', '56', '480', '603', '22', '12']
244 ['483', '515', '205', '511', '272', '510', '479', '302', '408', '55']
166 ['174', '98', '169', '50', '302', '603', '64', '520', '12', '318']
298 ['313', '272', '64', '251', '480', '12', '657', '515', '316', '114']
115 ['179', '134', '135', '474', '285', '488', '114', '483', '480', '91']
253 ['174', '515', '172', '480', '520', '178', '169', '357', '1142', '963']
305 ['514', '513', '213', '641', '124', '185', '603', '498', '57', '116']
6 ['603', '48', '198', '179', '654', '657', '705', '190', '434', '919']
62 ['169', '427', '656', '661', '48', '178', '480', '234', '515', '60']
286 ['12', '429', '64', '488', '612', '485', '318', '166', '479', '136']
200 ['181', '272', '64', '12', '251', '963', '316', '114', '302', '427']
210 ['408', '169', '496', '12', '511', '64', '178

695 ['178', '480', '127', '603', '661', '187', '357', '483', '427', '98']
675 ['64', '169', '408', '496', '1449', '191', '483', '603', '313', '515']
708 ['173', '64', '69', '272', '12', '318', '530', '28', '496', '511']
709 ['114', '194', '408', '272', '169', '483', '511', '479', '313', '513']
711 ['474', '9', '285', '192', '14', '603', '178', '57', '242', '661']
710 ['178', '408', '169', '427', '114', '98', '191', '493', '132', '519']
712 ['408', '272', '313', '134', '275', '318', '114', '603', '923', '251']
715 ['187', '357', '127', '169', '178', '515', '496', '1142', '315', '603']
713 ['169', '50', '318', '178', '8', '285', '603', '498', '527', '12']
716 ['170', '169', '272', '657', '100', '963', '313', '694', '408', '114']
681 ['318', '169', '408', '512', '483', '520', '89', '479', '197', '14']
678 ['408', '318', '89', '12', '480', '169', '187', '474', '511', '199']
719 ['923', '479', '603', '317', '192', '408', '166', '272', '191', '197']
702 ['169', '172', '272', '50', '114', '98