-
Notifications
You must be signed in to change notification settings - Fork 1k
/
k_nearest_neighbors.py
57 lines (44 loc) · 1.85 KB
/
k_nearest_neighbors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""
This module illustrates how to retrieve the k-nearest neighbors of an item. The
same can be done for users with minor changes. There's a lot of boilerplate
because of the id conversions, but it all boils down to the use of
algo.get_neighbors().
"""
# needed because of weird encoding of u.item file
import io # noqa
from surprise import Dataset, get_dataset_dir, KNNBaseline
def read_item_names():
"""Read the u.item file from MovieLens 100-k dataset and return two
mappings to convert raw ids into movie names and movie names into raw ids.
"""
file_name = get_dataset_dir() + "/ml-100k/ml-100k/u.item"
rid_to_name = {}
name_to_rid = {}
with open(file_name, encoding="ISO-8859-1") as f:
for line in f:
line = line.split("|")
rid_to_name[line[0]] = line[1]
name_to_rid[line[1]] = line[0]
return rid_to_name, name_to_rid
# First, train the algorithm to compute the similarities between items
data = Dataset.load_builtin("ml-100k")
trainset = data.build_full_trainset()
sim_options = {"name": "pearson_baseline", "user_based": False}
algo = KNNBaseline(sim_options=sim_options)
algo.fit(trainset)
# Read the mappings raw id <-> movie name
rid_to_name, name_to_rid = read_item_names()
# Retrieve inner id of the movie Toy Story
toy_story_raw_id = name_to_rid["Toy Story (1995)"]
toy_story_inner_id = algo.trainset.to_inner_iid(toy_story_raw_id)
# Retrieve inner ids of the nearest neighbors of Toy Story.
toy_story_neighbors = algo.get_neighbors(toy_story_inner_id, k=10)
# Convert inner ids of the neighbors into names.
toy_story_neighbors = (
algo.trainset.to_raw_iid(inner_id) for inner_id in toy_story_neighbors
)
toy_story_neighbors = (rid_to_name[rid] for rid in toy_story_neighbors)
print()
print("The 10 nearest neighbors of Toy Story are:")
for movie in toy_story_neighbors:
print(movie)