### Loading Libraries

In [1]:
import numpy as np
import polars as pl
import torch
from torch import nn, functional as F

### Getting Datasets

In [None]:
from zipfile import ZipFile
from urllib.request import urlretrieve
import os

urlretrieve("https://files.grouplens.org/datasets/movielens/ml-32m.zip", "../data/movielens.zip")
ZipFile("../data/movielens.zip", "r").extractall('../data/')
os.remove('../data/movielens.zip')

urlretrieve("https://www.kaggle.com/api/v1/datasets/download/asaniczka/tmdb-movies-dataset-2023-930k-movies", "../data/tmdb.zip")
ZipFile("../data/tmdb.zip", "r").extractall('../data/tmdb_dataset/')
os.remove('../data/tmdb.zip')

In [27]:
movies_df = pl.read_csv('../data/tmdb_dataset/TMDB_movie_dataset_v11.csv')
ratings_df = pl.read_csv('../data/ml-32m/ratings.csv')
links_df = pl.read_csv('../data/ml-32m/links.csv')

In [28]:
print("movies_df:",movies_df.schema)
print("ratings_df:",ratings_df.schema)
print("links_df:",links_df.schema)

movies_df: Schema([('id', Int64), ('title', String), ('vote_average', Float64), ('vote_count', Int64), ('status', String), ('release_date', String), ('revenue', Int64), ('runtime', Int64), ('adult', Boolean), ('backdrop_path', String), ('budget', Int64), ('homepage', String), ('imdb_id', String), ('original_language', String), ('original_title', String), ('overview', String), ('popularity', Float64), ('poster_path', String), ('tagline', String), ('genres', String), ('production_companies', String), ('production_countries', String), ('spoken_languages', String), ('keywords', String)])
ratings_df: Schema([('userId', Int64), ('movieId', Int64), ('rating', Float64), ('timestamp', Int64)])
links_df: Schema([('movieId', Int64), ('imdbId', Int64), ('tmdbId', Int64)])


In [29]:
movies_df = links_df.join(movies_df,left_on="tmdbId",right_on="id")
movies_df = movies_df.drop(["title","status","backdrop_path","homepage","imdb_id","imdbId","poster_path"])
movies_df.head()

movieId,tmdbId,vote_average,vote_count,release_date,revenue,runtime,adult,budget,original_language,original_title,overview,popularity,tagline,genres,production_companies,production_countries,spoken_languages,keywords
i64,i64,f64,i64,str,i64,i64,bool,i64,str,str,str,f64,str,str,str,str,str,str
79132,27205,8.364,34495,"""2010-07-15""",825532764,148,False,160000000,"""en""","""Inception""","""Cobb, a skilled thief who comm…",83.952,"""Your mind is the scene of the …","""Action, Science Fiction, Adven…","""Legendary Pictures, Syncopy, W…","""United Kingdom, United States …","""English, French, Japanese, Swa…","""rescue, mission, dream, airpla…"
109487,157336,8.417,32571,"""2014-11-05""",701729206,169,False,165000000,"""en""","""Interstellar""","""The adventures of a group of e…",140.241,"""Mankind was born on Earth. It …","""Adventure, Drama, Science Fict…","""Legendary Pictures, Syncopy, L…","""United Kingdom, United States …","""English""","""rescue, future, spacecraft, ra…"
58559,155,8.512,30619,"""2008-07-16""",1004558444,152,False,185000000,"""en""","""The Dark Knight""","""Batman raises the stakes in hi…",130.643,"""Welcome to a world without rul…","""Drama, Action, Crime, Thriller""","""DC Comics, Legendary Pictures,…","""United Kingdom, United States …","""English, Mandarin""","""joker, sadism, chaos, secret i…"
72998,19995,7.573,29815,"""2009-12-15""",2923706026,162,False,237000000,"""en""","""Avatar""","""In the 22nd century, a paraple…",79.932,"""Enter the world of Pandora.""","""Action, Adventure, Fantasy, Sc…","""Dune Entertainment, Lightstorm…","""United States of America, Unit…","""English, Spanish""","""future, society, culture clash…"
89745,24428,7.71,29166,"""2012-04-25""",1518815515,143,False,220000000,"""en""","""The Avengers""","""When an unexpected enemy emerg…",98.082,"""Some assembly required.""","""Science Fiction, Action, Adven…","""Marvel Studios""","""United States of America""","""English, Hindi, Russian""","""new york city, superhero, shie…"


In [30]:
# convert ids to categorical or string data type
movies_df = movies_df.with_columns([
    (pl.col("movieId").cast(str)).map_elements(lambda x: f"movie_{x}",return_dtype=pl.String).alias("movieId"),
    (pl.col("revenue")/100000),
    (pl.col("budget")/100000)
])

ratings_df = ratings_df.with_columns([
    (pl.col("movieId").cast(str)).map_elements(lambda x: f"movie_{x}",return_dtype=pl.String).alias("movieId"),
    (pl.col("userId").cast(str)).map_elements(lambda x: f"user_{x}",return_dtype=pl.String).alias("userId"),
])

### Building vocab

In [31]:
from collections import Counter
from torchtext.vocab import vocab

# vocab for movie_ids
movie_ids = movies_df['movieId'].unique()
movie_counter = Counter(movie_ids)
movie_vocab = vocab(movie_counter, specials=['<unk>'])
movie_vocab_stoi = movie_vocab.get_stoi()
movie_title_dict = dict(zip(movies_df['movieId'].to_list(), movies_df['original_title'].to_list()))

# vocab for user_ids
user_ids = ratings_df['userId'].unique()
user_counter = Counter(user_ids)
user_vocab = vocab(user_counter, specials=['<unk>'])
user_vocab_stoi = user_vocab.get_stoi()

# vocab for genres
movies_df = movies_df.with_columns(pl.col('genres').map_elements(lambda x: [s.strip() for s in x.split(',')],return_dtype=pl.List(pl.String)))
genres_counter = Counter(movies_df['genres'].explode().unique())
genres_vocab = vocab(genres_counter, specials=['<unk>'])
genres_vocab_stoi = genres_vocab.get_stoi()
genres_vocab_itos = genres_vocab.get_itos()

# vocab for production companies
movies_df = movies_df.with_columns(pl.col('production_companies').map_elements(lambda x: [s.strip() for s in x.split(',')],return_dtype=pl.List(pl.String)))
prod_comp_counter = Counter(movies_df['production_companies'].explode().unique())
prod_comp_vocab = vocab(prod_comp_counter, specials=['<unk>'])
prod_comp_vocab_stoi = prod_comp_vocab.get_stoi()
prod_comp_vocab_itos = prod_comp_vocab.get_itos()

# vocab for production countries
movies_df = movies_df.with_columns(pl.col('production_countries').map_elements(lambda x: [s.strip() for s in x.split(',')],return_dtype=pl.List(pl.String)))
prod_countries_counter = Counter(movies_df['production_countries'].explode().unique())
prod_countries_vocab = vocab(prod_countries_counter, specials=['<unk>'])
prod_countries_vocab_stoi = prod_countries_vocab.get_stoi()
prod_countries_vocab_itos = prod_countries_vocab.get_itos()

# vocab for spoken languages
movies_df = movies_df.with_columns(pl.col('spoken_languages').map_elements(lambda x: [s.strip() for s in x.split(',')],return_dtype=pl.List(pl.String)))
languages_counter = Counter(movies_df['spoken_languages'].explode().unique())
languages_vocab = vocab(languages_counter, specials=['<unk>'])
languages_vocab_stoi = languages_vocab.get_stoi()
languages_vocab_itos = languages_vocab.get_itos()

# vocab for words
movies_df = movies_df.with_columns([
    pl.col('keywords').map_elements(lambda x: [s.strip() for s in x.split(',')],return_dtype=pl.List(pl.String)),
    pl.col('overview').map_elements(lambda x: [s.strip() for s in x.split(' ')],return_dtype=pl.List(pl.String)),
    pl.col('tagline').map_elements(lambda x: [s.strip() for s in x.split(' ')],return_dtype=pl.List(pl.String))
])

words_counter = Counter(set(movies_df['keywords'].explode().unique().to_list() + movies_df['overview'].explode().unique().to_list() + movies_df['tagline'].explode().unique().to_list()))
words_vocab = vocab(words_counter, specials=['<unk>'])
words_vocab_stoi = words_vocab.get_stoi()
words_vocab_itos = words_vocab.get_itos()

In [33]:
print("Number of user: ",len(user_vocab))
print("Number of movies: ",len(movie_vocab))
print("Number of genres: ",len(genres_vocab))
print("Number of production_companies: ",len(prod_comp_vocab))
print("Number of production_countries: ",len(prod_countries_vocab))
print("Number of spoken_languages: ",len(languages_vocab))
print("Number of words: ",len(words_vocab))

Number of user:  200949
Number of movies:  86494
Number of genres:  21
Number of production_companies:  45546
Number of production_countries:  201
Number of spoken_languages:  164
Number of words:  270246


### Using timestamp to generate sequences

In [34]:
sequence_length = 5
step_size = 2

ratings_df = ratings_df.sort(["userId", "timestamp"])
ratings_df = ratings_df.with_columns([
    pl.col("userId").cum_count().over("userId").alias("idx")
])
for i in range(sequence_length):
    ratings_df = ratings_df.with_columns([
        pl.col("movieId").shift(-i).over("userId").alias(f"movie_{i}"),
        pl.col("rating").shift(-i).over("userId").alias(f"rating_{i}")
    ])

ratings_df = ratings_df.filter(pl.col("idx") % step_size == 0)
ratings_df = ratings_df.filter(
    pl.fold(
        acc=pl.lit(True),
        function=lambda acc, x: acc & x.is_not_null(),
        exprs=[pl.col(f"movie_{i}") for i in range(sequence_length)] +
              [pl.col(f"rating_{i}") for i in range(sequence_length)]
    )
)
ratings_df = ratings_df.with_columns([
    pl.concat_list([pl.col(f"movie_{i}") for i in range(sequence_length)]).alias("sequence"),
    pl.concat_list([pl.col(f"rating_{i}") for i in range(sequence_length)]).alias("rating_sequence")
])
ratings_df = ratings_df.select(["userId", "sequence", "rating_sequence"]).group_by(pl.col("userId"),maintain_order=True).agg(pl.col("sequence"),pl.col("rating_sequence"))

ratings_df = ratings_df.explode(["sequence", "rating_sequence"]).rename({
    "sequence": "sequence_movie_ids",
    "rating_sequence": "sequence_ratings"
})

In [35]:
ratings_df

userId,sequence_movie_ids,sequence_ratings
str,list[str],list[f64]
"""user_1""","[""movie_2966"", ""movie_2890"", … ""movie_541""]","[1.0, 4.0, … 5.0]"
"""user_1""","[""movie_3078"", ""movie_2882"", … ""movie_1136""]","[2.0, 1.0, … 1.0]"
"""user_1""","[""movie_541"", ""movie_838"", … ""movie_1211""]","[5.0, 5.0, … 2.0]"
"""user_1""","[""movie_1136"", ""movie_1236"", … ""movie_2396""]","[1.0, 4.0, … 5.0]"
"""user_1""","[""movie_1211"", ""movie_3030"", … ""movie_232""]","[2.0, 4.0, … 5.0]"
…,…,…
"""user_99999""","[""movie_3864"", ""movie_7842"", … ""movie_4733""]","[3.5, 2.5, … 2.5]"
"""user_99999""","[""movie_53894"", ""movie_42738"", … ""movie_53953""]","[4.5, 3.0, … 3.0]"
"""user_99999""","[""movie_4733"", ""movie_4030"", … ""movie_55820""]","[2.5, 3.0, … 4.5]"
"""user_99999""","[""movie_53953"", ""movie_63113"", … ""movie_8131""]","[3.0, 4.0, … 4.5]"


### Train Test Split

In [36]:
random_selection = np.random.rand(len(ratings_df)) <= 0.85

df_train_data = ratings_df.filter(random_selection)
train_data_raw = df_train_data[["userId", "sequence_movie_ids", "sequence_ratings"]].to_numpy()

df_test_data = ratings_df.filter(~random_selection)
test_data_raw = df_test_data[["userId", "sequence_movie_ids", "sequence_ratings"]].to_numpy()

### Creating DataLoaders

In [37]:
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

class MovieSeqDataset(Dataset):
    def __init__(self, data, movie_vocab_stoi, user_vocab_stoi):
        self.data = data
        self.movie_vocab_stoi = movie_vocab_stoi
        self.user_vocab_stoi = user_vocab_stoi
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        user, movie_sequence, rating_sequence = self.data[idx]
        movie_data = [self.movie_vocab_stoi.get(item,movie_vocab_stoi['<unk>']) for item in movie_sequence]
        user_data = self.user_vocab_stoi[user]
        return torch.tensor(movie_data), torch.tensor(user_data), torch.tensor(rating_sequence)
    
def collate_batch(batch):
    movie_list = [item[0] for item in batch]
    user_list = [item[1] for item in batch]
    rating_list = [item[2] for item in batch]
    return pad_sequence(movie_list, padding_value=movie_vocab_stoi['<unk>'], batch_first=True), torch.stack(user_list), pad_sequence(rating_list, padding_value=3, batch_first=True)

In [38]:
BATCH_SIZE = 16

train_dataset = MovieSeqDataset(train_data_raw, movie_vocab_stoi, user_vocab_stoi)
val_dataset = MovieSeqDataset(test_data_raw, movie_vocab_stoi, user_vocab_stoi)

train_iter = DataLoader(train_dataset, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)
val_iter = DataLoader(val_dataset, batch_size=BATCH_SIZE,shuffle=False, collate_fn=collate_batch)

In [40]:
for i, (movie_data, user_data, ratings_data) in enumerate(train_iter):
    print(movie_data.shape, user_data.shape, ratings_data.shape)
    break

torch.Size([16, 5]) torch.Size([16]) torch.Size([16, 5])


In [41]:
from polars import col, List, Int64

def list_to_idx(c, stoi):
    unk = stoi.get("<unk>", -1)
    return c.map_elements(lambda lst: [stoi.get(x, unk) for x in lst],
                          return_dtype=List(Int64))

movies_df = movies_df.with_columns([
    list_to_idx(col("genres"), genres_vocab_stoi).alias("genres_idx"),
    list_to_idx(col("production_companies"), prod_comp_vocab_stoi).alias("production_companies_idx"),
    list_to_idx(col("production_countries"), prod_countries_vocab_stoi).alias("production_countries_idx"),
    list_to_idx(col("spoken_languages"), languages_vocab_stoi).alias("spoken_languages_idx"),
    list_to_idx(col("keywords"), words_vocab_stoi).alias("keywords_idx"),
    list_to_idx(col("overview"), words_vocab_stoi).alias("overview_idx"),
    list_to_idx(col("tagline"), words_vocab_stoi).alias("tagline_idx"),
])


In [42]:
unk_mid = movie_vocab_stoi.get("<unk>", -1)
movies_df = movies_df.with_columns(
    col("movieId")
      .map_elements(lambda x: movie_vocab_stoi.get(x, unk_mid), return_dtype=Int64)
      .alias("movieId_idx")
)

movies_df = movies_df.with_columns(
    col("adult").cast(Int64).alias("adult_idx")
)

movies_df = movies_df.with_columns([
    col("vote_average").cast(pl.Float64),
    col("vote_count").cast(pl.Int64),
    col("revenue").cast(pl.Float64),
    col("runtime").cast(pl.Int64),
    col("budget").cast(pl.Float64),
    col("popularity").cast(pl.Float64),
])


In [43]:
movies_prepped = movies_df.select([
    "movieId_idx",
    "genres_idx",
    "production_companies_idx",
    "production_countries_idx",
    "spoken_languages_idx",
    "keywords_idx",
    "overview_idx",
    "tagline_idx",
    "adult_idx",
    "vote_average", 
    "vote_count",
    "revenue",
    "runtime",
    "budget",
    "popularity",
])

movies_prepped.head()

movieId_idx,genres_idx,production_companies_idx,production_countries_idx,spoken_languages_idx,keywords_idx,overview_idx,tagline_idx,adult_idx,vote_average,vote_count,revenue,runtime,budget,popularity
i64,list[i64],list[i64],list[i64],list[i64],list[i64],list[i64],list[i64],i64,f64,i64,f64,i64,f64,f64
60004,"[1, 6, 11]","[38600, 44135, 33646]","[159, 36]","[35, 99, … 144]","[243811, 8740, … 178817]","[187595, 109853, … 156388]","[47481, 213059, … 260989]",0,8.364,34495,8255.32764,148,1600.0,83.952
29815,"[11, 2, 6]","[38600, 44135, 35673]","[159, 36]",[35],"[243811, 251008, … 137131]","[115676, 122952, … 172657]","[68669, 266033, … 174268]",0,8.417,32571,7017.29206,169,1650.0,140.241
29132,"[2, 1, … 16]","[33815, 38600, … 33646]","[159, 36]","[35, 48]","[43159, 113706, … 140036]","[204544, 64010, … 144770]","[44438, 86722, … 171955]",0,8.512,30619,10045.58444,152,1850.0,130.643
2832,"[1, 11, … 6]","[18183, 9124, … 40890]","[36, 159]","[35, 65]","[251008, 254205, … 43486]","[234347, 267437, … 185792]","[106904, 267437, … 156736]",0,7.573,29815,29237.06026,162,2370.0,79.932
32424,"[6, 1, 11]",[35908],[36],"[35, 114, 20]","[166015, 205486, … 160073]","[19736, 201312, … 178020]","[244183, 4493, 55910]",0,7.71,29166,15188.15515,143,2200.0,98.082
