-
Notifications
You must be signed in to change notification settings - Fork 1k
/
precision_recall_at_k.py
63 lines (44 loc) · 2.22 KB
/
precision_recall_at_k.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
This module illustrates how to compute Precision at k and Recall at k metrics.
"""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from collections import defaultdict
from surprise import Dataset
from surprise import SVD
from surprise.model_selection import KFold
def precision_recall_at_k(predictions, k=10, threshold=3.5):
"""Return precision and recall at k metrics for each user"""
# First map the predictions to each user.
user_est_true = defaultdict(list)
for uid, _, true_r, est, _ in predictions:
user_est_true[uid].append((est, true_r))
precisions = dict()
recalls = dict()
for uid, user_ratings in user_est_true.items():
# Sort user ratings by estimated value
user_ratings.sort(key=lambda x: x[0], reverse=True)
# Number of relevant items
n_rel = sum((true_r >= threshold) for (_, true_r) in user_ratings)
# Number of recommended items in top k
n_rec_k = sum((est >= threshold) for (est, _) in user_ratings[:k])
# Number of relevant and recommended items in top k
n_rel_and_rec_k = sum(((true_r >= threshold) and (est >= threshold))
for (est, true_r) in user_ratings[:k])
# Precision@K: Proportion of recommended items that are relevant
# When n_rec_k is 0, Precision is undefined. We here set it to 0.
precisions[uid] = n_rel_and_rec_k / n_rec_k if n_rec_k != 0 else 0
# Recall@K: Proportion of relevant items that are recommended
# When n_rel is 0, Recall is undefined. We here set it to 0.
recalls[uid] = n_rel_and_rec_k / n_rel if n_rel != 0 else 0
return precisions, recalls
data = Dataset.load_builtin('ml-100k')
kf = KFold(n_splits=5)
algo = SVD()
for trainset, testset in kf.split(data):
algo.fit(trainset)
predictions = algo.test(testset)
precisions, recalls = precision_recall_at_k(predictions, k=5, threshold=4)
# Precision and recall can then be averaged over all users
print(sum(prec for prec in precisions.values()) / len(precisions))
print(sum(rec for rec in recalls.values()) / len(recalls))