# Importing Modules

In [23]:
import torch
from torch import nn
import numpy as np
import polars as pl
import math
from typing import List, Dict, Tuple, Optional, Any
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from torch.utils.data import Dataset, DataLoader

# Data Loading

In [14]:
movies_df = pl.read_parquet('../data/processed/output.parquet')
sequence_df = pl.read_parquet('../data/processed/train.parquet')
vocabs = torch.load('../data/processed/all_vocabs.pth')

user_vocab = vocabs["user_vocab"]
movie_vocab = vocabs["movie_vocab"]
genres_vocab = vocabs["genres_vocab"]
prod_comp_vocab = vocabs["prod_comp_vocab"]
prod_countries_vocab = vocabs["prod_countries_vocab"]
languages_vocab = vocabs["languages_vocab"]
words_vocab = vocabs["words_vocab"]

vocabs = {
    "user_vocab": user_vocab,
    "movie_vocab": movie_vocab,
    "genres_vocab": genres_vocab,
    "prod_comp_vocab": prod_comp_vocab,
    "prod_countries_vocab": prod_countries_vocab,
    "languages_vocab": languages_vocab,
    "words_vocab": words_vocab,
}

movie_vocab_stoi = movie_vocab.get_stoi()
user_vocab_stoi = user_vocab.get_stoi()

In [16]:
print(f"Movies DF shape: {movies_df.shape}")
print(f"Sequence DF shape: {sequence_df.shape}")
print("Vocab sizes:", {k: len(v) for k, v in vocabs.items()})
print("\nFirst movie row:", movies_df[0].to_dict(as_series=False))
print("First sequence row:", sequence_df[0].to_dict(as_series=False))

Movies DF shape: (86493, 15)
Sequence DF shape: (13218442, 3)
Vocab sizes: {'user_vocab': 200949, 'movie_vocab': 86494, 'genres_vocab': 21, 'prod_comp_vocab': 45546, 'prod_countries_vocab': 201, 'languages_vocab': 164, 'words_vocab': 270246}

First movie row: {'movieId_idx': [29614], 'genres_idx': [[9, 12, 1]], 'production_companies_idx': [[5787, 10790, 33238]], 'production_countries_idx': [[199, 179]], 'spoken_languages_idx': [[6, 23, 137, 47]], 'keywords_idx': [[16308, 175809, 93150, 146205, 199438, 206752, 33269, 147659, 198442, 112495, 215262, 104288, 10122, 84395, 101318, 37337, 35868, 263713, 52720, 82866]], 'overview_idx': [[176485, 24219, 129031, 130595, 204009, 102295, 187482, 46191, 265111, 75891, 120691, 82866, 7607, 80603, 210357, 60107, 248412, 24219, 171073, 185742, 207323, 80603, 208294, 186120, 230514, 250658, 48946, 24219, 162914, 55906, 185742, 145939, 86280, 230898, 120691, 50712, 7607, 233356, 69050, 257026, 247940, 24219, 99177, 12348]], 'tagline_idx': [[214081, 63

In [17]:
# --- Get Vocab Sizes for Model ---
len_genres = len(genres_vocab)
len_prod_comp = len(prod_comp_vocab)
len_prod_cont = len(prod_countries_vocab)
len_langs = len(languages_vocab)
len_words = len(words_vocab)
n_items = len(movie_vocab) # Total number of items including padding token

In [19]:
# --- Define Padding Values ---
PADDING_MOVIE_ID_INT = movie_vocab_stoi['<unk>'] # Integer ID for padding (should be 0)
RATING_PADDING_VALUE = -1.0 # Rating used for padded positions

assert PADDING_MOVIE_ID_INT == 0, "Padding token for movie vocab should be 0"

In [20]:
# Map movie_id_int (from movies_df) to its features dictionary
movie_features_map = {}

# Define the list and scalar feature columns expected by MovieEmbeddings
list_feature_cols = ['genres_idx', 'production_companies_idx', 'production_countries_idx',
                     'spoken_languages_idx', 'keywords_idx', 'tagline_idx', 'overview_idx']
scalar_feature_cols = ['revenue', 'budget', 'runtime', 'adult_idx',
                       'vote_average', 'vote_count', 'popularity']

# Iterate through movies_df to build the map using movieId_idx as key
for row in movies_df.iter_rows(named=True):
    movie_id_int = row['movieId_idx']
    features = {}
    for col in list_feature_cols + scalar_feature_cols:
        features[col] = row[col]
    movie_features_map[movie_id_int] = features

# Add a special entry for the padding ID (0)
# Features for padding should represent a neutral state
padding_features = {}
for col in list_feature_cols:
    padding_features[col] = [] # Empty list
for col in scalar_feature_cols:
     # Use appropriate default/padding values for scalars (e.g., 0.0)
     padding_features[col] = 0.0
if 'adult_idx' in scalar_feature_cols: # Handle boolean padding
    padding_features['adult_idx'] = False

movie_features_map[PADDING_MOVIE_ID_INT] = padding_features

print(f"Created movie_features_map with {len(movie_features_map)} entries (including padding ID {PADDING_MOVIE_ID_INT}).")

Created movie_features_map with 86494 entries (including padding ID 0).


# Making Dataset

In [29]:
class MovieDataset(Dataset):
    def __init__(self, sequence_df: pl.DataFrame, movie_vocab_stoi: Dict[str, int],
                 padding_id: int = 0, rating_padding_value: float = -1.0):
        self.sequence_df = sequence_df
        self.movie_vocab_stoi = movie_vocab_stoi
        self.padding_id = padding_id
        self.rating_padding_value = rating_padding_value
        self.sequence_length = 5 # Fixed sequence length

    def __len__(self):
        return len(self.sequence_df)

    def __getitem__(self, idx: int) -> Dict[str, List[Any]]:
        row = self.sequence_df[idx].to_dict(as_series=False) # Get row as dict

        # Extract string IDs and ratings lists
        movie_ids_str = row['sequence_movie_ids'][0] # Assuming they are lists of lists from Polars
        ratings = row['sequence_ratings'][0]

        # Convert string movie IDs to integer IDs
        movie_ids_int = [self.movie_vocab_stoi.get(uid, self.padding_id) for uid in movie_ids_str] # Use .get with default padding_id

        # Ensure length is exactly 5, pad if necessary (should be handled in data prep, but safety check)
        current_len = len(movie_ids_int)
        if current_len < self.sequence_length:
             movie_ids_int.extend([self.padding_id] * (self.sequence_length - current_len))
             ratings.extend([self.rating_padding_value] * (self.sequence_length - current_len))
        elif current_len > self.sequence_length:
             movie_ids_int = movie_ids_int[:self.sequence_length]
             ratings = ratings[:self.sequence_length]

        # Return the 5 movie IDs and 5 ratings
        return {'movie_ids': movie_ids_int, 'ratings': ratings}

In [30]:
class MovieCollator:
    def __init__(self, movie_features_map: Dict[int, Dict[str, Any]],
                 list_feature_cols: List[str], scalar_feature_cols: List[str],
                 padding_id: int = 0, rating_padding_value: float = -1.0):
        self.movie_features_map = movie_features_map
        self.list_feature_cols = list_feature_cols
        self.scalar_feature_cols = scalar_feature_cols
        self.padding_id = padding_id
        self.rating_padding_value = rating_padding_value
        self.input_sequence_length = 4 # Input length is fixed at 4

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Custom collate function. Takes a batch of 5-item sequence samples and
        prepares a batch for TRXModel (4 input items, 1 target).
        """
        batch_size = len(batch)
        input_seq_len = self.input_sequence_length # 4
        total_input_items_in_batch = batch_size * input_seq_len

        # --- 4.1 Extract and separate input/target from batch ---
        batch_input_movie_ids: List[List[int]] = [] # List of 4-item lists
        batch_input_ratings: List[List[float]] = [] # List of 4-item lists
        batch_target_item_ids: List[int] = []       # List of 1 target ID

        for sample in batch:
            # Sample has 5 movie_ids and 5 ratings
            batch_input_movie_ids.append(sample['movie_ids'][:input_seq_len]) # First 4 for input
            batch_input_ratings.append(sample['ratings'][:input_seq_len])     # First 4 ratings
            batch_target_item_ids.append(sample['movie_ids'][input_seq_len])  # The 5th item is the target

        # --- 4.2 Determine padding mask for input sequences (length 4) ---
        # Padding mask is True where the item ID is the padding_id (0)
        padding_mask = torch.tensor(
            [[id == self.padding_id for id in seq] for seq in batch_input_movie_ids],
            dtype=torch.bool
        ) # Shape: [Batch_Size, 4]


        # --- 4.3 Collect features for all input items (B*4 total) into a flat pool ---
        # This pool will contain features for all non-padded items.
        # Padded positions in the sequence will map to a special index outside this pool.

        pooled_list_features_raw: Dict[str, List[List[Any]]] = {col: [] for col in self.list_feature_cols}
        pooled_scalar_features_raw: Dict[str, List[Any]] = {col: [] for col in self.scalar_feature_cols}
        sequence_item_pool_map: List[List[int]] = [] # Map (batch_idx, seq_idx) -> pool_idx

        current_pool_index = 0
        # The index for padded items will be the total number of *non-padded* items in the pool.

        for b in range(batch_size):
            current_sequence_map = []
            for s in range(input_seq_len):
                movie_id_int = batch_input_movie_ids[b][s]
                is_padded = padding_mask[b][s].item()

                if is_padded:
                    # Map padded positions to the index *after* the real pool
                    # We'll use actual_items_in_pool as the padding map index later
                    # Add a placeholder index for now, fix it after calculating actual_items_in_pool
                     current_sequence_map.append(-1) # Use a temporary invalid index
                else:
                    # Get features for the actual movie ID
                    features = self.movie_features_map.get(movie_id_int)
                    if features is None:
                        # Fallback for safety, though map should include all relevant IDs + padding
                        print(f"Warning: Movie ID {movie_id_int} not found in movie_features_map. Using padding features.")
                        features = self.movie_features_map[self.padding_id]

                    # Add features to the raw pooled lists/values
                    for col in self.list_feature_cols:
                        pooled_list_features_raw[col].append(features[col])
                    for col in self.scalar_feature_cols:
                        pooled_scalar_features_raw[col].append(features[col])

                    # Map this sequence position to its index in the pool
                    current_sequence_map.append(current_pool_index)
                    current_pool_index += 1 # Increment pool index for the next non-padded item

            sequence_item_pool_map.append(current_sequence_map)

        # The actual number of non-padded items added to the pool
        actual_items_in_pool = current_pool_index

        # Now, fix the padding indices in sequence_item_pool_map to point to actual_items_in_pool
        # This index will map to the dummy zero embedding later
        sequence_item_pool_map_fixed = []
        for seq_map in sequence_item_pool_map:
             fixed_seq_map = [idx if idx != -1 else actual_items_in_pool for idx in seq_map]
             sequence_item_pool_map_fixed.append(fixed_seq_map)

        sequence_item_pool_map_tensor = torch.tensor(sequence_item_pool_map_fixed, dtype=torch.long) # Shape [Batch_Size, 4]


        # --- 4.4 Prepare pooled features tensors for MovieEmbeddings ---
        input_item_pool_data: Dict[str, Tuple[torch.Tensor, torch.Tensor] | torch.Tensor] = {}

        # Process list features
        for col in self.list_feature_cols:
            flat_list = []
            offsets = [0]
            # Only process collected lists for non-padded items (size actual_items_in_pool)
            for sublist in pooled_list_features_raw[col]:
                flat_list.extend(sublist)
                offsets.append(offsets[-1] + len(sublist))
            offsets = offsets[:-1] # Remove the last cumulative sum

            # Convert to tensors
            # If actual_items_in_pool is 0, flat_list and offsets will be empty
            if not flat_list:
                 flat_tensor = torch.empty(0, dtype=torch.long)
            else:
                 flat_tensor = torch.tensor(flat_list, dtype=torch.long)

            if actual_items_in_pool > 0 and flat_tensor.numel() == 0:
                 # Case where there are non-padded items, but all their lists were empty
                 offsets_tensor = torch.zeros(actual_items_in_pool, dtype=torch.long)
            elif actual_items_in_pool == 0:
                 # Case where all input items in the batch were padded
                 offsets_tensor = torch.empty(0, dtype=torch.long)
            else:
                 offsets_tensor = torch.tensor(offsets, dtype=torch.long)

            input_item_pool_data[col] = (flat_tensor, offsets_tensor)


        # Process scalar features
        for col in self.scalar_feature_cols:
            values = pooled_scalar_features_raw[col] # Contains values for actual_items_in_pool
            if len(values) > 0:
                 if isinstance(values[0], bool):
                      scalar_tensor = torch.tensor(values, dtype=torch.float32).unsqueeze(1)
                 else:
                      scalar_tensor = torch.tensor(values, dtype=torch.float32).unsqueeze(1)
            else:
                 scalar_tensor = torch.empty(0, 1, dtype=torch.float32)

            input_item_pool_data[col] = scalar_tensor


        # --- 4.5 Prepare input ratings and target item IDs tensors ---
        input_ratings_tensor = torch.tensor(batch_input_ratings, dtype=torch.float32) # Shape [Batch_Size, 4]
        target_item_ids_tensor = torch.tensor(batch_target_item_ids, dtype=torch.long) # Shape [Batch_Size]


        # --- 4.6 Return the collated batch dictionary ---
        return {
            'input_item_pool_data': input_item_pool_data,
            'input_sequence_info': {
                'padding_mask': padding_mask,
                'sequence_item_pool_map': sequence_item_pool_map_tensor,
                'actual_items_in_pool': actual_items_in_pool # Pass actual size for dummy vector indexing
            },
            'input_ratings': input_ratings_tensor,
            'target_item_ids': target_item_ids_tensor
        }

# Embeddings

In [31]:
class MovieEmbeddings(nn.Module):
    def __init__(self,
                 d_model_internal: int, # Internal embedding dim before FC
                 hidden_size: int,      # Output dim after FC (Item_Feature_Dim)
                 n_genres: int,
                 n_production_companies: int,
                 n_production_countries: int,
                 n_spoken_languages: int,
                 n_words: int):
        super().__init__()
        # EmbeddingBag layers using provided vocab sizes
        self.genres_embedding = nn.EmbeddingBag(n_genres, d_model_internal*2, mode='mean')
        self.prod_comp_embedding = nn.EmbeddingBag(n_production_companies, d_model_internal, mode='mean')
        self.prod_cont_embedding = nn.EmbeddingBag(n_production_countries, d_model_internal, mode='mean')
        self.lang_embedding = nn.EmbeddingBag(n_spoken_languages, d_model_internal, mode='mean')
        self.word_embedding = nn.EmbeddingBag(n_words, d_model_internal*4, mode='mean') # Shared for keywords, tagline, overview

        # Calculate the input dimension for the final linear layer
        total_embedding_dim = (d_model_internal * 2) + (d_model_internal * 1) + (d_model_internal * 1) + (d_model_internal * 1) + \
                              (d_model_internal * 4) + (d_model_internal * 4) + (d_model_internal * 4) # genres, prod_comp, prod_cont, lang, keywords, tagline, overview
        num_scalar_features_actual = 7 # revenue, budget, runtime, adult_idx, vote_average, vote_count, popularity
        fc_input_dim = total_embedding_dim + num_scalar_features_actual

        self.fc = nn.Linear(fc_input_dim, hidden_size) # Output is Item_Feature_Dim

        self._init_weights()

    def _init_weights(self) -> None:
        for name, param in self.named_parameters():
            if 'weight' in name and param.dim() > 1:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)

    def forward(self, item_data: Dict[str, Tuple[torch.Tensor, torch.Tensor] | torch.Tensor]) -> torch.Tensor:
        """
        Processes pre-collated movie features for a pool of items.
        Returns embeddings of shape [Total_Items_in_Pool, hidden_size].
        """
        # The collate_fn ensures item_data contains tensors for the actual non-padded items
        # + possibly empty tensors/offsets if actual_items_in_pool is 0

        # --- Process List Features ---
        genres_flat, genres_offsets = item_data['genres_idx']
        # Use the number of offsets as batch size for EmbeddingBag if flat is empty
        embag_batch_size = genres_offsets.shape[0] if genres_flat.numel() == 0 else None

        genres_e = self.genres_embedding(genres_flat, genres_offsets) if genres_flat.numel() > 0 else torch.zeros(embag_batch_size if embag_batch_size is not None else 0, self.genres_embedding.embedding_dim, device=genres_flat.device)

        comp_flat, comp_offsets = item_data['production_companies_idx']
        embag_batch_size = comp_offsets.shape[0] if comp_flat.numel() == 0 else None
        comp_e = self.prod_comp_embedding(comp_flat, comp_offsets) if comp_flat.numel() > 0 else torch.zeros(embag_batch_size if embag_batch_size is not None else 0, self.prod_comp_embedding.embedding_dim, device=comp_flat.device)

        cont_flat, cont_offsets = item_data['production_countries_idx']
        embag_batch_size = cont_offsets.shape[0] if cont_flat.numel() == 0 else None
        cont_e = self.prod_cont_embedding(cont_flat, cont_offsets) if cont_flat.numel() > 0 else torch.zeros(embag_batch_size if embag_batch_size is not None else 0, self.prod_cont_embedding.embedding_dim, device=cont_flat.device)

        lang_flat, lang_offsets = item_data['spoken_languages_idx']
        embag_batch_size = lang_offsets.shape[0] if lang_flat.numel() == 0 else None
        lang_e = self.lang_embedding(lang_flat, lang_offsets) if lang_flat.numel() > 0 else torch.zeros(embag_batch_size if embag_batch_size is not None else 0, self.lang_embedding.embedding_dim, device=lang_flat.device)

        kw_flat, kw_offsets = item_data['keywords_idx']
        embag_batch_size = kw_offsets.shape[0] if kw_flat.numel() == 0 else None
        kw_e = self.word_embedding(kw_flat, kw_offsets) if kw_flat.numel() > 0 else torch.zeros(embag_batch_size if embag_batch_size is not None else 0, self.word_embedding.embedding_dim, device=kw_flat.device)

        tag_flat, tag_offsets = item_data['tagline_idx']
        embag_batch_size = tag_offsets.shape[0] if tag_flat.numel() == 0 else None
        tag_e = self.word_embedding(tag_flat, tag_offsets) if tag_flat.numel() > 0 else torch.zeros(embag_batch_size if embag_batch_size is not None else 0, self.word_embedding.embedding_dim, device=tag_flat.device)

        ov_flat, ov_offsets = item_data['overview_idx']
        embag_batch_size = ov_offsets.shape[0] if ov_flat.numel() == 0 else None
        ov_e = self.word_embedding(ov_flat, ov_offsets) if ov_flat.numel() > 0 else torch.zeros(embag_batch_size if embag_batch_size is not None else 0, self.word_embedding.embedding_dim, device=ov_flat.device)

        # --- Process Scalar Features ---
        revenue = item_data["revenue"].float()
        budget = item_data["budget"].float()
        runtime = item_data["runtime"].float()
        adult_idx = item_data["adult_idx"].float()
        vote_average = item_data["vote_average"].float()
        vote_count = item_data["vote_count"].float()
        popularity = item_data["popularity"].float()

        # --- Concatenate all embeddings and scalar features ---
        # All these tensors should have actual_items_in_pool as their first dimension (or be empty if 0)
        pooled_embedding = torch.cat([
            genres_e, comp_e, cont_e, lang_e, kw_e, tag_e, ov_e,
            revenue, budget, runtime, adult_idx, vote_average, vote_count, popularity
        ], dim=1)

        # Pass through final linear layer
        return self.fc(pooled_embedding) # Shape: [Total_Items_in_Pool, hidden_size]

# Positional Encoding

In [32]:
class SimplePositionalEncoding(nn.Module):
    def __init__(self, d_model: int, sequence_length: int = 4):
        super().__init__()
        self.sequence_length = sequence_length
        self.pe = nn.Embedding(sequence_length, d_model)
        self._init_weights()

    def _init_weights(self):
         nn.init.xavier_uniform_(self.pe.weight)

    def forward(self, device: torch.device) -> torch.Tensor:
        """Returns positional embeddings for the fixed input sequence length."""
        positions = torch.arange(self.sequence_length, dtype=torch.long, device=device)
        return self.pe(positions) # Shape: [4, d_model]

# TRXModel

In [35]:
class TRXModel(nn.Module):
    def __init__(self,
                 movie_embedding_params: Dict[str, Any], # Params for MovieEmbeddings
                 item_feature_dim: int,       # Output dimension of MovieEmbeddings (its hidden_size)
                 transformer_n_heads: int,
                 transformer_n_layers: int,
                 transformer_dim_feedforward: int,
                 transformer_dropout: float,
                 n_items: int):               # Total number of unique items/movies for the output prediction layer
        super().__init__()

        self.n_items = n_items
        # Transformer dimension = Item Feature Dim + Rating Dim (1)
        self.transformer_d_model = item_feature_dim + 1
        # Ensure transformer_d_model is divisible by n_heads
        assert self.transformer_d_model % transformer_n_heads == 0, \
            f"Transformer d_model ({self.transformer_d_model}) must be divisible by n_heads ({transformer_n_heads})"

        self.input_sequence_length = 4 # Fixed input sequence length

        # 1. Item Feature Embedding Layer (your detailed MovieEmbeddings)
        # The output size of MovieEmbeddings's FC layer must be item_feature_dim
        # Ensure 'hidden_size' in movie_embedding_params is set to item_feature_dim
        self.movie_embeddings = MovieEmbeddings(**movie_embedding_params, hidden_size=item_feature_dim)

        # 2. Positional Encoding (for fixed length 4)
        self.positional_encoding = SimplePositionalEncoding(self.transformer_d_model, self.input_sequence_length)

        # 3. Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.transformer_d_model,
            nhead=transformer_n_heads,
            dim_feedforward=transformer_dim_feedforward,
            dropout=transformer_dropout,
            batch_first=True # Set to True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=transformer_n_layers
        )

        # 4. Prediction Head: Predict scores for next item (the 5th item)
        self.prediction_head = nn.Linear(self.transformer_d_model, n_items)

        self.dropout = nn.Dropout(transformer_dropout)

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.prediction_head.weight)
        nn.init.zeros_(self.prediction_head.bias)
        # MovieEmbeddings and PositionalEncoding have their own init


    def forward(self,
                input_item_pool_data: Dict[str, Tuple[torch.Tensor, torch.Tensor] | torch.Tensor],
                input_sequence_info: Dict[str, torch.Tensor],
                input_ratings: torch.Tensor, # [Batch_Size, 4]
                target_item_ids: torch.Tensor # [Batch_Size]
               ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass for the Transformer Recommender Model.

        Args:
            input_item_pool_data: Dictionary containing pre-collated feature data for the
                                  pool of input items.
            input_sequence_info: Dictionary with 'padding_mask' [B, 4] and
                                 'sequence_item_pool_map' [B, 4], 'actual_items_in_pool'.
            input_ratings: Tensor [Batch_Size, 4] with ratings for the input items.
            target_item_ids: Tensor [Batch_Size] with the IDs of the 5th item (target).

        Returns:
            Tuple: (item_scores [Batch_Size, n_items], target_item_ids [Batch_Size])
        """
        device = input_ratings.device
        batch_size = input_ratings.shape[0]
        input_seq_len = self.input_sequence_length # which is 4

        # 1. Get embeddings for the pool of input items
        # Output shape: [Actual_Items_in_Pool, item_feature_dim]
        pooled_item_features = self.movie_embeddings(input_item_pool_data)
        actual_items_in_pool = input_sequence_info['actual_items_in_pool'] # Convert to int

        # 2. Prepare pooled features with a dummy vector for padding
        # The dummy vector index in sequence_item_pool_map is actual_items_in_pool
        dummy_zero_feature = torch.zeros(1, self.movie_embeddings.fc.out_features, device=device)
        pooled_item_features_with_dummy = torch.cat([pooled_item_features, dummy_zero_feature], dim=0) # Shape: [Actual_Items_in_Pool + 1, Item_Feature_Dim]


        # 3. Reshape/Gather pooled features back into sequence structure
        # Use the map to fill the sequence tensor [Batch_Size, 4, item_feature_dim]
        sequence_item_pool_map = input_sequence_info['sequence_item_pool_map'].long() # [Batch_Size, 4]

        # Gather features using the map indices
        # Flatten the map and gather from the pooled features with dummy
        flat_map = sequence_item_pool_map.view(-1) # [Batch_Size * 4]
        sequence_item_features_flat = torch.index_select(pooled_item_features_with_dummy, dim=0, index=flat_map)
        sequence_item_features = sequence_item_features_flat.view(batch_size, input_seq_len, -1) # Reshape back [B, 4, Item_Feature_Dim]

        # Zero out features for padded positions if needed (dummy vector should handle this if it's truly zero)
        # Using the padding mask from input_sequence_info
        padding_mask = input_sequence_info['padding_mask'].unsqueeze(-1) # [B, 4, 1]
        sequence_item_features = sequence_item_features * (~padding_mask)


        # 4. Concatenate Features and Ratings
        # sequence_item_features: [Batch_Size, 4, item_feature_dim]
        # input_ratings: [Batch_Size, 4] -> needs to be [Batch_Size, 4, 1]
        input_ratings_expanded = input_ratings.unsqueeze(-1) # Shape: [Batch_Size, 4, 1]

        # Concatenate along the last dimension
        sequence_input = torch.cat([sequence_item_features, input_ratings_expanded], dim=-1)
        # Shape: [Batch_Size, 4, item_feature_dim + 1] = [Batch_Size, 4, transformer_d_model]


        # 5. Add Positional Encoding (for fixed length 4)
        # Positional embeddings shape: [4, transformer_d_model]
        positional_embeddings = self.positional_encoding(device)
        # Add positional embeddings to each item in the batch
        # Shape: [Batch_Size, 4, transformer_d_model]
        sequence_input = sequence_input + positional_embeddings.unsqueeze(0) # unsqueeze to broadcast over batch dim

        # Apply dropout
        sequence_input = self.dropout(sequence_input)

        # 6. Pass through Transformer Encoder
        # Input shape [Batch_Size, Sequence_Length, d_model] because batch_first=True
        transformer_output = self.transformer_encoder(
            sequence_input,
            src_key_padding_mask=input_sequence_info['padding_mask'] # Pass the padding mask
        ) # Output shape: [Batch_Size, 4, transformer_d_model]


        # 7. Extract Representation for Prediction
        # We predict the 5th item based on the representation of the 4th item (last input item)
        # The index of the last input item is 3 (since input sequence length is 4, indices 0-3)
        # We need the output embedding at index 3 for every sequence in the batch.
        extracted_representation = transformer_output[:, 3, :] # Shape: [Batch_Size, transformer_d_model]

        # Apply dropout before the final layer
        extracted_representation = self.dropout(extracted_representation)

        # 8. Prediction Head
        # Predict scores for all items in the catalog
        # Output shape: [Batch_Size, n_items]
        item_scores = self.prediction_head(extracted_representation)

        # Return scores and the target item IDs for loss calculation outside the forward pass
        return item_scores, target_item_ids

In [36]:
# Model Parameters
d_model_internal_me = 16 # Internal dim used in MovieEmbeddings
# Choose item_feature_dim such that (item_feature_dim + 1) is divisible by transformer_n_heads
item_feature_dim = 63 # Example, adjusted for n_heads=4 --> 63 + 1 = 64, divisible by 4
transformer_n_heads = 4
transformer_n_layers = 2
ff_dimension = 128
batch_size = 32 # Your desired batch size

# MovieEmbeddings parameters using actual vocab sizes
movie_embed_params = {
    'd_model_internal': d_model_internal_me,
    'n_genres': len_genres,
    'n_production_companies': len_prod_comp,
    'n_production_countries': len_prod_cont,
    'n_spoken_languages': len_langs,
    'n_words': len_words,
    # hidden_size will be set by TRXModel based on item_feature_dim
}


# --- Instantiate Dataset and Collator ---
movie_dataset = MovieDataset(sequence_df, movie_vocab_stoi,
                                padding_id=PADDING_MOVIE_ID_INT,
                                rating_padding_value=RATING_PADDING_VALUE)

movie_collator = MovieCollator(movie_features_map, list_feature_cols, scalar_feature_cols,
                                padding_id=PADDING_MOVIE_ID_INT,
                                rating_padding_value=RATING_PADDING_VALUE)

# --- Instantiate DataLoader ---
# Use num_workers > 0 for faster data loading in practice
movie_dataloader = DataLoader(
    movie_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=movie_collator,
    num_workers=0 # Set to >0 for performance, but 0 is easier for debugging
)

# --- Instantiate Model ---
model = TRXModel(
    movie_embedding_params=movie_embed_params,
    item_feature_dim=item_feature_dim,
    transformer_n_heads=transformer_n_heads,
    transformer_n_layers=transformer_n_layers,
    transformer_dim_feedforward=ff_dimension,
    transformer_dropout=0.1,
    n_items=n_items # Use total items from vocab
)

print(f"Model instantiated with {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters.")
print(f"Transformer D_model: {model.transformer_d_model}")

# --- Dummy Training Loop Iteration ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print(f"\nIterating through DataLoader on device: {device}")

try:
    # Get one batch from the DataLoader
    for batch_idx, batch_data in enumerate(movie_dataloader):
        print(f"\nProcessing batch {batch_idx + 1}")

        # Move batch data to the device
        # This involves iterating through the nested dictionary structure
        batch_data_on_device = {}
        for key, value in batch_data.items():
            if key == 'input_item_pool_data':
                batch_data_on_device[key] = {}
                for feature_key, feature_value in value.items():
                    if isinstance(feature_value, tuple): # (flat_indices, offsets)
                        batch_data_on_device[key][feature_key] = (feature_value[0].to(device), feature_value[1].to(device))
                    elif isinstance(feature_value, torch.Tensor): # Scalars
                            batch_data_on_device[key][feature_key] = feature_value.to(device)
                    else:
                            batch_data_on_device[key][feature_key] = feature_value # Keep non-tensor items (like int actual_items_in_pool)
            elif key == 'input_sequence_info':
                    batch_data_on_device[key] = {}
                    for info_key, info_value in value.items():
                        if isinstance(info_value, torch.Tensor):
                            batch_data_on_device[key][info_key] = info_value.to(device)
                        else:
                            batch_data_on_device[key][info_key] = info_value # Keep non-tensor (like int)
            elif isinstance(value, torch.Tensor): # input_ratings, target_item_ids
                batch_data_on_device[key] = value.to(device)
            else:
                batch_data_on_device[key] = value


        # Separate input arguments for the model's forward method
        input_item_pool_data = batch_data_on_device['input_item_pool_data']
        input_sequence_info = batch_data_on_device['input_sequence_info']
        input_ratings = batch_data_on_device['input_ratings']
        target_item_ids = batch_data_on_device['target_item_ids']

        # Perform a forward pass
        # In a real training loop, remove torch.no_grad()
        with torch.no_grad():
            output_scores, returned_target_ids = model(
            input_item_pool_data=input_item_pool_data,
            input_sequence_info=input_sequence_info,
            input_ratings=input_ratings,
            target_item_ids=target_item_ids # Pass targets, model returns them
            )

        print(f"Output scores shape: {output_scores.shape}")
        print(f"Returned target IDs shape: {returned_target_ids.shape}")

        # Example Loss Calculation (in a real training loop)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output_scores, returned_target_ids)
        print(f"Dummy Loss: {loss.item()}")

        # Stop after the first batch for this example
        if batch_idx == 0:
            break

except Exception as e:
    print(f"\nAn error occurred during data loading or forward pass: {e}")
    import traceback
    traceback.print_exc()

Model instantiated with 23737942 parameters.
Transformer D_model: 64

Iterating through DataLoader on device: cpu

Processing batch 1
Output scores shape: torch.Size([32, 86494])
Returned target IDs shape: torch.Size([32])
Dummy Loss: 11.367762565612793
