In [1]:
import numpy as np
import pandas as pd
import os, json, gc, re, random
from tqdm.notebook import tqdm
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import seaborn as sns

import warnings
warnings.filterwarnings("ignore")

import logging
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

ModuleNotFoundError: No module named 'plotly'

In [6]:
import torch, transformers, tokenizers
torch.__version__, transformers.__version__, tokenizers.__version__

('1.13.0+cpu', '4.22.2', '0.12.1')

In [11]:
df = pd.read_csv('kaggle_movie_train.csv')
df.head()

Unnamed: 0,id,text,genre
0,0,"eady dead, maybe even wishing he was. INT. 2ND...",thriller
1,2,"t, summa cum laude and all. And I'm about to l...",comedy
2,3,"up Come, I have a surprise.... She takes him ...",drama
3,4,ded by the two detectives. INT. JEFF'S APARTME...,thriller
4,5,"nd dismounts, just as the other children reach...",drama


In [12]:
shortlisted_genres = df["genre"].value_counts().reset_index(name="count").query("count > 200")["index"].tolist()
df = df[df["genre"].isin(shortlisted_genres)].reset_index(drop=True)
df = df.groupby("genre").head(400).reset_index(drop=True)
label_encoder = LabelEncoder()
df["genre_encoded"] = label_encoder.fit_transform(df["genre"].tolist())

movies_df = df[["text", "genre", "genre_encoded"]]
movies_df

Unnamed: 0,text,genre,genre_encoded
0,"eady dead, maybe even wishing he was. INT. 2ND...",thriller,6
1,"t, summa cum laude and all. And I'm about to l...",comedy,1
2,"up Come, I have a surprise.... She takes him ...",drama,2
3,ded by the two detectives. INT. JEFF'S APARTME...,thriller,6
4,"nd dismounts, just as the other children reach...",drama,2
...,...,...,...
2665,tands naked at the bedside. JULIEN stands besi...,other,4
2666,tang and Cadillac follow.. ACROSS THE PLAZA Th...,other,4
2667,by them ... until she comes to one showing Pa...,other,4
2668,We shoulda! KITTLE Stop blaming yourself. Nena...,other,4


In [14]:
%%time

from simpletransformers.classification import ClassificationModel

model_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "save_model_every_epoch": False,
    "save_eval_checkpoints": False,
    "max_seq_length": 512,
    "train_batch_size": 16,
    "num_train_epochs": 4,
}

# Create a ClassificationModel
model = ClassificationModel('bert', 'bert-base-cased',use_cuda=False, num_labels=len(shortlisted_genres), args=model_args)

Downloading:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/436k [00:00<?, ?B/s]

Wall time: 1min 21s


In [15]:
%%time

train_df, eval_df = train_test_split(movies_df, test_size=0.2, stratify=movies_df["genre"], random_state=42)

# Train the model
model.train_model(train_df[["text", "genre_encoded"]])

# Evaluate the model
result, model_outputs, wrong_predictions = model.eval_model(eval_df[["text", "genre_encoded"]])
print(result)

INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.


  0%|          | 0/2136 [00:00<?, ?it/s]

INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_train_bert_512_7_2


Epoch:   0%|          | 0/4 [00:00<?, ?it/s]

Running Epoch 0 of 4:   0%|          | 0/134 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
predicted_genres_encoded = list(map(lambda x: np.argmax(x), model_outputs))
predicted_genres = list(label_encoder.inverse_transform(predicted_genres_encoded))
eval_gt_labels = eval_df["genre"].tolist()
class_labels = list(label_encoder.classes_)

plt.figure(figsize=(22,18))
cf_matrix = confusion_matrix(predicted_genres, eval_gt_labels, class_labels)
ax = sns.heatmap(cf_matrix/np.sum(cf_matrix), annot=True, cmap="YlGnBu")
ax.set_xlabel('Predicted Genres', fontsize=20)
ax.set_ylabel('True Genres', fontsize=20)
ax.set_title('Confusion Matrix', fontsize=20)
ax.set_xticklabels(class_labels, rotation=90, fontsize=18)
ax.set_yticklabels(class_labels, rotation=0, fontsize=18)

plt.show()

In [None]:
for _ in range(10):

    random_idx = random.randint(0, len(eval_df)-1)
    text = eval_df.iloc[random_idx]['text']
    true_genre = eval_df.iloc[random_idx]['genre']

    # Predict with trained multiclass classification model
    predicted_genre_encoded, raw_outputs = model.predict([text])
    predicted_genre_encoded = np.array(predicted_genre_encoded)
    predicted_genre = label_encoder.inverse_transform(predicted_genre_encoded)[0]

    print(f'\nTrue Genre:'.ljust(16,' '), f'{true_genre}\n')
    print(f'Predicted Genre: {predicted_genre}\n')
    print(f'Plot: {text}\n')
    print("-------------------------------------------")