In [5]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from model.encoder import Encoder
from utils.data import DescriptionsDataset
from utils.misc import cosine_distance

In [2]:
movies = pd.read_csv("data/ml-20m/movies.csv", header=0, names=["movie_id", "movie_title", "genres"])

movies = movies[["movie_id", "movie_title"]]

descriptions = pd.read_csv("data/ml-20m/descriptions.csv")

descriptions_dataset = DescriptionsDataset(descriptions)

descriptions_loader = DataLoader(descriptions_dataset, batch_size=32, shuffle=False)

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

encoder = Encoder().to(device)

encoder.load_state_dict(torch.load("weights/encoder/encoder.pt", map_location=device))



<All keys matched successfully>

In [6]:
movie_embeddings = []

with torch.no_grad():
    for movie_ids, descriptions in tqdm(descriptions_loader):
        description_embeddings = encoder(descriptions)

        movie_embeddings.append(description_embeddings.cpu())

movie_embeddings = torch.cat(movie_embeddings)

100%|██████████| 353/353 [00:38<00:00,  9.11it/s]


In [17]:
request = "I want to watch a romantic comedy."

request_embedding = encoder(request).cpu()

distances = cosine_distance(request_embedding, movie_embeddings)

_, indices = torch.topk(distances, k=10, largest=False)

movies.iloc[indices]

Unnamed: 0,movie_id,movie_title
4072,4722,All Over the Guy (2001)
1969,2292,Overnight Delivery (1998)
1262,1457,Fools Rush In (1997)
2094,2424,You've Got Mail (1998)
4404,5123,"Touch of Class, A (1973)"
6340,8024,"Thing Called Love, The (1993)"
9861,76147,I Hate Valentine's Day (2009)
2811,3225,Down to You (2000)
6009,7270,Sleep with Me (1994)
10280,86817,Something Borrowed (2011)
