In [1]:
!pip install simpletransformers



In [2]:
import numpy as np
import pandas as pd
from simpletransformers.classification import MultiLabelClassificationModel, MultiLabelClassificationArgs
import json
from sklearn.model_selection import train_test_split
from sklearn.metrics import label_ranking_average_precision_score

In [3]:
genreToIndex = {
    "Drama": 0,
    "Comedy": 1,
    "Thriller": 2,
    "Action": 3,
    "Romance": 4,
    "Adventure": 5,
    "Crime": 6,
    "Science Fiction": 7,
    "Horror": 8,
    "Family": 9,
    "Fantasy": 10,
    "Mystery": 11,
    "Animation": 12,
    "History": 13,
    "Music": 14,
    "War": 15,
    "Documentary": 16,
    "Western": 17,
    "Foreign": 18,
    "TV Movie": 19
}

indexToGenre = [
    "Drama",
    "Comedy",
    "Thriller",
    "Action",
    "Romance",
    "Adventure",
    "Crime",
    "Science Fiction",
    "Horror",
    "Family",
    "Fantasy",
    "Mystery",
    "Animation",
    "History",
    "Music",
    "War",
    "Documentary",
    "Western",
    "Foreign",
    "TV Movie"
]

In [4]:
csv_path = '/content/drive/MyDrive/EDA/Project/data/tmdb.csv'

In [5]:
df = pd.read_csv(csv_path)
df.head()

Unnamed: 0.1,Unnamed: 0,budget,genres,homepage,id,keywords,original_language,overview,popularity,production_companies,production_countries,release_date,revenue,runtime,spoken_languages,status,tagline,title,vote_average,vote_count,cast,crew
0,0,237000000,"[{""id"": 28, ""name"": ""Action""}, {""id"": 12, ""nam...",http://www.avatarmovie.com/,19995,"[{""id"": 1463, ""name"": ""culture clash""}, {""id"":...",en,"In the 22nd century, a paraplegic Marine is di...",150.437577,"[{""name"": ""Ingenious Film Partners"", ""id"": 289...","[{""iso_3166_1"": ""US"", ""name"": ""United States o...",2009-12-10,2787965087,162.0,"[{""iso_639_1"": ""en"", ""name"": ""English""}, {""iso...",Released,Enter the World of Pandora.,Avatar,7.2,11800,"[{""cast_id"": 242, ""character"": ""Jake Sully"", ""...","[{""credit_id"": ""52fe48009251416c750aca23"", ""de..."
1,1,300000000,"[{""id"": 12, ""name"": ""Adventure""}, {""id"": 14, ""...",http://disney.go.com/disneypictures/pirates/,285,"[{""id"": 270, ""name"": ""ocean""}, {""id"": 726, ""na...",en,"Captain Barbossa, long believed to be dead, ha...",139.082615,"[{""name"": ""Walt Disney Pictures"", ""id"": 2}, {""...","[{""iso_3166_1"": ""US"", ""name"": ""United States o...",2007-05-19,961000000,169.0,"[{""iso_639_1"": ""en"", ""name"": ""English""}]",Released,"At the end of the world, the adventure begins.",Pirates of the Caribbean: At World's End,6.9,4500,"[{""cast_id"": 4, ""character"": ""Captain Jack Spa...","[{""credit_id"": ""52fe4232c3a36847f800b579"", ""de..."
2,2,245000000,"[{""id"": 28, ""name"": ""Action""}, {""id"": 12, ""nam...",http://www.sonypictures.com/movies/spectre/,206647,"[{""id"": 470, ""name"": ""spy""}, {""id"": 818, ""name...",en,A cryptic message from Bond’s past sends him o...,107.376788,"[{""name"": ""Columbia Pictures"", ""id"": 5}, {""nam...","[{""iso_3166_1"": ""GB"", ""name"": ""United Kingdom""...",2015-10-26,880674609,148.0,"[{""iso_639_1"": ""fr"", ""name"": ""Fran\u00e7ais""},...",Released,A Plan No One Escapes,Spectre,6.3,4466,"[{""cast_id"": 1, ""character"": ""James Bond"", ""cr...","[{""credit_id"": ""54805967c3a36829b5002c41"", ""de..."
3,3,250000000,"[{""id"": 28, ""name"": ""Action""}, {""id"": 80, ""nam...",http://www.thedarkknightrises.com/,49026,"[{""id"": 849, ""name"": ""dc comics""}, {""id"": 853,...",en,Following the death of District Attorney Harve...,112.31295,"[{""name"": ""Legendary Pictures"", ""id"": 923}, {""...","[{""iso_3166_1"": ""US"", ""name"": ""United States o...",2012-07-16,1084939099,165.0,"[{""iso_639_1"": ""en"", ""name"": ""English""}]",Released,The Legend Ends,The Dark Knight Rises,7.6,9106,"[{""cast_id"": 2, ""character"": ""Bruce Wayne / Ba...","[{""credit_id"": ""52fe4781c3a36847f81398c3"", ""de..."
4,4,260000000,"[{""id"": 28, ""name"": ""Action""}, {""id"": 12, ""nam...",http://movies.disney.com/john-carter,49529,"[{""id"": 818, ""name"": ""based on novel""}, {""id"":...",en,"John Carter is a war-weary, former military ca...",43.926995,"[{""name"": ""Walt Disney Pictures"", ""id"": 2}]","[{""iso_3166_1"": ""US"", ""name"": ""United States o...",2012-03-07,284139100,132.0,"[{""iso_639_1"": ""en"", ""name"": ""English""}]",Released,"Lost in our world, found in another.",John Carter,6.1,2124,"[{""cast_id"": 5, ""character"": ""John Carter"", ""c...","[{""credit_id"": ""52fe479ac3a36847f813eaa3"", ""de..."


In [6]:
def one_hot_encode_genres(df):
  labels = []
  for genres in df["genres"]:
    genreList = json.loads(genres)
    label = [0 for _ in range(len(indexToGenre))]
    for genre in genreList:
      genreName = genre["name"]
      genreIndex = genreToIndex[genreName]
      label[genreIndex] = 1
    labels.append(label)
  return labels

In [7]:
def get_overviews(df):
  overviews = df["overview"].fillna("")
  return list(overviews)

In [8]:
def get_dataframe(data_list, column_list):
  assert len(data_list) == len(column_list)

  newDf = pd.DataFrame()
  for i in range(len(data_list)):
    column = column_list[i]
    data = data_list[i]
    newDf[column] = data
    
  return newDf

In [9]:
labels = one_hot_encode_genres(df)
overviews = get_overviews(df)
newDf = get_dataframe(
    [overviews, labels],
    ["text", "labels"]
)
newDf.head()

Unnamed: 0,text,labels
0,"In the 22nd century, a paraplegic Marine is di...","[0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, ..."
1,"Captain Barbossa, long believed to be dead, ha...","[0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, ..."
2,A cryptic message from Bond’s past sends him o...,"[0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,Following the death of District Attorney Harve...,"[1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,"John Carter is a war-weary, former military ca...","[0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, ..."


In [10]:
print(labels[:10])

[[0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]


In [11]:
train_df , test_df = train_test_split(newDf, test_size=0.2)

In [12]:
# Model configuration
model_args = MultiLabelClassificationArgs(num_train_epochs=5, train_batch_size=16)

In [13]:
model  = MultiLabelClassificationModel('roberta', 'roberta-base', num_labels=len(indexToGenre), args=model_args)

Downloading:   0%|          | 0.00/481 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/478M [00:00<?, ?B/s]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForMultiLabelSequenceClassification: ['lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.bias']
- This IS expected if you are initializing RobertaForMultiLabelSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForMultiLabelSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForMultiLabelSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifier.out_proj.bias', 'c

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

In [14]:
model.train_model(train_df)

  0%|          | 0/3842 [00:00<?, ?it/s]

Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

Running Epoch 0 of 5:   0%|          | 0/241 [00:00<?, ?it/s]

Running Epoch 1 of 5:   0%|          | 0/241 [00:00<?, ?it/s]

Running Epoch 2 of 5:   0%|          | 0/241 [00:00<?, ?it/s]

Running Epoch 3 of 5:   0%|          | 0/241 [00:00<?, ?it/s]

Running Epoch 4 of 5:   0%|          | 0/241 [00:00<?, ?it/s]

(1205, 0.21507023468428133)

In [15]:
def evaluate_model(model, df):
  result, model_outputs, _ = model.eval_model(df)
  print(result)
  return model_outputs

In [16]:
# Evaluate model on test data
test_preds = evaluate_model(model, test_df)

  0%|          | 0/961 [00:00<?, ?it/s]

Running Evaluation:   0%|          | 0/121 [00:00<?, ?it/s]

{'LRAP': 0.7967673965581334, 'eval_loss': 0.20743565760121857}


In [17]:
# Evaluate model on train data
evaluate_model(model, train_df)

  0%|          | 0/3842 [00:00<?, ?it/s]

Running Evaluation:   0%|          | 0/481 [00:00<?, ?it/s]

{'LRAP': 0.9300198777546257, 'eval_loss': 0.13160332515371068}


array([[0.96777344, 0.06744385, 0.03689575, ..., 0.01159668, 0.01000977,
        0.00530243],
       [0.14245605, 0.13232422, 0.73095703, ..., 0.00881195, 0.0115509 ,
        0.00760651],
       [0.91943359, 0.04602051, 0.72167969, ..., 0.02827454, 0.02461243,
        0.01589966],
       ...,
       [0.67480469, 0.30395508, 0.03753662, ..., 0.02069092, 0.02575684,
        0.02029419],
       [0.20397949, 0.04544067, 0.22399902, ..., 0.02484131, 0.00963593,
        0.00967407],
       [0.93994141, 0.17211914, 0.01461792, ..., 0.06164551, 0.01986694,
        0.01233673]])

In [18]:
# Check some predictions
test_preds_rounded = np.rint(test_preds)
test_preds_rounded = test_preds_rounded.astype('i')
test_labels = np.array(list(test_df["labels"]))
for i in range(0, 20):
  print(test_labels[i])
  print(test_preds_rounded[i])
  print()

[0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0]

[1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

[0 0 1 1 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0]

[0 1 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

[0 1 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

[1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0]
[1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

[1 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

[1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

[1 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0]
[1 0 1 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0]

[0 1 0 0 0 1 0 0 0 1 0 0 1 0 0 0 0 0 0 0]
[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

[1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 0 1 0 0 0 1 0 0 0 0 