In [26]:
import pandas as pd
import numpy as np

import plotly.express as px

In [27]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, roc_auc_score, f1_score, classification_report
from catboost import Pool, CatBoostClassifier, cv
from sklearn.model_selection import GridSearchCV

In [28]:
imdb_df = pd.read_csv('data/imdb_encoded_with_topics.csv')

## Data Preparation

In [29]:
imdb_df.columns

Index(['actor1', 'actor2', 'actor3', 'actor4', 'director_enc', 'action',
       'adult', 'adventure', 'animation', 'biography', 'comedy', 'crime',
       'documentary', 'drama', 'family', 'fantasy', 'film-noir', 'game-show',
       'history', 'horror', 'music', 'musical', 'mystery', 'news',
       'reality-tv', 'romance', 'sci-fi', 'short', 'sport', 'talk-show',
       'thriller', 'unknown', 'war', 'western', 'link', 'genre', 'duration',
       'imdb_rating', 'votes', 'release_start', 'release_month', 'tv_series',
       'title', 'synopsis', 'director', 'actors', 'synopsis_lemmatized',
       'topic'],
      dtype='object')

In [30]:
imdb_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 161602 entries, 0 to 161601
Data columns (total 48 columns):
 #   Column               Non-Null Count   Dtype  
---  ------               --------------   -----  
 0   actor1               161602 non-null  int64  
 1   actor2               161602 non-null  int64  
 2   actor3               161602 non-null  int64  
 3   actor4               161602 non-null  int64  
 4   director_enc         161602 non-null  int64  
 5   action               161602 non-null  int64  
 6   adult                161602 non-null  int64  
 7   adventure            161602 non-null  int64  
 8   animation            161602 non-null  int64  
 9   biography            161602 non-null  int64  
 10  comedy               161602 non-null  int64  
 11  crime                161602 non-null  int64  
 12  documentary          161602 non-null  int64  
 13  drama                161602 non-null  int64  
 14  family               161602 non-null  int64  
 15  fantasy          

In [31]:
model_df = imdb_df.dropna()

In [32]:
model_df.shape

(160422, 48)

In [33]:
px.histogram(model_df, 'imdb_rating')

Binning the countinous variable into groups to be able to achieve classification tasks.

In [34]:
model_df['imdb_rating_cat'] = pd.cut(model_df['imdb_rating'], bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], right=True, labels=False) 
px.histogram(model_df, 'imdb_rating', color = 'imdb_rating_cat')



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



## Model training

In [35]:
imdb_df.columns

Index(['actor1', 'actor2', 'actor3', 'actor4', 'director_enc', 'action',
       'adult', 'adventure', 'animation', 'biography', 'comedy', 'crime',
       'documentary', 'drama', 'family', 'fantasy', 'film-noir', 'game-show',
       'history', 'horror', 'music', 'musical', 'mystery', 'news',
       'reality-tv', 'romance', 'sci-fi', 'short', 'sport', 'talk-show',
       'thriller', 'unknown', 'war', 'western', 'link', 'genre', 'duration',
       'imdb_rating', 'votes', 'release_start', 'release_month', 'tv_series',
       'title', 'synopsis', 'director', 'actors', 'synopsis_lemmatized',
       'topic'],
      dtype='object')

In [36]:
dep_var = f'imdb_rating_cat'
indep_vars = ['genre', 'duration', 'votes', 'release_start', 'release_month', 'tv_series', 'topic',
       'title', 'director', 'actors', 'synopsis_lemmatized']

In [37]:
X = model_df[indep_vars]
y = model_df[dep_var]

In [38]:
X.columns

Index(['genre', 'duration', 'votes', 'release_start', 'release_month',
       'tv_series', 'topic', 'title', 'director', 'actors',
       'synopsis_lemmatized'],
      dtype='object')

In [39]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [40]:
X_train.shape, X_test.shape

((128337, 11), (32085, 11))

In [41]:
cat_features = ['actors', 'director', 'genre']
text_features = ['title', 'synopsis_lemmatized']

model = CatBoostClassifier(iterations=50, loss_function='MultiClass')

grid = {
    'learning_rate': [1, 0.5],
    'depth': [5, 4],
        }

search = GridSearchCV(model, grid, cv=5, scoring='roc_auc_ovr', verbose=10, n_jobs=-1)
search.fit(X_train, y_train, cat_features = cat_features, text_features = text_features)

Fitting 5 folds for each of 4 candidates, totalling 20 fits



One or more of the test scores are non-finite: [       nan 0.82369374 0.81548561 0.8196748 ]



0:	learn: 1.7591427	total: 1.89s	remaining: 1m 32s
1:	learn: 1.6303010	total: 3.57s	remaining: 1m 25s
2:	learn: 1.5451485	total: 5.88s	remaining: 1m 32s
3:	learn: 1.4994254	total: 8.14s	remaining: 1m 33s
4:	learn: 1.4786700	total: 10.2s	remaining: 1m 31s
5:	learn: 1.4661430	total: 12s	remaining: 1m 28s
6:	learn: 1.4580121	total: 13.9s	remaining: 1m 25s
7:	learn: 1.4485435	total: 15.7s	remaining: 1m 22s
8:	learn: 1.4384225	total: 17.9s	remaining: 1m 21s
9:	learn: 1.4366558	total: 19.8s	remaining: 1m 19s
10:	learn: 1.4306725	total: 21.5s	remaining: 1m 16s
11:	learn: 1.4235937	total: 23.5s	remaining: 1m 14s
12:	learn: 1.4223366	total: 25.5s	remaining: 1m 12s
13:	learn: 1.4173215	total: 27.5s	remaining: 1m 10s
14:	learn: 1.4151045	total: 29.2s	remaining: 1m 8s
15:	learn: 1.4121617	total: 31.2s	remaining: 1m 6s
16:	learn: 1.4110258	total: 33s	remaining: 1m 4s
17:	learn: 1.4073166	total: 34.7s	remaining: 1m 1s
18:	learn: 1.4061377	total: 36.2s	remaining: 59.1s
19:	learn: 1.4037924	total: 38.

GridSearchCV(cv=5,
             estimator=<catboost.core.CatBoostClassifier object at 0x0000024C0FB0CF40>,
             n_jobs=-1, param_grid={'depth': [5, 4], 'learning_rate': [1, 0.5]},
             scoring='roc_auc_ovr', verbose=10)

In [42]:
search.score(X_test, y_test)

0.8184286501514796

In [43]:
search.best_score_

0.8236937368690507

In [44]:
import pickle

# save the model
pickle.dump(search, open('models/catboost_encoded_with_topics.sav', 'wb'))