# Importing Modules

In [1]:
import os
import time
import math
from typing import Any, Dict, List, Optional, Tuple
from tempfile import TemporaryDirectory
import json

import numpy as np
import polars as pl
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Data Loading

In [2]:
movies_df = pl.read_parquet('../data/processed/output.parquet')
train_df = pl.read_parquet('../data/processed/train.parquet')
test_df = pl.read_parquet('../data/processed/test.parquet')
vocabs = torch.load('../data/processed/all_vocabs.pth')

user_vocab = vocabs["user_vocab"]
movie_vocab = vocabs["movie_vocab"]
genres_vocab = vocabs["genres_vocab"]
prod_comp_vocab = vocabs["prod_comp_vocab"]
prod_countries_vocab = vocabs["prod_countries_vocab"]
languages_vocab = vocabs["languages_vocab"]
words_vocab = vocabs["words_vocab"]

vocabs = {
    "user_vocab": user_vocab,
    "movie_vocab": movie_vocab,
    "genres_vocab": genres_vocab,
    "prod_comp_vocab": prod_comp_vocab,
    "prod_countries_vocab": prod_countries_vocab,
    "languages_vocab": languages_vocab,
    "words_vocab": words_vocab,
}

movie_vocab_stoi = movie_vocab.get_stoi()
user_vocab_stoi = user_vocab.get_stoi()



In [3]:
print(f"Movies DF shape: {movies_df.shape}")
print(f"Train DF shape: {train_df.shape}")
print(f"Test DF shape: {test_df.shape}")
print("Vocab sizes:", {k: len(v) for k, v in vocabs.items()})
print("\nFirst movie row:", movies_df[0].to_dict(as_series=False))
print("First sequence row:", train_df[0].to_dict(as_series=False))

Movies DF shape: (86493, 15)
Train DF shape: (100000, 3)
Test DF shape: (1000, 3)
Vocab sizes: {'user_vocab': 200949, 'movie_vocab': 86494, 'genres_vocab': 21, 'prod_comp_vocab': 45546, 'prod_countries_vocab': 201, 'languages_vocab': 164, 'words_vocab': 270246}

First movie row: {'movieId_idx': [29614], 'genres_idx': [[9, 12, 1]], 'production_companies_idx': [[5787, 10790, 33238]], 'production_countries_idx': [[199, 179]], 'spoken_languages_idx': [[6, 23, 137, 47]], 'keywords_idx': [[16308, 175809, 93150, 146205, 199438, 206752, 33269, 147659, 198442, 112495, 215262, 104288, 10122, 84395, 101318, 37337, 35868, 263713, 52720, 82866]], 'overview_idx': [[176485, 24219, 129031, 130595, 204009, 102295, 187482, 46191, 265111, 75891, 120691, 82866, 7607, 80603, 210357, 60107, 248412, 24219, 171073, 185742, 207323, 80603, 208294, 186120, 230514, 250658, 48946, 24219, 162914, 55906, 185742, 145939, 86280, 230898, 120691, 50712, 7607, 233356, 69050, 257026, 247940, 24219, 99177, 12348]], 'taglin

In [4]:
# --- Get Vocab Sizes for Model ---
len_genres = len(genres_vocab)
len_prod_comp = len(prod_comp_vocab)
len_prod_cont = len(prod_countries_vocab)
len_langs = len(languages_vocab)
len_words = len(words_vocab)
n_items = len(movie_vocab)

In [5]:
train_df.schema

Schema([('userId', String),
        ('sequence_movie_ids', List(String)),
        ('sequence_ratings', List(Float64))])

In [32]:
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

class MovieSeqDataset(Dataset):
    def __init__(self, data, movie_vocab_stoi, user_vocab_stoi):
        self.data = data
        self.movie_vocab_stoi = movie_vocab_stoi
        self.user_vocab_stoi = user_vocab_stoi
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        user, movie_sequence, rating_sequence = self.data[idx]
        movie_data = [self.movie_vocab_stoi.get(item,movie_vocab_stoi['<unk>']) for item in movie_sequence.to_list()[0]]
        user_data = self.user_vocab_stoi[user.to_list()[0]]
        rating_sequence = rating_sequence.to_list()[0]
        return torch.tensor(movie_data[:-1],device='cuda'), torch.tensor(user_data), torch.tensor(rating_sequence[:-1]), torch.tensor(movie_data[-1],device='cuda')

def collate_batch(batch):
    movie_list = [item[0] for item in batch]
    user_list = [item[1] for item in batch]
    rating_list = [item[2] for item in batch]
    target_list = [item[3] for item in batch]
    return pad_sequence(movie_list, padding_value=movie_vocab_stoi['<unk>'], batch_first=True), torch.stack(user_list), pad_sequence(rating_list, padding_value=3, batch_first=True), torch.stack(target_list)

In [33]:
BATCH_SIZE = 256

train_dataset = MovieSeqDataset(train_df, movie_vocab_stoi, user_vocab_stoi)
val_dataset = MovieSeqDataset(test_df, movie_vocab_stoi, user_vocab_stoi)

train_iter = DataLoader(train_dataset, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)
val_iter = DataLoader(val_dataset, batch_size=BATCH_SIZE,shuffle=False, collate_fn=collate_batch)

In [35]:
for i, (movie_data, user_data, ratings_data, y_train) in enumerate(train_iter):
    print(movie_data.shape, user_data.shape, ratings_data.shape, y_train.shape)
    print(movie_data[0])
    break

torch.Size([256, 11]) torch.Size([256]) torch.Size([256, 11]) torch.Size([256])
tensor([15837, 27579, 24675, 78108, 33476, 78820, 40442,  2674, 21279, 76577,
        71546], device='cuda:0')


# Embeddings

In [196]:
import torch 
from torch import nn
from typing import Tuple

class MovieEmbeddings(nn.Module):
    def __init__(self, 
                 d_model: int,
                 hidden_size: int,
                 num_list_features: int,
                 num_scalar_features: int,
                 n_genres: int, 
                 n_production_companies: int,
                 n_production_countries: int,
                 n_spoken_languages: int,
                 n_words: int):
        super().__init__()
        self.genres_embedding = nn.EmbeddingBag(n_genres, d_model*2, mode='mean')
        self.prod_comp_embedding = nn.EmbeddingBag(n_production_companies, d_model, mode='mean')
        self.prod_cont_embedding = nn.EmbeddingBag(n_production_countries, d_model, mode='mean')
        self.lang_embedding = nn.EmbeddingBag(n_spoken_languages, d_model, mode='mean')
        self.word_embedding = nn.EmbeddingBag(n_words, d_model*4, mode='mean')
        self.fc = nn.Linear(d_model*(10+num_list_features)+num_scalar_features,hidden_size)
        self._init_weights()

    def _init_weights(self) -> None:
        nn.init.xavier_uniform_(self.genres_embedding.weight)
        nn.init.xavier_uniform_(self.prod_comp_embedding.weight)
        nn.init.xavier_uniform_(self.prod_cont_embedding.weight)
        nn.init.xavier_uniform_(self.lang_embedding.weight)
        nn.init.xavier_uniform_(self.word_embedding.weight)
        nn.init.xavier_uniform_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)

    def _prepare_embedding_inputs(self, list_of_lists) -> Tuple[torch.Tensor, torch.Tensor]:
        flat_list = []
        offsets = [0]
        for sublist in list_of_lists:
            flat_list.extend(sublist)
            offsets.append(offsets[-1] + len(sublist))
        offsets = offsets[:-1]  # Remove last cumulative sum
        offsets = torch.tensor(offsets, dtype=torch.long,device='cuda')
        flat_list = torch.tensor(flat_list, dtype=torch.long,device='cuda')
        return flat_list, offsets   

    def forward(self, row: pl.DataFrame) -> torch.Tensor:
        genres, genres_offsets = self._prepare_embedding_inputs(row['genres_idx'])
        genres_e = self.genres_embedding(genres, genres_offsets)

        comp, comp_offsets = self._prepare_embedding_inputs(row['production_companies_idx'])
        comp_e = self.prod_comp_embedding(comp, comp_offsets)

        cont, cont_offsets = self._prepare_embedding_inputs(row['production_countries_idx'])
        cont_e = self.prod_cont_embedding(cont, cont_offsets)

        lang, lang_offsets = self._prepare_embedding_inputs(row['spoken_languages_idx'])
        lang_e = self.lang_embedding(lang, lang_offsets)

        kw, kw_offsets = self._prepare_embedding_inputs(row['keywords_idx'])
        kw_e = self.word_embedding(kw, kw_offsets)

        tag, tag_offsets = self._prepare_embedding_inputs(row['tagline_idx'])
        tag_e = self.word_embedding(tag, tag_offsets)

        ov, ov_offsets = self._prepare_embedding_inputs(row['overview_idx'])
        ov_e = self.word_embedding(ov, ov_offsets)

        # Scalar features as tensors (ensure shape is [batch_size, 1])
        revenue = torch.tensor(row["revenue"], dtype=torch.float32,device='cuda').unsqueeze(1)
        budget = torch.tensor(row["budget"], dtype=torch.float32,device='cuda').unsqueeze(1)
        runtime = torch.tensor(row["runtime"], dtype=torch.float32,device='cuda').unsqueeze(1)
        adult_idx = torch.tensor(row["adult_idx"], dtype=torch.bool,device='cuda').unsqueeze(1)
        vote_average = torch.tensor(row["vote_average"], dtype=torch.float32,device='cuda').unsqueeze(1)
        vote_count = torch.tensor(row["vote_count"], dtype=torch.float32,device='cuda').unsqueeze(1)
        popularity = torch.tensor(row["popularity"], dtype=torch.float32,device='cuda').unsqueeze(1)

        # Concatenate all embeddings and scalar features
        master_embedding = torch.cat([
            genres_e,
            comp_e,
            cont_e,
            lang_e,
            kw_e,
            tag_e,
            ov_e,
            revenue,
            budget,
            runtime,
            adult_idx,
            vote_average,
            vote_count,
            popularity
        ], dim=1)

        return self.fc(master_embedding)

# Model

In [226]:
class TRXTransformer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, num_layers: int):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.num_layers = num_layers

        self.transformer_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, batch_first=True) for _ in range(num_layers)])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x)
        # x shape: (batch_size, seq_len, d_model)
        return x

In [240]:
class TRXModel(nn.Module):
    def __init__(self,
                 d_model: int,
                 n_heads: int,
                 num_layers: int,
                 hidden_size: int, # This is the output dimension of MovieEmbeddings' FC layer
                 num_list_features: int,
                 num_scalar_features: int,
                 n_genres: int,
                 n_production_companies: int,
                 n_production_countries: int,
                 n_spoken_languages: int,
                 movie_vocab_stoi: Dict[str, int],
                 user_vocab_stoi: Dict[str, int],
                 n_movies: int,
                 n_words: int,
                 fc_size: int = 512,
                 seq_len: int = 11,
                 dropout_rate: float = 0.1):
        super().__init__()
        self.hidden_size = hidden_size
        self.d_model = d_model
        self.n_heads = n_heads
        self.num_layers = num_layers
        self.num_list_features = num_list_features
        self.num_scalar_features = num_scalar_features
        self.n_genres = n_genres
        self.n_production_companies = n_production_companies
        self.n_production_countries = n_production_countries
        self.n_spoken_languages = n_spoken_languages
        self.n_movies = n_movies
        self.n_words = n_words
        self.fc_size = fc_size
        self.seq_len = seq_len
        self.dropout_rate = dropout_rate

        self.movie_embeddings = MovieEmbeddings(d_model, hidden_size, num_list_features, num_scalar_features, n_genres, n_production_companies, n_production_countries, n_spoken_languages, n_words)
        self.movie_vocab_stoi = movie_vocab_stoi
        self.user_vocab_stoi = user_vocab_stoi

        self.transformer_encoder = TRXTransformer(hidden_size, n_heads=n_heads, num_layers=num_layers)

        self.fc1 = nn.Linear(hidden_size*seq_len, fc_size)
        self.fc2 = nn.Linear(fc_size, n_movies)
        self.dropout = nn.Dropout(dropout_rate)
        self._init_weights()


    def _init_weights(self) -> None:
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)


    def forward(self, x: torch.Tensor, movies_df: pl.DataFrame) -> torch.Tensor:
        # x: [batch_size, seq_len]
        batch_size, seq_len = x.shape

        sequence_movie_rows: List[pl.DataFrame] = [movies_df.filter(pl.col('movieId_idx').is_in(x[i].cpu().tolist())) for i in range(batch_size)]
        embeddings = [self.movie_embeddings(row) for row in sequence_movie_rows] # List of tensors [seq_len, hidden_size]

        # Stack and pad embeddings if sequences are different sizes
        max_len = max([emb.shape[0] for emb in embeddings]) if embeddings else 0

        # Handle empty batch or empty sequences gracefully
        if max_len == 0:
            return torch.zeros(batch_size, self.fc2.out_features, device=x.device)

        padded_embeddings = []

        for i in range(batch_size):
            current_seq_len = embeddings[i].shape[0]
            if current_seq_len < max_len:
                # Pad with zeros along the sequence dimension - NOW self.hidden_size is available
                pad_tensor = torch.zeros(max_len - current_seq_len, self.hidden_size, device='cuda')
                padded_embeddings.append(torch.cat([embeddings[i], pad_tensor], dim=0))
            else:
                padded_embeddings.append(embeddings[i][:max_len]) # Truncate if somehow longer

        embeddings = torch.stack(padded_embeddings, dim=0)
        # embeddings shape: [batch_size, max_len, hidden_size]

        # Pass Embeddings through the Transformer Encoder
        transformer_output = self.transformer_encoder(embeddings)
        # transformer_output shape: [batch_size, max_len, hidden_size]

        # Reshape transformer_output to feed into fc1
        transformer_output = transformer_output.view(transformer_output.size(0), -1)
        # transformer_output shape: [batch_size, max_len * hidden_size]

        # Pass to fc1
        fc1_output = self.fc1(transformer_output)
        # fc1_output shape: [batch_size, fc_size]

        # Apply dropout
        fc1_output = self.dropout(fc1_output)

        # Pass to fc2 (outputs scores for all n_movies)
        fc2_output = self.fc2(fc1_output)
        # fc2_output shape: [batch_size, n_movies]

        return fc2_output # Return raw logits

In [241]:
def sample_negatives(positive_ids: torch.Tensor, n_movies: int, num_negatives: int) -> torch.Tensor:
    """
    Samples negative movie indices for each positive ID in the batch.
    Ensures that negative samples do not include the corresponding positive ID.

    Args:
        positive_ids (torch.Tensor): Tensor of shape [batch_size] with positive movie indices.
        n_movies (int): Total number of movies in the vocabulary.
        num_negatives (int): Number of negative samples per positive.

    Returns:
        torch.Tensor: Tensor of shape [batch_size, num_negatives] with negative movie indices.
    """
    batch_size = positive_ids.size(0)
    num_negatives = min(num_negatives, n_movies - 1)  # Ensure we don't ask for more than possible

    # Create tensor of all possible movie IDs
    all_movie_ids = torch.arange(n_movies, device=positive_ids.device)

    # Allocate tensor for results
    final_negative_ids = torch.empty(batch_size, num_negatives, dtype=torch.long, device=positive_ids.device)

    for i in range(batch_size):
        pos_id = positive_ids[i]  # No .item(), stays on GPU
        # Exclude the positive ID from possible negatives
        mask = all_movie_ids != pos_id
        possible_negatives = all_movie_ids[mask]

        if possible_negatives.numel() < num_negatives:
            # Not enough negatives — warn and pad with zeros
            print(f"Warning: Not enough negatives for index {i}, only {possible_negatives.numel()} available.")
            sampled = torch.randperm(possible_negatives.numel(), device=positive_ids.device)
            final_negative_ids[i, :possible_negatives.numel()] = possible_negatives[sampled]
            final_negative_ids[i, possible_negatives.numel():] = 0  # Assumes 0 is a valid ID or placeholder
        else:
            # Sample without replacement from possible_negatives
            sampled = torch.randperm(possible_negatives.numel(), device=positive_ids.device)[:num_negatives]
            final_negative_ids[i] = possible_negatives[sampled]

    return final_negative_ids

In [242]:
model = TRXModel(
    d_model=8,
    n_heads=4,
    num_layers=4,
    hidden_size=64,
    num_list_features=7,
    num_scalar_features=7,
    n_genres=len_genres, 
    n_production_companies=len_prod_comp, 
    n_production_countries=len_prod_cont, 
    n_spoken_languages=len_langs, 
    movie_vocab_stoi=movie_vocab_stoi,
    user_vocab_stoi=user_vocab_stoi,
    n_movies=len(movie_vocab_stoi),
    n_words=len_words,
    fc_size=32
).to('cuda')

for idx, (movie_data, user_data, ratings_data, y_train) in enumerate(train_iter):
    print(movie_data.shape, user_data.shape, ratings_data.shape, y_train.shape)
    print(movie_data[0])
    y = model(movie_data, movies_df)
    print(y.shape)
    break

torch.Size([256, 11]) torch.Size([256]) torch.Size([256, 11]) torch.Size([256])
tensor([72489, 76417, 50429, 43890, 68666, 85271, 41492, 12780, 65928, 26366,
        54896], device='cuda:0')
torch.Size([256, 86494])


In [243]:
# Total parameters in the model
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters in the model: {total_params / 1e6:.2f}M")

Total parameters in the model: 13.03M


## Inference

In [286]:
def predict_topk_movies(model: TRXModel,
                       sequence_movie_ids: torch.Tensor,
                       movies_df: pl.DataFrame,
                       movie_vocab_itos: Dict[int, str],
                       top_k: int = 10,
                       device: torch.device = torch.device('cuda')) -> List[List[Tuple[str, float]]]:
    """
    Predicts the top-k next movies for a batch of input sequences.

    Args:
        model (TRXModel): The trained TRXModel.
        sequence_movie_ids (torch.Tensor): Input tensor of shape [batch_size, seq_len]
                                           containing movie indices for sequences.
        movies_df (pl.DataFrame): The Polars DataFrame containing movie features.
                                  Required by the model's forward pass.
        movie_vocab_itos (Dict[int, str]): Dictionary mapping movie index to original ID/string.
        top_k (int): The number of top recommendations to return for each sequence.
        device (torch.device): The device the model is on ('cuda' or 'cpu').

    Returns:
        List[List[Tuple[str, float]]]: A list of lists. Each inner list contains
                                       (movie_id_string, score) tuples for the
                                       top-k recommendations for one input sequence in the batch.
    """
    model.eval() # Set the model to evaluation mode
    sequence_movie_ids = sequence_movie_ids.to(device)

    with torch.no_grad():
        all_movie_logits = model(sequence_movie_ids, movies_df) # Shape: [batch_size, n_movies]
        topk_scores, topk_indices = torch.topk(all_movie_logits, k=top_k, dim=1) # Both shape: [batch_size, top_k]

    # Convert results back to Python lists and map indices to original movie IDs/strings
    results = []
    topk_indices = topk_indices.cpu().tolist()
    topk_scores = topk_scores.cpu().tolist()

    for i in range(len(topk_indices)):
        sequence_results = []
        for j in range(top_k):
            movie_index = topk_indices[i][j]
            score = topk_scores[i][j]
            # Map the index back to the original movie ID string
            movie_id_string = movie_vocab_itos[movie_index] # Handle unknown indices

            sequence_results.append((movie_id_string, score))
        results.append(sequence_results)

    return results

In [224]:
movie_vocab_itos[44916]

'movie_64614'

In [246]:
dummy_sequences_idx = torch.tensor([
    [44916, 27309,  1330, 34067, 10537, 84627, 43299, 6341, 52722, 1217, 75094]
], dtype=torch.long) # Shape [1, 5]

top_k = 5
movie_vocab_itos = movie_vocab.get_itos()
recommendations = predict_topk_movies(model, dummy_sequences_idx, movies_df, movie_vocab_itos, top_k, device='cuda')

# Print the recommendations
for i, recs in enumerate(recommendations):
    print(f"Recommendations for sequence {i+1}:")
    for movie_id, score in recs:
        print(f"  Movie ID: {movie_id}, Score: {score:.4f}")
    print("-" * 20)

Recommendations for sequence 1:
  Movie ID: movie_1210, Score: 7.0559
  Movie ID: movie_2028, Score: 6.8235
  Movie ID: movie_40815, Score: 6.5902
  Movie ID: movie_4886, Score: 6.4509
  Movie ID: movie_1291, Score: 6.3563
--------------------


# Training Loop

In [302]:
EPOCHS = 3
NUM_NEGATIVES = 20

model = TRXModel(
    d_model=8,
    n_heads=4,
    num_layers=4,
    hidden_size=64,
    num_list_features=7,
    num_scalar_features=7,
    n_genres=len_genres, 
    n_production_companies=len_prod_comp, 
    n_production_countries=len_prod_cont, 
    n_spoken_languages=len_langs, 
    movie_vocab_stoi=movie_vocab_stoi,
    user_vocab_stoi=user_vocab_stoi,
    n_movies=len(movie_vocab_stoi),
    n_words=len_words,
    fc_size=32
)
model.to('cuda')

TRXModel(
  (movie_embeddings): MovieEmbeddings(
    (genres_embedding): EmbeddingBag(21, 16, mode='mean')
    (prod_comp_embedding): EmbeddingBag(45546, 8, mode='mean')
    (prod_cont_embedding): EmbeddingBag(201, 8, mode='mean')
    (lang_embedding): EmbeddingBag(164, 8, mode='mean')
    (word_embedding): EmbeddingBag(270246, 32, mode='mean')
    (fc): Linear(in_features=143, out_features=64, bias=True)
  )
  (transformer_encoder): TRXTransformer(
    (transformer_blocks): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwi

In [303]:
# Define your optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Example optimizer
criterion = nn.CrossEntropyLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [304]:
def train(epoch):
    for idx, (sequence_movie_ids, user_data, ratings_data, target_movie_id) in enumerate(train_iter):
        optimizer.zero_grad()

        sequence_movie_ids = sequence_movie_ids.to('cuda')
        target_movie_id = target_movie_id.to('cuda') # This is the positive target index

        all_movie_logits = model(sequence_movie_ids, movies_df) # Shape: [batch_size, n_movies]
        negative_movie_ids = sample_negatives(target_movie_id, model.n_movies, NUM_NEGATIVES) # Shape: [batch_size, NUM_NEGATIVES]

        positive_logits = all_movie_logits[torch.arange(all_movie_logits.size(0), device=all_movie_logits.device), target_movie_id] # Shape: [batch_size]
        negative_logits = torch.gather(all_movie_logits, 1, negative_movie_ids) # Shape: [batch_size, NUM_NEGATIVES]
        sampled_logits = torch.cat([positive_logits.unsqueeze(1), negative_logits], dim=1)

        targets = torch.zeros(sampled_logits.size(0), dtype=torch.long, device=sampled_logits.device)
        loss = criterion(sampled_logits, targets)

        # Backpropagation
        loss.backward()
        optimizer.step()

        if idx % 50 == 0:
            print(f"Epoch {epoch}, Batch {idx}, Loss: {loss.item()}")

In [305]:
def evaluate(model: nn.Module, eval_iter) -> float:
    model.eval()
    total_loss = 0.

    with torch.no_grad():
        for i, (sequence_movie_ids, user_data, ratings_data, target_movie_id) in tqdm(enumerate(eval_iter)):
            sequence_movie_ids = sequence_movie_ids.to('cuda')
            target_movie_id = target_movie_id.to('cuda')

            all_movie_logits = model(sequence_movie_ids, movies_df)  # [batch_size, n_movies]

            negative_movie_ids = sample_negatives(target_movie_id, model.n_movies, NUM_NEGATIVES)  # [batch_size, num_negatives]
            positive_logits = all_movie_logits[torch.arange(all_movie_logits.size(0), device=all_movie_logits.device), target_movie_id]  # [batch_size]
            negative_logits = torch.gather(all_movie_logits, 1, negative_movie_ids)  # [batch_size, num_negatives]

            sampled_logits = torch.cat([positive_logits.unsqueeze(1), negative_logits], dim=1)  # [batch_size, 1 + num_negatives]
            targets = torch.zeros(sampled_logits.size(0), dtype=torch.long, device=sampled_logits.device)  # [batch_size]

            loss = criterion(sampled_logits, targets)
            total_loss += loss.item()

    return total_loss / len(eval_iter)


In [None]:
with TemporaryDirectory() as tempdir:
    best_model_params_path = os.path.join(tempdir, "best_model_params.pt")

    for epoch in range(1, EPOCHS + 1):
        epoch_start_time = time.time()

        # Training
        train(epoch)

        # Evaluation
        val_loss = evaluate(model, val_iter)

        # Compute the perplexity of the validation loss
        val_ppl = math.exp(val_loss)
        elapsed = time.time() - epoch_start_time

        # Results
        print('-' * 89)
        print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
            f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
        print('-' * 89)

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), best_model_params_path)

        # scheduler.step()
    model.load_state_dict(torch.load(best_model_params_path)) # load best model

Epoch 1, Batch 0, Loss: 3.04215931892395


In [299]:
movie_vocab_stoi = movie_vocab.get_stoi()
movie_vocab['movie_122912']

26812

In [301]:
predict_topk_movies(model,
                    torch.tensor([43299, 6341,  78663, 1217, 40518, 1376, 48320, 11031, 84283, 51781, 26812], device='cuda').unsqueeze(0), 
                    movies_df, 
                    movie_vocab_itos, 
                    top_k=5)

[[('movie_47', 7.067640781402588),
  ('<unk>', 7.000823020935059),
  ('movie_78499', 6.7338995933532715),
  ('movie_8368', 6.675086498260498),
  ('movie_8961', 6.662047386169434)]]