-
Notifications
You must be signed in to change notification settings - Fork 0
/
recommend.py
32 lines (23 loc) · 916 Bytes
/
recommend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import PIL
def get_similar(embedding, k):
model_similar_items = NearestNeighbors(n_neighbors=k, algorithm="ball_tree").fit(embedding)
distances, indices = model_similar_items.kneighbors(embedding)
return distances, indices
def show_similar(item_index, item_similar_indices, item_encoder):
s = item_similar_indices[item_index]
movie_ids = item_encoder.inverse_transform(s)
images = []
for movie_id in movie_ids:
img_path = 'data/posters/' + str(movie_id) + '.jpg'
images.append(mpimg.imread(img_path))
plt.figure(figsize=(20,10))
columns = 5
for i, image in enumerate(images):
plt.subplot(len(images) / columns + 1, columns, i + 1)
plt.axis('off')
plt.imshow(image)