In [None]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from model.encoder import Encoder
from utils.data import DescriptionsDataset
from utils.misc import pairwise_cosine_distance
from utils.data import train_test_split_requests

In [2]:
movies = pd.read_csv("data/ml-20m/movies.csv", header=0, names=["movie_id", "movie_title", "genres"])

movies = movies[["movie_id", "movie_title"]]

descriptions = pd.read_csv("data/ml-20m/descriptions.csv")

descriptions_dataset = DescriptionsDataset(descriptions)

descriptions_loader = DataLoader(descriptions_dataset, batch_size=32, num_workers=2, shuffle=False)

In [3]:
requests = pd.read_csv('data/ml-20m/requests.csv')

requests.set_index("movie_id", inplace=True, drop=False)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

encoder = Encoder().to(device)

encoder.eval()

encoder.load_state_dict(torch.load("weights/encoder/encoder.pt", map_location=device))

In [None]:
movie_embeddings = []

with torch.no_grad():
    for movie_ids, descriptions in tqdm(descriptions_loader):
        description_embeddings = encoder(descriptions)

        movie_embeddings.append(description_embeddings.cpu())

movie_embeddings = torch.cat(movie_embeddings)

In [8]:
num_requests = 256
num_movies = len(movie_embeddings)
max_rank = 500

movie_ids = (requests["movie_id"].to_list())[:num_requests]
movie_indices = [np.where(movies["movie_id"] == movie_id)[0][0] for movie_id in movie_ids]
sampled_requests = (requests["request"].to_list())[:num_requests]

In [None]:
request_embeddings = encoder(sampled_requests).cpu()

distances = pairwise_cosine_distance(request_embeddings, movie_embeddings)
matches = torch.zeros(num_requests, num_movies)
matches[torch.arange(num_requests), movie_indices] = 1

distances, indices = torch.sort(distances, dim=-1)
matches = matches.gather(-1, indices)

# Calculate the Average Precision (https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision)
cumulative_sum = torch.cumsum(matches, dim=-1)
precision = cumulative_sum / torch.arange(1, num_movies + 1).unsqueeze(0)
average_precision = (precision * matches).sum(dim=-1) / matches.sum(dim=-1)
average_precision = average_precision.mean()

# Calculate the Cumulative Matching Characteristics (CMC) curve
cumulative_sum[cumulative_sum > 1] = 1
cmc_curve = torch.full((max_rank,), num_requests, dtype=torch.float32)
cmc_curve[:num_movies] = cumulative_sum[:, :max_rank].sum(dim=0).cpu()
cmc_curve = cmc_curve / num_requests

# Plot the CMC curve
import matplotlib.pyplot as plt

plt.plot(range(1, max_rank + 1), cmc_curve)
plt.xlabel("Rank")
plt.ylabel("Cumulative Match Characteristics")
plt.show()

print(f"Average Precision: {average_precision:.4f}")
print(f"Rank 1: {cmc_curve[0]:.4f}")
print(f"Rank 5: {cmc_curve[4]:.4f}")
print(f"Rank 10: {cmc_curve[9]:.4f}")