In [1]:
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from torch import nn, optim
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.data import DataLoader
from tqdm import tqdm

import wandb
from model.encoder import Encoder, MultiLayerPerceptron
from model.recommender import DeepFM
from utils.data import ContentDataset, DescriptionsDataset, RequestsDataset, train_test_split_requests
from utils.loss import EncoderCriterion, JointCriterion, RecommenderCriterion
from utils.misc import pairwise_cosine_distance
from utils.metric import get_reid_metrics

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
movies = pd.read_csv("data/ml-20m/movies.csv", header=0, names=["movie_id", "movie_title", "genres"])

movies = movies[["movie_id", "movie_title"]]

# Load requests
requests = pd.read_csv('data/ml-20m/requests.csv')
requests = requests.groupby("movie_id").agg({
    "movie_title": "first",
    "request": list,
}).reset_index()
requests.set_index("movie_id", inplace=True, drop=False)

# Load descriptions
descriptions = pd.read_csv("data/ml-20m/descriptions.csv")
descriptions.set_index("movie_id", inplace=True, drop=False)

In [3]:
train_size = 0.8
batch_size = 32

train_requests, test_requests = train_test_split_requests(requests, train_size=train_size)

train_dataset = ContentDataset(descriptions, train_requests)
test_dataset = ContentDataset(descriptions, test_requests)
descriptions_dataset = DescriptionsDataset(descriptions)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
descriptions_dataloader = DataLoader(descriptions_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

encoder = Encoder().to(device)

encoder.eval()

encoder.load_state_dict(torch.load("weights/encoder/encoder.pt", map_location=device))

  encoder.load_state_dict(torch.load("weights/encoder/encoder.pt", map_location=device))


<All keys matched successfully>

In [5]:
request_embeddings = []
request_logits = []
request_item_ids = []

with torch.no_grad():
    for anchor, positive, negative in tqdm(test_dataloader, desc=f"Validation (Epoch {1})"):
        anchor_requests, anchor_ids = anchor 
        positive_descriptions, positive_ids = positive
        negative_requests, negative_ids = negative 

        anchor_embeddings = encoder(anchor_requests)

        request_embeddings.append(anchor_embeddings.cpu())
        request_item_ids.append(anchor_ids)

request_embeddings = torch.cat(request_embeddings)
request_item_ids = torch.cat(request_item_ids)

Validation (Epoch 1): 100%|██████████| 706/706 [00:20<00:00, 33.84it/s]


In [6]:
description_embeddings = []
description_item_ids = []

with torch.no_grad():
    for movie_ids, descriptions in tqdm(descriptions_dataloader):
        description_embeddings.append(encoder(descriptions).cpu())
        description_item_ids.append(movie_ids)

description_embeddings = torch.cat(description_embeddings)
description_item_ids = torch.cat(description_item_ids)

100%|██████████| 353/353 [00:34<00:00, 10.33it/s]


In [7]:
get_reid_metrics((request_embeddings, request_item_ids), (description_embeddings, description_item_ids))

{'reid_map': 0.01666228659451008,
 'rank-1': 0.005402532871812582,
 'rank-5': 0.019750243052840233,
 'rank-10': 0.032326631247997284}