In [2]:
from collections import defaultdict

from surprise import SVD
from surprise import Dataset

In [3]:
# First train an SVD algorithm on the movielens dataset.
data = Dataset.load_builtin('ml-100k')

In [6]:
trainset = data.build_full_trainset()

In [7]:
trainset

<surprise.trainset.Trainset at 0x7f6bc9adfda0>

In [11]:
trainset.ur

defaultdict(list,
            {0: [(0, 3.0),
              (528, 4.0),
              (377, 4.0),
              (522, 3.0),
              (431, 5.0),
              (834, 5.0),
              (380, 4.0),
              (329, 4.0),
              (550, 5.0),
              (83, 4.0),
              (632, 2.0),
              (86, 4.0),
              (289, 5.0),
              (363, 3.0),
              (438, 5.0),
              (389, 5.0),
              (649, 4.0),
              (947, 4.0),
              (423, 3.0),
              (291, 3.0),
              (10, 2.0),
              (1006, 4.0),
              (179, 3.0),
              (751, 3.0),
              (487, 3.0),
              (665, 3.0),
              (92, 4.0),
              (512, 5.0),
              (1045, 3.0),
              (672, 4.0),
              (656, 4.0),
              (221, 5.0),
              (432, 2.0),
              (365, 3.0),
              (321, 2.0),
              (466, 4.0),
              (302, 4.0),
              (491, 3

In [12]:
algo = SVD()
algo.fit(trainset)

<surprise.prediction_algorithms.matrix_factorization.SVD at 0x7f6bc84047b8>

In [13]:
# Than predict ratings for all pairs (u, i) that are NOT in the training set.
testset = trainset.build_anti_testset()

In [14]:
testset

[('196', '302', 3.52986),
 ('196', '377', 3.52986),
 ('196', '51', 3.52986),
 ('196', '346', 3.52986),
 ('196', '474', 3.52986),
 ('196', '265', 3.52986),
 ('196', '465', 3.52986),
 ('196', '451', 3.52986),
 ('196', '86', 3.52986),
 ('196', '1014', 3.52986),
 ('196', '222', 3.52986),
 ('196', '40', 3.52986),
 ('196', '29', 3.52986),
 ('196', '785', 3.52986),
 ('196', '387', 3.52986),
 ('196', '274', 3.52986),
 ('196', '1042', 3.52986),
 ('196', '1184', 3.52986),
 ('196', '392', 3.52986),
 ('196', '486', 3.52986),
 ('196', '144', 3.52986),
 ('196', '118', 3.52986),
 ('196', '1', 3.52986),
 ('196', '546', 3.52986),
 ('196', '95', 3.52986),
 ('196', '768', 3.52986),
 ('196', '277', 3.52986),
 ('196', '234', 3.52986),
 ('196', '246', 3.52986),
 ('196', '98', 3.52986),
 ('196', '193', 3.52986),
 ('196', '88', 3.52986),
 ('196', '194', 3.52986),
 ('196', '1081', 3.52986),
 ('196', '603', 3.52986),
 ('196', '796', 3.52986),
 ('196', '32', 3.52986),
 ('196', '16', 3.52986),
 ('196', '304', 3.5

In [28]:
predictions = algo.test(testset)

In [29]:
predictions

[Prediction(uid='196', iid='377', r_ui=3.52986, est=2.9756941091058904, details={'was_impossible': False}),
 Prediction(uid='196', iid='51', r_ui=3.52986, est=3.422448110920973, details={'was_impossible': False}),
 Prediction(uid='196', iid='346', r_ui=3.52986, est=3.661662362191269, details={'was_impossible': False}),
 Prediction(uid='196', iid='474', r_ui=3.52986, est=4.184302272447449, details={'was_impossible': False}),
 Prediction(uid='196', iid='265', r_ui=3.52986, est=4.007719614016683, details={'was_impossible': False}),
 Prediction(uid='196', iid='465', r_ui=3.52986, est=3.608625098577172, details={'was_impossible': False}),
 Prediction(uid='196', iid='451', r_ui=3.52986, est=3.5795488476512554, details={'was_impossible': False}),
 Prediction(uid='196', iid='86', r_ui=3.52986, est=4.115601291587445, details={'was_impossible': False}),
 Prediction(uid='196', iid='1014', r_ui=3.52986, est=3.0224204336280964, details={'was_impossible': False}),
 Prediction(uid='196', iid='222', r

In [34]:
def get_top_n(predictions, n=10):
    # First map the predictions to each user.
    top_n = defaultdict(list)
    for uid, iid, true_r, est, _ in predictions:
        top_n[uid].append((iid, est))

    # Then sort the predictions for each user and retrieve the k highest ones.
    for uid, user_ratings in top_n.items():
        user_ratings.sort(key=lambda x: x[1], reverse=True)
        top_n[uid] = user_ratings[:n]

    return top_n

In [35]:
top_n = get_top_n(predictions, n=10)

In [36]:
top_n

defaultdict(list,
            {'196': [('408', 4.7417190486869965),
              ('169', 4.627054863786147),
              ('483', 4.625900820473185),
              ('515', 4.606966350167425),
              ('64', 4.585969732457447),
              ('272', 4.558812409891243),
              ('357', 4.492776457734352),
              ('178', 4.492490986519911),
              ('114', 4.487780517063499),
              ('513', 4.481873381124203)],
             '186': [('169', 4.788906916721228),
              ('408', 4.783645238267988),
              ('483', 4.665808448015972),
              ('318', 4.494090581720603),
              ('513', 4.47965420899484),
              ('97', 4.462923445843806),
              ('114', 4.460632005800639),
              ('1449', 4.438497940913597),
              ('527', 4.435679706902444),
              ('313', 4.4318850712354685)],
             '22': [('114', 4.793059182757055),
              ('56', 4.747266283140541),
              ('98', 4.73824078489403

In [37]:
# Print the recommended items for each user
for uid, user_ratings in top_n.items():
    print(uid, [iid for (iid, _) in user_ratings])

196 ['408', '169', '483', '515', '64', '272', '357', '178', '114', '513']
186 ['169', '408', '483', '318', '513', '97', '114', '1449', '527', '313']
22 ['114', '56', '98', '483', '64', '1142', '272', '169', '313', '496']
244 ['515', '127', '12', '427', '14', '285', '654', '480', '483', '489']
166 ['204', '408', '483', '169', '133', '210', '480', '513', '199', '114']
298 ['169', '408', '272', '114', '64', '316', '83', '515', '659', '963']
115 ['514', '134', '488', '199', '483', '156', '168', '1194', '197', '408']
253 ['313', '114', '515', '657', '520', '408', '302', '176', '272', '603']
305 ['515', '114', '124', '19', '606', '57', '603', '488', '1194', '745']
6 ['654', '603', '179', '48', '657', '190', '114', '428', '659', '251']
62 ['169', '408', '484', '647', '657', '175', '1137', '963', '478', '272']
286 ['320', '178', '302', '318', '480', '474', '603', '657', '8', '1142']
200 ['603', '657', '194', '133', '408', '12', '511', '923', '520', '1142']
210 ['318', '480', '178', '408', '48'