In [1]:
#!kill -9 -1

In [2]:
import numpy as np
import pandas as pd
import os, json, gc, re, random
from tqdm.notebook import tqdm
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

import matplotlib.pyplot as plt
import plotly.express as px
import seaborn as sns

import warnings
warnings.filterwarnings("ignore")

import logging
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

In [3]:

#%%time


!pip uninstall -q torch -y
# 
!python3 -m pip install -q torch==1.6.0 -f https://download.pytorch.org/whl/torch_stable.html
!python3 -m pip install -q -U tokenizers==0.7.0 > /dev/null
!python3 -m pip install -q -U transformers==3.0.2 > /dev/null
!python3 -m pip install -q -U simpletransformers==0.46.0 > /dev/null


[K     |████████████████████████████████| 552.8MB 29kB/s 
[31mERROR: torchvision 0.9.1+cu101 has requirement torch==1.8.1, but you'll have torch 1.6.0+cu92 which is incompatible.[0m
[31mERROR: torchtext 0.9.1 has requirement torch==1.8.1, but you'll have torch 1.6.0+cu92 which is incompatible.[0m
[?25h[31mERROR: torchtext 0.9.1 has requirement torch==1.8.1, but you'll have torch 1.6.0+cu92 which is incompatible.[0m


In [4]:
#!pip install torch==1.6.0 -f https://download.pytorch.org/whl/torch_stable.html

In [5]:
import torch, transformers, tokenizers
torch.__version__, transformers.__version__, tokenizers.__version__



('1.6.0+cu92', '3.0.2', '0.8.1.rc1')

In [6]:
movies_df = pd.read_csv("/content/drive/MyDrive/Colab Notebooks/LTP/wiki_movie_plots_deduped.csv")

In [7]:
genre_counts = movies_df.groupby('Genre').count().iloc[:,0]


In [8]:
#data pre-processing

movies_df = movies_df[(movies_df["Origin/Ethnicity"]=="American") | (movies_df["Origin/Ethnicity"]=="British")]
movies_df = movies_df[["Plot", "Genre"]]
drop_indices = movies_df[movies_df["Genre"] == "unknown" ].index
movies_df.drop(drop_indices, inplace=True)

# Combine genres: 1) "sci-fi" with "science fiction" &  2) "romantic comedy" with "romance"
movies_df["Genre"].replace({"sci-fi": "science fiction", "romantic comedy": "romance"}, inplace=True)

# Choosing movie genres based on their frequency
shortlisted_genres = movies_df["Genre"].value_counts().reset_index(name="count").query("count > 200")["index"].tolist()
movies_df = movies_df[movies_df["Genre"].isin(shortlisted_genres)].reset_index(drop=True)

# Shuffle DataFrame
movies_df = movies_df.sample(frac=1).reset_index(drop=True)

# Sample roughly equal number of movie plots from different genres (to reduce class imbalance issues)
movies_df = movies_df.groupby("Genre").head(400).reset_index(drop=True)

label_encoder = LabelEncoder()
movies_df["genre_encoded"] = label_encoder.fit_transform(movies_df["Genre"].tolist())

movies_df = movies_df[["Plot", "Genre", "genre_encoded"]]

INFO:numexpr.utils:NumExpr defaulting to 2 threads.


In [9]:
genre_counts = movies_df.groupby('Genre').count().iloc[:,0]

In [10]:
#!python3 -m pip install simpletransformers

In [11]:
from simpletransformers.classification import ClassificationModel

model_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "save_model_every_epoch": False,
    "save_eval_checkpoints": False,
    "use_cuda" : True,
    "max_seq_length": 1024,
    "train_batch_size": 8,
    "num_train_epochs": 4,
}

# Create a ClassificationModel
#model = ClassificationModel("roberta", "roberta-base", num_labels=len(shortlisted_genres), args=model_args)
model = ClassificationModel("longformer", "allenai/longformer-base-4096", num_labels=len(shortlisted_genres), args=model_args)
# bigbird, google/bigbird-roberta-base 

INFO:filelock:Lock 139898371602512 acquired on /root/.cache/torch/transformers/ea6da73c9ecf0c07b29510e423dfd58f1b8b59cba3c89383189ec8b0b6d00e82.eac3423747595976f3a9edff0a11b53ff7005ed92d70f39106dcb04b8871ee74.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=694.0, style=ProgressStyle(description_…

INFO:filelock:Lock 139898371602512 released on /root/.cache/torch/transformers/ea6da73c9ecf0c07b29510e423dfd58f1b8b59cba3c89383189ec8b0b6d00e82.eac3423747595976f3a9edff0a11b53ff7005ed92d70f39106dcb04b8871ee74.lock





INFO:filelock:Lock 139898371761040 acquired on /root/.cache/torch/transformers/dfc92dbbf5c555abf807425ebdb22b55de7a17e21fe1c48cbaa5764982c1d9c0.cd65234711d2e83d420aa696eb9186cdec6ab79ef8bf090b442cf249443dfa92.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=597257159.0, style=ProgressStyle(descri…

INFO:filelock:Lock 139898371761040 released on /root/.cache/torch/transformers/dfc92dbbf5c555abf807425ebdb22b55de7a17e21fe1c48cbaa5764982c1d9c0.cd65234711d2e83d420aa696eb9186cdec6ab79ef8bf090b442cf249443dfa92.lock





- This IS expected if you are initializing LongformerForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing LongformerForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:filelock:Lock 139898374714768 acquired on /root/.cache/torch/transformers/1ae1f5b6e2b22b25ccc04c000bb79ca847aa226d0761536b011cf7e5868f0655.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…

INFO:filelock:Lock 139898374714768 released on /root/.cache/torch/transformers/1ae1f5b6e2b22b25ccc04c000bb79ca847aa226d0761536b011cf7e5868f0655.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b.lock





INFO:filelock:Lock 139898414646032 acquired on /root/.cache/torch/transformers/f8f83199a6270d582d6245dc100e99c4155de81c9745c6248077018fe01abcfb.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…

INFO:filelock:Lock 139898414646032 released on /root/.cache/torch/transformers/f8f83199a6270d582d6245dc100e99c4155de81c9745c6248077018fe01abcfb.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda.lock





In [None]:
#torch.cuda.empty_cache()

train_df, eval_df = train_test_split(movies_df, test_size=0.2, stratify=movies_df["Genre"], random_state=42)

# Train the model
model.train_model(train_df[["Plot", "genre_encoded"]])

# Evaluate the model
result, model_outputs, wrong_predictions = model.eval_model(eval_df[["Plot", "genre_encoded"]])
print(result)



INFO:simpletransformers.classification.classification_model: Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=4750.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 4', max=594.0, style=ProgressStyle(des…

In [None]:
predicted_genres_encoded = list(map(lambda x: np.argmax(x), model_outputs))
predicted_genres = list(label_encoder.inverse_transform(predicted_genres_encoded))
eval_gt_labels = eval_df["Genre"].tolist()
class_labels = list(label_encoder.classes_)

plt.figure(figsize=(22,18))
cf_matrix = confusion_matrix(predicted_genres, eval_gt_labels, class_labels)
ax = sns.heatmap(cf_matrix/np.sum(cf_matrix), annot=True, cmap="YlGnBu")
ax.set_xlabel('Predicted Genres', fontsize=20)
ax.set_ylabel('True Genres', fontsize=20)
ax.set_title('Confusion Matrix', fontsize=20)
ax.set_xticklabels(class_labels, rotation=90, fontsize=18)
ax.set_yticklabels(class_labels, rotation=0, fontsize=18)

plt.show()

In [None]:
random_idx = random.randint(0, len(eval_df)-1)
text = eval_df.iloc[random_idx]['Plot']
true_genre = eval_df.iloc[random_idx]['Genre']

# Predict with trained multiclass classification model
predicted_genre_encoded, raw_outputs = model.predict([text])
predicted_genre_encoded = np.array(predicted_genre_encoded)
predicted_genre = label_encoder.inverse_transform(predicted_genre_encoded)[0]

print(f'\nTrue Genre:'.ljust(16,' '), f'{true_genre}\n')
print(f'Predicted Genre: {predicted_genre}\n')
print(f'Plot: {text}\n')
print("-------------------------------------------")