### Importing Modules

In [54]:
import torch
from torch import nn
import numpy as np
import polars as pl
import math

### Positional Encoding for movie sequence

In [45]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float=0.1, max_len: int=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        positions = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0,d_model,2)*(-math.log(10000.0)/d_model))
        
        pe = torch.zeros(max_len,1,d_model)

        pe[:, 0, 0::2] = torch.sin(positions * div_term)
        pe[:, 0, 1::2] = torch.cos(positions * div_term)
        self.register_buffer('pe', pe)
    
    def forward(self, x: torch.Tensor):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
    

In [53]:
pe = PositionalEncoding(2)
x = torch.tensor([[0,0,0,0,0]])
print(x.shape)
e = nn.Embedding(10,2)
y_e = e(x)
print(y_e.shape)
y=pe(y_e)
print(y.shape)

torch.Size([1, 5])
torch.Size([1, 5, 2])
torch.Size([1, 5, 2])


### Generate embeddings for batch of movies

In [56]:
import torch 
from torch import nn
from typing import Tuple

class MovieEmbeddings(nn.Module):
    def __init__(self, 
                 d_model: int,
                 hidden_size: int,
                 num_list_features: int,
                 num_scalar_features: int,
                 n_genres: int, 
                 n_production_companies: int,
                 n_production_countries: int,
                 n_spoken_languages: int,
                 n_words: int):
        super().__init__()
        self.genres_embedding = nn.EmbeddingBag(n_genres, d_model*2, mode='mean')
        self.prod_comp_embedding = nn.EmbeddingBag(n_production_companies, d_model, mode='mean')
        self.prod_cont_embedding = nn.EmbeddingBag(n_production_countries, d_model, mode='mean')
        self.lang_embedding = nn.EmbeddingBag(n_spoken_languages, d_model, mode='mean')
        self.word_embedding = nn.EmbeddingBag(n_words, d_model*4, mode='mean')
        self.fc = nn.Linear(d_model*(10+num_list_features)+num_scalar_features,hidden_size)
        self._init_weights()

    def _init_weights(self) -> None:
        nn.init.xavier_uniform_(self.genres_embedding.weight)
        nn.init.xavier_uniform_(self.prod_comp_embedding.weight)
        nn.init.xavier_uniform_(self.prod_cont_embedding.weight)
        nn.init.xavier_uniform_(self.lang_embedding.weight)
        nn.init.xavier_uniform_(self.word_embedding.weight)
        nn.init.xavier_uniform_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)

    def _prepare_embedding_inputs(self, list_of_lists) -> Tuple[torch.Tensor, torch.Tensor]:
        flat_list = []
        offsets = [0]
        for sublist in list_of_lists:
            flat_list.extend(sublist)
            offsets.append(offsets[-1] + len(sublist))
        offsets = offsets[:-1]  # Remove last cumulative sum
        offsets = torch.tensor(offsets, dtype=torch.long)
        flat_list = torch.tensor(flat_list, dtype=torch.long)
        return flat_list, offsets   

    def forward(self, row: pl.DataFrame) -> torch.Tensor:
        genres, genres_offsets = self._prepare_embedding_inputs(row['genres_idx'])
        genres_e = self.genres_embedding(genres, genres_offsets)

        comp, comp_offsets = self._prepare_embedding_inputs(row['production_companies_idx'])
        comp_e = self.prod_comp_embedding(comp, comp_offsets)

        cont, cont_offsets = self._prepare_embedding_inputs(row['production_countries_idx'])
        cont_e = self.prod_cont_embedding(cont, cont_offsets)

        lang, lang_offsets = self._prepare_embedding_inputs(row['spoken_languages_idx'])
        lang_e = self.lang_embedding(lang, lang_offsets)

        kw, kw_offsets = self._prepare_embedding_inputs(row['keywords_idx'])
        kw_e = self.word_embedding(kw, kw_offsets)

        tag, tag_offsets = self._prepare_embedding_inputs(row['tagline_idx'])
        tag_e = self.word_embedding(tag, tag_offsets)

        ov, ov_offsets = self._prepare_embedding_inputs(row['overview_idx'])
        ov_e = self.word_embedding(ov, ov_offsets)

        # Scalar features as tensors (ensure shape is [batch_size, 1])
        revenue = torch.tensor(row["revenue"], dtype=torch.float32).unsqueeze(1)
        budget = torch.tensor(row["budget"], dtype=torch.float32).unsqueeze(1)
        runtime = torch.tensor(row["runtime"], dtype=torch.float32).unsqueeze(1)
        adult_idx = torch.tensor(row["adult_idx"], dtype=torch.bool).unsqueeze(1)
        vote_average = torch.tensor(row["vote_average"], dtype=torch.float32).unsqueeze(1)
        vote_count = torch.tensor(row["vote_count"], dtype=torch.float32).unsqueeze(1)
        popularity = torch.tensor(row["popularity"], dtype=torch.float32).unsqueeze(1)

        # Concatenate all embeddings and scalar features
        master_embedding = torch.cat([
            genres_e,
            comp_e,
            cont_e,
            lang_e,
            kw_e,
            tag_e,
            ov_e,
            revenue,
            budget,
            runtime,
            adult_idx,
            vote_average,
            vote_count,
            popularity
        ], dim=1)

        return self.fc(master_embedding)

In [57]:
len_genres = 21
len_prod_comp = 45546
len_prod_cont = 201
len_langs = 164
len_words = 270246
movies_prepped = pl.read_parquet('../data/processed/output.parquet')

me = MovieEmbeddings(16,256,7,7,len_genres,len_prod_comp,len_prod_cont,len_langs,len_words)
print(f"Number of parameters: {sum(p.numel() for p in me.parameters() if p.requires_grad)}")
y = me(movies_prepped[:10])
print(y.shape)
print(y)

Number of parameters: 18102672
torch.Size([10, 256])
tensor([[-1831.9744,  -866.2841, -2882.7822,  ...,   242.2773, -3159.7954,
         -2309.8333],
        [-1665.1652,  -841.4111, -2762.7041,  ...,   210.1431, -3007.7551,
         -2248.9331],
        [-1839.4342,  -632.6804, -2498.7393,  ...,   245.2319, -2732.8325,
         -1841.5848],
        ...,
        [-1010.3897,  -944.6721, -2421.1853,  ...,   116.8286, -2649.2502,
         -2254.8867],
        [-1520.3438,  -588.8050, -2210.5342,  ...,   201.7019, -2393.4956,
         -1680.5638],
        [-1059.3677,  -875.0412, -2221.3069,  ...,   133.6470, -2497.7415,
         -2035.9862]], grad_fn=<AddmmBackward0>)
