# Models

In [24]:
import pandas as pd
import numpy as np
from sklearn.naive_bayes import MultinomialNB
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import GridSearchCV
from sklearn import metrics

In [4]:
# read in data
imdb_movie = pd.read_csv('data/imdb_multilabel.csv')

In [5]:
imdb_movie.head()

Unnamed: 0.1,Unnamed: 0,title,imdb_id,topRank,bottomRank,metaScore,plot,rating,ratingCount,reviewCount,runningTimeInMinutes,userRatingCount,userScore,year,all_genre,genre,plot_list,genreCount,genre_code,all_genre_encode
0,0,"I, Tonya",tt5580036,930.0,17643.0,77.0,From the proverbial wrong side of the tracks i...,7.6,67667.0,46,120.0,235,7.8,2017.0,"['Biography', 'Comedy', 'Drama', 'Sport']",sport,"['From', 'the', 'proverbial', 'wrong', 'side',...",4,0,[1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
1,1,Cars 3,tt3606752,2256.0,11547.0,59.0,Blindsided by a new generation of blazing-fast...,6.8,41896.0,41,102.0,232,6.9,2017.0,"['Animation', 'Adventure', 'Comedy', 'Family',...",sport,"['Blindsided', 'by', 'a', 'new', 'generation',...",5,0,[1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. ...
2,2,Creed,tt3076658,847.0,17840.0,82.0,Adonis Johnson is the son of the famous boxing...,7.6,193206.0,42,133.0,614,8.0,2015.0,"['Drama', 'Sport']",sport,"['Adonis', 'Johnson', 'is', 'the', 'son', 'of'...",2,0,[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
3,3,Battle of the Sexes,tt4622512,2303.0,11228.0,73.0,In the wake of the sexual revolution and the r...,6.8,27960.0,46,121.0,102,6.3,2017.0,"['Biography', 'Comedy', 'Drama', 'Sport']",sport,"['In', 'the', 'wake', 'of', 'the', 'sexual', '...",4,0,[1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
4,4,Borg McEnroe,tt5727282,,12891.0,57.0,The story of the 1980s tennis rivalry between ...,7.0,9800.0,13,107.0,0,,2017.0,"['Biography', 'Drama', 'Sport']",sport,"['The', 'story', 'of', 'the', '1980s', 'tennis...",3,0,[1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...


In [25]:
imdb_movie["all_genre_encode"] = np.array(imdb_movie["all_genre_encode"])

In [26]:
imdb_movie["all_genre_encode"]

0       [1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
1       [1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. ...
2       [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
3       [1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
4       [1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
5       [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. ...
6       [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
7       [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
8       [1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
9       [1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. ...
10      [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
11      [1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
12      [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
13      [1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. ...
14      [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
15      [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
16      [1. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ...
17      [1. 0.

## Baseline Model - Naive Bayes 

In [27]:
naive_X_train, naive_X_test, naive_y_train, naive_y_test = train_test_split(imdb_movie['plot_list'], 
                                                                            imdb_movie['all_genre_encode'],
                                                                            test_size = 0.2,
                                                                            random_state = 209,
                                                                            )

In [28]:
naive_X_train.head()

1118    ['When', 'a', 'meteorite', 'from', 'outer', 's...
2405    ['Details', 'the', 'story', 'of', 'a', 'sober'...
2128    ['After', 'Custer', 'and', 'the', '7th', 'Cava...
2067    ['Armando', 'Alvarez', '(Will', 'Ferrell)', 'h...
3089    ['300', 'years', 'have', 'passed', 'since', 't...
Name: plot_list, dtype: object

In [29]:
vectorizer = CountVectorizer(stop_words='english')

In [30]:
naive_X_train = vectorizer.fit_transform(naive_X_train)
naive_X_test = vectorizer.transform(naive_X_test)

In [31]:
nb_tuning_parameter = {'alpha' : [0.1, 0.5, 1, 2, 5]}
nb = GridSearchCV(MultinomialNB(), nb_tuning_parameter, cv = 5)
nb.fit(naive_X_train, naive_y_train)



GridSearchCV(cv=5, error_score='raise',
       estimator=MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True),
       fit_params={}, iid=True, n_jobs=1,
       param_grid={'alpha': [0.1, 0.5, 1, 2, 5]}, pre_dispatch='2*n_jobs',
       refit=True, return_train_score=True, scoring=None, verbose=0)

In [32]:
naive_train_pred = nb.predict(naive_X_train)
naive_test_pred = nb.predict(naive_X_test)
print('Naive Bayes Accuracy on Train : {}'.format(metrics.accuracy_score(naive_y_train, naive_train_pred)))
print('Naive Bayes Accuracy on Test : {}'.format(metrics.accuracy_score(naive_y_test, naive_test_pred)))

Naive Bayes Accuracy on Train : 0.9720921155347385
Naive Bayes Accuracy on Test : 0.14274570982839313


In [47]:
nb.predict_proba(naive_X_train)

array([[  6.92734262e-036,   6.06975446e-041,   4.72004043e-033, ...,
          2.96255265e-079,   3.39740821e-077,   3.47149655e-079],
       [  9.23438585e-003,   7.15604558e-005,   1.33974719e-001, ...,
          1.11433666e-008,   1.98348992e-008,   1.13214352e-008],
       [  1.01364204e-049,   2.41129329e-044,   1.34541689e-052, ...,
          5.03626124e-055,   2.96372666e-055,   5.67855760e-055],
       ..., 
       [  1.72362417e-223,   5.06201350e-229,   3.46215857e-256, ...,
          0.00000000e+000,   0.00000000e+000,   0.00000000e+000],
       [  2.52941066e-024,   4.75596796e-027,   9.99780204e-028, ...,
          1.60387967e-048,   2.38570165e-046,   1.77194134e-048],
       [  4.81791655e-024,   2.19391180e-023,   1.98801009e-024, ...,
          2.02869105e-032,   1.53307635e-032,   2.16150467e-032]])

In [42]:
naive_train_pred[4]

'[0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0.]'

In [46]:
naive_y_train.iloc[4]

'[0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0.]'