In [1]:
from pathlib import Path
from urllib.parse import unquote

import numpy as np
import pandas as pd
import plotly.express as px
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Introduction

Our [dataset](https://www.kaggle.com/jrobischon/wikipedia-movie-plots) contains plot descriptions for 34,886 movies scraped from Wikipedia. Our goal is to predict movie genre based on just the plot description. However, we may also include some of the other variables if needed, so during data cleaning we try to clean all the variables.

In [2]:
# Read data
df = pd.read_csv("../input/wikipedia-movie-plots/wiki_movie_plots_deduped.csv")
df

Unnamed: 0,Release Year,Title,Origin/Ethnicity,Director,Cast,Genre,Wiki Page,Plot
0,1901,Kansas Saloon Smashers,American,Unknown,,unknown,https://en.wikipedia.org/wiki/Kansas_Saloon_Sm...,"A bartender is working at a saloon, serving dr..."
1,1901,Love by the Light of the Moon,American,Unknown,,unknown,https://en.wikipedia.org/wiki/Love_by_the_Ligh...,"The moon, painted with a smiling face hangs ov..."
2,1901,The Martyred Presidents,American,Unknown,,unknown,https://en.wikipedia.org/wiki/The_Martyred_Pre...,"The film, just over a minute long, is composed..."
3,1901,"Terrible Teddy, the Grizzly King",American,Unknown,,unknown,"https://en.wikipedia.org/wiki/Terrible_Teddy,_...",Lasting just 61 seconds and consisting of two ...
4,1902,Jack and the Beanstalk,American,"George S. Fleming, Edwin S. Porter",,unknown,https://en.wikipedia.org/wiki/Jack_and_the_Bea...,The earliest known adaptation of the classic f...
...,...,...,...,...,...,...,...,...
34881,2014,The Water Diviner,Turkish,Director: Russell Crowe,Director: Russell Crowe\r\nCast: Russell Crowe...,unknown,https://en.wikipedia.org/wiki/The_Water_Diviner,"The film begins in 1919, just after World War ..."
34882,2017,Çalgı Çengi İkimiz,Turkish,Selçuk Aydemir,"Ahmet Kural, Murat Cemcir",comedy,https://en.wikipedia.org/wiki/%C3%87alg%C4%B1_...,"Two musicians, Salih and Gürkan, described the..."
34883,2017,Olanlar Oldu,Turkish,Hakan Algül,"Ata Demirer, Tuvana Türkay, Ülkü Duru",comedy,https://en.wikipedia.org/wiki/Olanlar_Oldu,"Zafer, a sailor living with his mother Döndü i..."
34884,2017,Non-Transferable,Turkish,Brendan Bradley,"YouTubers Shanna Malcolm, Shira Lazar, Sara Fl...",romantic comedy,https://en.wikipedia.org/wiki/Non-Transferable...,The film centres around a young woman named Am...


Several observations can be made from the printed 10 rows:

1. There are missing values, potentially a lot of them. Most obvious ones are the `NaN`'s in `Cast`, but we also have "unknown"'s in `Director` and `Genre`.
2. A movie could have multiple directors, cast members and genres. We have a multi-label classification problem.
3. Text need to be cleaned! For example, "Director: " should be removed from entry 34881. Only one separator should be used, as now we have `,`, `\r\n`, `;`, etc.
4. `Wiki Page` seems like metadata that we could simply drop from the table.
5. Special characters (or rather, non-alphabetical characters) exist in the text, so models trained specifically on English text may not suffice.

The package [pandas-profiling](https://github.com/pandas-profiling/pandas-profiling) could be used to explore the data, although it's more suitable for continuous variables.

```python
from pandas_profiling import ProfileReport
profile = ProfileReport(dat)
profile.to_notebook_iframe()
```

# Data cleaning

Variables `Release Year` and `Title` can be used as-is (assuming we even use them in our model). Note that `Title` may contain non-English characters.

## Origin/Ethnicity

We will start with the easiest variable `Origin/Ethnicity`. By easiest I mean it has the lowest cardinality and is mostly clean. 

In [3]:
# Check unique values in Origin/Ethnicity
print(
    f"Total missing entries in Origin/Ethnicity: {df['Origin/Ethnicity'].isna().sum()}"
)
print(f"Number of unique values: {df['Origin/Ethnicity'].nunique()}")
df["Origin/Ethnicity"].value_counts()

Total missing entries in Origin/Ethnicity: 0
Number of unique values: 24


American        17377
British          3670
Bollywood        2931
Tamil            2599
Telugu           1311
Japanese         1188
Malayalam        1095
Hong Kong         791
Canadian          723
Australian        576
South_Korean      522
Chinese           463
Kannada           444
Bengali           306
Russian           232
Marathi           141
Filipino          128
Bangladeshi        87
Punjabi            84
Malaysian          70
Turkish            70
Egyptian           67
Assamese            9
Maldivian           2
Name: Origin/Ethnicity, dtype: int64

It seems the dataset contains predominately (49.81%) American movies. Indian movies also take a large proportion, but they are separated into different ethnic groups or languages in this dataset, namely `Telugu`, `Malayalam`, `Kannada`, `Marathi`, and `Assamese`. We could consider combining them since there's still 24 unique values for this variable. We may also consider pooling some of the levels with very few observations together into an "Other" group.

## Director

Next up we have the `Director` variable, which is a lot harder to clean. A good way to extract the entries with dirty data is to filter for characters that we don't expect to see in this column. We will run this following line iteratively after cleaning some of the identified bad data. The following steps were applied sequentially to clean this column:

1. Strip whitespace from the beginning and end of the strings.
2. Remove additional information given in parentheses. These are mostly the role of the person.
3. Replace the alias "[Alan Smithee](https://en.wikipedia.org/wiki/Alan_Smithee)" with "Unknown".
4. Use `|` as separators for multiple directors. There's a mixture of `,`, `&`, `/`, `;` among others in the original dataset.
5. Remove the "Director: " string that comes before names in some cases.
6. Remove two awards strings that accompany names in very few cases.
7. Replace entries like "3 directors" with "Unknown".
8. Remove footnotes, e.g. in `Ernst Lubitsch[19]`.
9. Replace all "Unknown" entries with `NaN`.

In [4]:
df[df.Director.str.contains(r"[^\w ,.\-\"\']")]

Unnamed: 0,Release Year,Title,Origin/Ethnicity,Director,Cast,Genre,Wiki Page,Plot
35,1910,The Wonderful Wizard of Oz,American,Otis Turner (unconfirmed),Bebe Daniels,unknown,https://en.wikipedia.org/wiki/The_Wonderful_Wi...,"In Kansas, Dorothy and Imogene the cow are cha..."
304,1919,When the Clouds Roll By,American,Victor Fleming & Theodore Reed,"Douglas Fairbanks, Kathleen Clifford",comedy,https://en.wikipedia.org/wiki/When_the_Clouds_...,"As described in a film magazine,[3] Daniel Boo..."
2130,1936,San Francisco,American,W. S. Van Dyke (Best Director nominee),"Clark Gable, Jeanette MacDonald, Spencer Tracy...","drama, adventure",https://en.wikipedia.org/wiki/San_Francisco_(1...,The film opens with two men in boxing gloves a...
2853,1940,Knock Knock,American,Walter Lantz (producer),,animated short,https://en.wikipedia.org/wiki/Knock_Knock_(194...,The cartoon ostensibly stars Andy Panda (voice...
2933,1940,The Shop Around the Corner,American,Ernst Lubitsch[19],"James Stewart, Margaret Sullavan",romantic comedy,https://en.wikipedia.org/wiki/The_Shop_Around_...,Alfred Kralik (James Stewart) is the top sales...
...,...,...,...,...,...,...,...,...
34849,2010,Black Dogs Barking,Turkish,Mehmet Bahadır Er & Maryna Gorbach,"Cemal Toktaş, Volga Sorgu & Erkan Can",drama,https://en.wikipedia.org/wiki/Black_Dogs_Barking,Selim’s family has migrated from Anatolia to İ...
34878,2014,Mandıra Filozofu,Turkish,Director: Müfit Can Saçıntı,Director: Müfit Can Saçıntı\r\nCast: Rasim Özt...,unknown,https://en.wikipedia.org/wiki/Mand%C4%B1ra_Fil...,Cavit an ambitious industralist in İstanbul pl...
34879,2014,Winter Sleep,Turkish,Director: Nuri Bilge Ceylan,Director: Nuri Bilge Ceylan\r\nCast: Haluk Bil...,unknown,https://en.wikipedia.org/wiki/Winter_Sleep_(film),"Aydın, a former actor, owns a mountaintop hote..."
34880,2014,Sivas,Turkish,Director: Kaan Müjdeci,Director: Kaan Müjdeci\r\nCast: Dogan Izci,unknown,https://en.wikipedia.org/wiki/Sivas_(film),The film follows an eleven-year-old boy named ...


In [5]:
df["Director"] = df["Director"].str.strip()

# Info in parentheses: unconfirmed, producer, co-director, uncredited, old...
df["Director"] = df["Director"].str.replace(r"\s+\((.*?)\)", "", regex=True)
df["Director"] = df["Director"].str.replace(r"\s+\([dD]irector", "", regex=True)
df.loc[df.Director.str.match(r'"?Alan Smithee"?'), "Director"] = "Unknown"

# Uniform separators for multiple directors
df["Director"] = df["Director"].str.replace(r"\s*([,&;/]| and )\s*", " | ", regex=True)
df["Director"] = df["Director"].str.replace(
    r" \| $", "", regex=True
)  # trailing separator

# Extra information in the names
# Director(s): xxx
df["Director"] = df["Director"].str.replace(r"^Director[s]?:\s*", "", regex=True)
# Awards: Academy Award for Best Director, Best Director nominee
df["Director"] = df["Director"].str.replace(r" (Best|Academy).*$", "", regex=True)
# Another special case that I don't want to write a regex
df.loc[
    df.Title == "Veeraadhi Veeran Dubbed from Telugu", "Director"
] = "B. Vittalacharya"

# 3 directors, Six directors, ...
df.loc[df.Director.str.contains(r"[dD]irectors$"), "Director"] = "Unknown"

# Footnotes
df["Director"] = df["Director"].str.replace(r"\[\d+\]", "", regex=True)

# "Unknown" is an indicator for missing values
df.loc[df.Director == "Unknown", "Director"] = np.nan

After all this cleaning, let's see if there's any other special characters that we missed.

In [6]:
df[df.Director.notna() & df.Director.str.contains(r"[^\w \|.\-\"\']")]

Unnamed: 0,Release Year,Title,Origin/Ethnicity,Director,Cast,Genre,Wiki Page,Plot
21171,2011,Dimensions,British,Sloane U’Ren,Director: Sloane U’Ren\r\nCast: Henry Lloyd-Hu...,unknown,https://en.wikipedia.org/wiki/Dimensions_(film),"The film follows Stephen, a brilliant young sc..."
28223,2011,Kottarathil Kuttibhootham (കൊട്ടാരത്തിൽ കുട്ടി...,Malayalam,Kumar Nanda – Basheer,"Mukesh, Jagadish","comedy, fantasy",https://en.wikipedia.org/wiki/Kottarathil_Kutt...,Kottarathil Kutty Bhootham tells the story of ...
29410,1968,Uyarndha Manithan,Tamil,Krishnan–Panju,"Sivaji Ganesan, Sowcar Janaki, Vanisri, Sivaku...",unknown,https://en.wikipedia.org/wiki/Uyarndha_Manithan,Rajalingam (alias Raju) (Sivaji Ganesan) is th...
29432,1970,Engal Thangam,Tamil,Krishnan–Panju,"M. G. Ramachandran, Jayalalithaa, A. V. M. Raj...",unknown,https://en.wikipedia.org/wiki/Engal_Thangam,The capital of the Tamil Country at the end of...
29459,1971,Rangarattinam,Tamil,Krishnan–Panju,"Gemini Ganesan, Sowcar Janaki, K. A. Thangavelu",unknown,https://en.wikipedia.org/wiki/Rangarattinam,A girl loses her mental balance while riding a...
29471,1972,Idhaya Veenai,Tamil,Krishnan–Panju,"M. G. Ramachandran, Manjula, Sivakumar, Lakshmi",unknown,https://en.wikipedia.org/wiki/Idhaya_Veenai,"Somewhere in Chennai, several years previously..."
29481,1972,Pillaiyo Pillai,Tamil,Krishnan–Panju,"M. K. Muthu, Lakshmi, R. S. Manohar, C. R. Vij...",unknown,https://en.wikipedia.org/wiki/Pillaiyo_Pillai,The story deals with a villain Ganngatharan(R....
29494,1973,Pookari,Tamil,Krishnan–Panju,"M. K. Muthu, Manjula, Vennira Aadai Nirmala, J...",unknown,https://en.wikipedia.org/wiki/Pookari,Valli (Manjula) is a flower seller and her bro...
29505,1974,Kaliyuga Kannan,Tamil,Krishnan–Panju,"Jaishankar, Sowcar Janaki, Jayachitra, Thengai...",unknown,https://en.wikipedia.org/wiki/Kaliyuga_Kannan,Kaliyuga Kannan is a drama of faith and disbel...
29544,1977,Chakravarthy,Tamil,Krishnan–Panju,"Jaishankar, Sharada, Sripriya, Sreekanth, Then...",unknown,https://en.wikipedia.org/wiki/Chakravarthy_(19...,"Chakravarthy, a rich boy; Ranjit, a criminal's..."


Only `–`, `=` and `’`, but they don't affect our analyses so I'm good with this.

## Cast

The `Cast` column is very similar to the `Director` column, with a few additional cases.

1. There's the Windows newline character `\r\n` that's sometimes used as a separator, and sometimes just as a space.
2. There's also the `U+00a0` Unicode space character.
3. Both `unknown` and ␣ were present as missing values.
4. Parentheses and square brackets were appearing randomly. Sometimes not even paired.

In [7]:
# Misplaced information
df.loc[df.Title == "Marwencol", "Cast"] = "Unknown"  # plot in cast

# Remove newline characters
df["Cast"] = df["Cast"].str.strip()
df["Cast"] = df["Cast"].str.replace(r",?\r\n", ", ", regex=True)
df["Cast"] = df["Cast"].str.replace(u"\xa0", " ", regex=True)

# Role inside parentheses: director, screenplay, cameo, in dual role, ...
df["Cast"] = df["Cast"].str.replace(r"\(+(.*?)\)+", "", regex=True)

# Footnotes
df["Cast"] = df["Cast"].str.replace(r"\[\d+\]", "", regex=True)
df["Cast"] = df["Cast"].str.replace(r"(\[|\])*", "", regex=True)

# Normalize separators
df["Cast"] = df["Cast"].str.replace(r"\s*([&;/,]| and )\s*", " | ", regex=True)
df["Cast"] = df["Cast"].str.replace(r"( \| )+$", "", regex=True)

# Annotation of people
df["Cast"] = df["Cast"].str.replace(
    r"(Director|Cast|Introducing)s?:\s*", "", regex=True
)

# Fix missing values and split cast
df.loc[(df.Cast.str.match(r"^\s*$")) | (df.Cast == "Unknown"), "Cast"] = np.nan
# dat["Cast"] = dat["Cast"].str.split(", ")

Finally there's a **big** difference -- in some entries the `Genre` of the movie was given in the `Cast` column! Often times these entries have `unknown` genres or some random values.

In [8]:
df[df.Cast.notna() & df.Cast.str.contains(r"[Cc]omedy|[Rr]omantic|[Mm]usical")].head()

Unnamed: 0,Release Year,Title,Origin/Ethnicity,Director,Cast,Genre,Wiki Page,Plot
80,1914,Charlie Chaplin,American,Charlie Chaplin | Mabel Normand,Comedy,unknown,https://en.wikipedia.org/wiki/His_Trysting_Place,Charlie and his friend Ambrose meet in a resta...
17255,2017,Roadside Attractions,American,,Comedy | Drama,"usa, can",https://en.wikipedia.org/wiki/Beatriz_at_Dinner,The film opens with Beatriz (Salma Hayek) rowi...
17294,2017,Bleecker Street / FilmNation Entertainment,American,,Heist | Comedy,usa,https://en.wikipedia.org/wiki/Logan_Lucky,Jimmy Logan is laid off from his construction ...
21089,2010,Edgar Wright,British,Michael Cera | Mary Elizabeth Winstead | Ellen...,Romantic comedy,august 25,https://en.wikipedia.org/wiki/Scott_Pilgrim_vs...,"In Toronto, 23-year-old Scott Pilgrim is a bas..."
23242,1988,Stanley Kwan,Hong Kong,Leslie Cheung | Anita Mui,Romantic fantasy,unknown,https://en.wikipedia.org/wiki/Rouge_(film),"In 1980s Hong Kong, newspaperman Yuen (Alex Ma..."


And the problem is bigger than we thought. This probably explains why we have movies that were directed by so many people. The entire row was shifted -- `Director` is in `Title`, `Cast` is in `Director`, and `Genre` is in `Cast`.

## Wiki Page

One possible way to mitigate this issue is to parse the last part of the URL in `Wiki Page`. If the extracted values don't match the value in `Title`, then we *may* have a shift in columns. We first separate the dataset into two tables -- the "good" table where `Title` and `Wiki Page` are exact matches, and the "bad" table where there is an mismatch. Then we investigate the mismatches and hopefully find a fix for some of the movies with unknown genres.

In [9]:
# Parse the percent-encoded string in the URLs
df["Wiki Title"] = df["Wiki Page"].str.replace("_", " ")
df["Wiki Title"] = df["Wiki Title"].apply(lambda x: unquote(x.split("/")[-1]))

# Remove (YYYY film) in links and titles
df["Wiki Title"] = df["Wiki Title"].str.replace(r"\s*\((\d+ )?film\)", "", regex=True)
df["Title"] = df["Title"].str.replace(r"\s*\((\d+ )?film\)", "", regex=True)

# Cleaning genres

We've learned a lot about the potential problems we may encounter in parsing the `Genre` column. First thing we may notice is there's **way** too many genres in this dataset. Fortunately we will soon see that a lot of these are not actually genres -- some are misplaced strings, some are detailed descriptions, and some are simply uninterpretable.

In [10]:
df.Genre.unique().shape[0]

2265

## Misplaced genres in Cast

We first fix the genres stored in the `Cast` column. Recall that when the genre is in `Cast`, the corresponding value in `Genre` is not always `unknown`, so we can't simply take the approach below:

```sql
SELECT * FROM df WHERE
    Genre = 'unknown' AND
    Cast IS NOT NULL
```

Instead, we need a way to figure out all the values in `Genre` that are probably not, well, genres. A simple way to do this is to start from a genre that's known to be in the `Cast` column, e.g. "comedy". Bacause certain movies are tagged "comedy" along with some other genre (e.g. "romance"), we can repeat this process iteratively to find tokens that are probably genres in the column.

In [11]:
def find_related_genres(cast_col: pd.Series, regex: str):
    related_genres = (
        cast_col[cast_col.str.contains(regex, regex=True, case=False)]
        .apply(lambda x: x.split(" | "))
        .explode()
    )
    related_genres = related_genres[
        ~related_genres.str.contains(regex, regex=True, case=False)
    ]
    return related_genres.unique()


genres_in_cast = {"comedy"}
all_casts = df.Cast[(df.Cast.notna()) & (df.Title != df["Wiki Title"])]
# Filter out some edge cases
all_casts = all_casts[~all_casts.str.contains("Thriller Manju", regex=False)]
all_casts = all_casts[~all_casts.str.contains("Love", regex=False)]

while True:
    regex = rf"\b(?:{'|'.join(genres_in_cast)})\b"
    related_genres = find_related_genres(all_casts, regex)
    related_genres = [x.lower() for x in related_genres]
    if len(related_genres) == 0:
        break

    genres_in_cast = genres_in_cast.union(related_genres)

genres_in_cast

{'action',
 'comedy',
 'crime',
 'drama',
 'entered into the 11th moscow international film festival',
 'family',
 'fantasy',
 'heist',
 'historical',
 'horror',
 'musical',
 'romance',
 'social',
 'sports',
 'suspense',
 'thriller',
 'war'}

This isn't bad! We just need to remove the one false positive and get all the entries matching those genres. There's a total of 68 entries in this dataset matching the following criteria.

In [12]:
genres_in_cast.remove("entered into the 11th moscow international film festival")
regex = rf"\b(?:{'|'.join(genres_in_cast)})\b"

bad_genre = (
    df.Cast.notna()
    & df.Cast.str.contains(regex, regex=True, case=False)
    & (~df.Title.isin(["Jaihind", "51 Birch Street"]))
)

df.loc[bad_genre, "Genre"] = df["Cast"]
df.loc[bad_genre, "Cast"] = df["Director"]
df.loc[bad_genre, "Director"] = df["Title"]
df.loc[bad_genre, "Title"] = df["Wiki Title"]

## Simple cleaning

Now we run some basic cleaning on the `Genre` column that can be catched by glimpsing sample rows.

In [13]:
df["Genre"] = df["Genre"].str.replace(u"\xa0", " ", regex=False)
df["Genre"] = df["Genre"].str.strip()

# Remove footnotes
df["Genre"] = df.Genre.str.replace("\s*\[.*\]", "", regex=True)

# One dash to rule them all
df["Genre"] = df["Genre"].str.replace("–|—", "-", regex=True)
df.loc[df.Genre.isin(["-", ""]), "Genre"] = "unknown"

# Trailing punctuation
df["Genre"] = df["Genre"].str.replace("[.,;\s]+$", "", regex=True)

# I don't care who produced it
df["Genre"] = df["Genre"].str.replace(
    "(warner bros|paramount|united artists)\. ", "", regex=True
)
df["Genre"] = df["Genre"].str.replace("(co-)?produced .*$", "", regex=True)

df.loc[df.Genre.str.contains("films$"), "Genre"] = "unknown"
df["Genre"] = df["Genre"].str.replace("\s*film\s*", "", regex=True)

# or what the movie is based on
df["Genre"] = df["Genre"].str.replace(" (about|based on).*$", "", regex=True)

# or info in parentheses. There's some cases where the sport is annotated,
# but I don't think dividing the sports genre into multiple groups would help
df["Genre"] = df["Genre"].str.replace(r"\s*\(.*?\)\s*", "", regex=True)

# Different name same meanings
df["Genre"] = df["Genre"].str.replace("3-d", "3d", regex=False)
df["Genre"] = df["Genre"].str.replace("science[- ]fiction", "sci-fi", regex=True)
df["Genre"] = df["Genre"].str.replace("sci fi", "sci-fi", regex=False)
df["Genre"] = df["Genre"].str.replace(r"bio-?pic", "biographical", regex=True)
df["Genre"] = df["Genre"].str.replace("biography", "biographical", regex=False)
df["Genre"] = df["Genre"].str.replace("ww|(world war)", " worldwar", regex=True)
df["Genre"] = df["Genre"].str.replace(r"(kung|gun)( |-)fu", "kungfu", regex=True)
df["Genre"] = df["Genre"].str.replace("007", "james bond", regex=False)
df["Genre"] = df["Genre"].str.replace(
    r"rom(antic|ance)? com(edy)?", "rom-com", regex=True
)
df["Genre"] = df["Genre"].str.replace(
    r"(action|comedy|horror) masala", r"\1 | masala", regex=True
)

# Typos
df["Genre"] = df["Genre"].str.replace("family. ", "family | ", regex=False)
df["Genre"] = df["Genre"].str.replace("familya", "family", regex=False)
df["Genre"] = df["Genre"].str.replace("supeheroes", "superheroes", regex=False)
df["Genre"] = df["Genre"].str.replace(r" in 3d.?", ", 3d", regex=True)

# Special cases
df.loc[df.Title == "Tales of Manhattan", "Genre"] = "drama | comedy"
df.loc[df.Title == "Aaram", "Genre"] = "romance"
df.loc[df.Title == "Roadside Attractions", "Genre"] = "comedy | drama"

# Separators
df["Genre"] = df["Genre"].str.replace(r"\s*([;,/]| and )\s*", " | ", regex=True)
df["Genre"] = df["Genre"].str.replace(r"\s+[&-]\s+", " | ", regex=True)
df["Genre"] = df["Genre"].str.replace(r"( \| )+$", "", regex=True)

## Embeddings

There's so many more edge cases and since nobody wants to go through thousands of entries manually, we may project the values into some embedding space and find clusters, in hope that actual genres will be close to each other. To make visualization easier, we first used `SpaCy` to covert the unique tokens into vectors, and then used `UMAP` to reduce the 768-dimensional vector to two dimensions. Spectral clustering was then performed for coloring the points.

Since UMAP is involved and it is not deterministic, we load the projections from a previous run. The code to generate the data is attached below.

```python
f_umap = Path("./genre_umap_projections.csv")

# Load English tokenizer, tagger, parser and NER
# TODO: may require multiple languages
import spacy
from sklearn.cluster import SpectralClustering
from umap import UMAP

spacy.prefer_gpu()
nlp = spacy.load("en_core_web_trf")

# Get unique genre tokens
genres = df_genres.Genre.unique().tolist()

# Lemmatize tokens
doc = nlp.pipe(genres)
genre_lemma = [" ".join([x.lemma_.strip() for x in token]) for token in doc]
genre_lemma = pd.Series(genre_lemma)
genre_lemma = genre_lemma.str.replace(r"\s*([\-.'])\s*", r"\1", regex=True)

# Get embeddings using default SpaCy transformer model
doc = nlp.pipe(genre_lemma)
genre_vecs = [token._.trf_data.tensors[-1].flatten() for token in doc]

# Reduce the dimensions of embedding tensors with UMAP
umap_2d = UMAP(n_components=2, learning_rate=0.5, init="spectral", random_state=42)
proj_2d = umap_2d.fit_transform(genre_vecs)

# Spectral clustering on the projections
sc = SpectralClustering(n_clusters=10, random_state=42)
genre_clust = sc.fit_predict(proj_2d).astype(str)

# Concatenate results in a data frame
genre_proj = pd.DataFrame(
    {
        "Genre": genres,
        "GenreLemma": genre_lemma,
        "Cluster": genre_clust,
        "UMAP0": proj_2d[:, 0],
        "UMAP1": proj_2d[:, 1],
    }
)
genre_proj = genre_proj.sort_values(["Cluster", "GenreLemma"])
genre_proj.to_csv(f_umap, index=False)
```

In [14]:
df_genres = (
    df.reset_index()
    .rename(columns={"index": "movieID"})
    .assign(
        Genre=lambda dat: dat["Genre"].apply(lambda x: x.lower().split(" | ")),
    )
    .explode("Genre")
    .filter(["movieID", "Genre"])
)

In [15]:
# To replicate these results, we download the UMAP projections from a previous run
# This is done because UMAP isn't deterministic
genre_proj = pd.read_csv("https://raw.githubusercontent.com/y1zhou/wiki-movie-plots/e953ac189fb603a391e45566804c4ef0fd8835a7/EDA/genre_umap_projections.csv")
genre_proj.Cluster = genre_proj.Cluster.astype(str)

In [16]:
# Visualize projections
fig = px.scatter(
    genre_proj,
    x="UMAP0",
    y="UMAP1",
    hover_name="GenreLemma",
    color="Cluster",
    width=1000,
    height=800,
)
fig.show()

We have three big clusters, namely $A = \{0\}$, $B = \{2, 7, 8\}$, and $C = \{1, 3, 4, 5, 6, 9\}$. Cluster 0 contains mostly names and unrecognizable strings, so **we may drop these entries**. Cluster $B$ contains mostly single-word strings. Cluster 8 contains a lot of typos (**fix**), but also some valid genres (**keep**) and some names (**replace with unknown**). Clusters 2 and 7 contains mostly valid single-word genres, with a few expectations like `sword` and `p.o.w`.

As for Cluster $C$, things are a bit more complicated. These are typically longer strings ranging from two to five words. Some are multiple genres separated by spaces, so our previous regexs didn't separate them successfully. We also have some companies and very specific descriptions mixed in here, e.g. `biographical of pioneering american photographer eadweard muybridge`.

## Dimension reduction on genres
So let's try to fix this! We first join the table with the full dataset to get the lemmatized genres for each movie.

In [17]:
df_genres = (
    df_genres.merge(genre_proj, on="Genre")
    .assign(Genre=lambda x: x["GenreLemma"].str.strip())
    .drop(columns="GenreLemma")
)
df_genres.head()

Unnamed: 0,movieID,Genre,Cluster,UMAP0,UMAP1
0,0,unknown,2,6.368839,1.837754
1,1,unknown,2,6.368839,1.837754
2,2,unknown,2,6.368839,1.837754
3,3,unknown,2,6.368839,1.837754
4,4,unknown,2,6.368839,1.837754


Upon inspection, only three strings from Cluster 0 are genres, so we drop everything else.

In [18]:
# Drop all those genres from Cluster 0 that don't make sense
clust0_unknown_movies = df_genres[df_genres.Cluster == "0"]
clust0_unknown_movies = clust0_unknown_movies.loc[
    ~clust0_unknown_movies.Genre.isin(["james bond", "buddy cop", "crime ttriller"]),
    "movieID",
]
df_genres.loc[df_genres.movieID.isin(clust0_unknown_movies), "Genre"] = "unknown"

Cluster 8 is more complicated - we have a mixture of genres, typos and random names. Luckily a lot of the names cluster together, so we can remove them based on the UMAP projections. Interestingly, the projection of `unknown` also falls here. Finally, we manually filter out some invalid genres.

In [19]:
invalid_genres = [
    "ram",
    "sada",
    "jyothika",
    "sumanth",
    "mammootty",
    "ajay",
    "suhasini",
    "sunil",
    "arya",
    "pooja",
    "sneha",
    "tabu",
    "keiji",
    "stoner",
    "harem",
    "jeet",
    "sf",
    "dev",
    "pink",
    "sentiment",
    "000 year ago in the canadian arctic",
    "16 mm",
    "adapt from the play by alexandre goyette",
    "adaptation of a play by michel marc bouchard",
    "british-german co-production",
    "direct-to-dvd",
    "enter into the 11th moscow international film festival",
    "fiction make with the nationalboard",
    "find footage",
    "modern day passion play",
    "ravi teja",
    "critically acclaim",
    "imax",
]

In [20]:
# Fix typos from Cluster 8
df_genres["Genre"] = df_genres["Genre"].str.replace(
    "ttriller|thriler|triller", "thriller", regex=True, case=False
)
df_genres["Genre"] = df_genres["Genre"].str.replace(
    "biogtaphy", "biography", regex=False
)
df_genres["Genre"] = df_genres["Genre"].str.replace("slahser", "slasher", regex=False)
df_genres["Genre"] = df_genres["Genre"].str.replace("fantay", "fantasy", regex=False)
df_genres["Genre"] = df_genres["Genre"].str.replace("tragerdy", "tragedy", regex=False)
df_genres["Genre"] = df_genres["Genre"].str.replace("comedey", "comedy", regex=False)
df_genres["Genre"] = df_genres["Genre"].str.replace("famil ", "family ", regex=False)
df_genres["Genre"] = df_genres["Genre"].str.replace("romcom", "rom-com", regex=False)
df_genres.loc[df_genres.Genre.isin(["erotica", "ero"]), "Genre"] = "erotic"

names_mask = (
    (df_genres.Cluster == "8")
    & (df_genres.UMAP0 > 8.5)
    & (df_genres.UMAP1 > 2.7)
    & (df_genres.Genre != "operetta")
)
df_genres.loc[names_mask, "Genre"] = "unknown"
df_genres.loc[df_genres.Genre.isin(invalid_genres), "Genre"] = "unknown"  # not just 8

In [21]:
# Just a few special cases in Cluster 7
df_genres["Genre"] = df_genres["Genre"].str.replace("period", "historical", regex=False)
df_genres["Genre"] = df_genres["Genre"].str.replace("p.o.w", "pow", regex=False)
df_genres.loc[df_genres.Genre.str.contains("eadweard"), "Genre"] = "biographical"

Now for Cluster $C$, we need to figure out a way to extract the space-separated genres without breaking the genres that contain spaces, i.e. multi-word genres.

> There's probably a much more automated way to do this, but I found that in most cases the multi-word genres can be broken into two genres, so we may only deal with the special cases and then break everything else.

In [22]:
# Clusters 1, 3, 4, 5, 6, 9
regex_fixes = [
    # columns=["original", "fixed", "is_regex"],
    ["romantic|romanctic", "romance", True],
    [
        r"(action|adventure|animation|comedy|crime|drama|epic|fantasy|gangster|heist|historical|horror|musical|romance|suspense|war|western)-",
        r"\1 ",
        True,
    ],
    [r"(:?computer[- ])?(?:animate[d]?|animation)", "animation", True],
    [r"\s*feature\s*", "", True],
    ["american football", "american-football", False],
    ["art house", "art", False],
    ["bio-", "biographical ", False],
    ["biography", "biographical", False],
    [r"\bbiographic\b", "biographical", True],
    ["biographical fim", "biographical", False],
    [r"\s*adapt(ation)? (from|of).*$", "", True],
    ["black ", "black-", False],
    ["blaxploitation", "exploitation", False],
    ["bruceploitation", "exploitation", False],
    ["buddy cop", "buddy-cop", False],
    ["cartoon", "animation", False],
    ["for child", "child", False],
    ["child's", "child", False],
    ["cold war", "war", False],
    ["come of age", "come-of-age", False],
    ["comedic", "comedy", False],
    [r"(cyber|steam)punk", "punk", True],
    ["road movie", "road", False],
    ["devotional-", "devotional ", False],
    ["docudrama", "documentary drama", False],
    ["pseudo-documentary", "documentary", False],
    ["docufiction", "documentary fiction", False],
    ["dramatic", "drama", False],
    [" set 4", "", False],
    ["fairy tale", "fairytale", False],
    [" on the early year of hitler", "", False],
    [r"historic$", "historical", True],
    [r"history", "historical", False],
    [r"historical dram(a{0,4})", "historical drama", True],
    ["historical piece", "historical", False],
    ["horror from the novel by bram stoker", "horror", False],
    ["homosexual", "lgbt", False],
    ["human right", "human-right", False],
    ["independent movie", "independent", False],
    ["inspire by true event", "true-event", False],
    ["interactive cinema", "interactive-cinema", False],
    ["james bond", "james-bond", False],
    ["j-horror", "horror", False],
    [r"\bkid\b", "child", True],
    [r"lgbt-(theme|relate)", "lgbt", True],
    ["live action", "live-action", False],
    ["love story", "love", False],
    ["martial art", "martial-art", False],
    [r"\s*nationalboard\s*", "", True],
    ["nfb ", "", False],
    ["patriotism", "patriotic", False],
    ["perodic", "historical drama", False],
    [r"(psy|psycho|physiological) ", "psychological ", True],
    ["politics", "political", False],
    ["rom-com-drama", "rom-com drama", False],
    ["satirical", "satire", False],
    ["socio-", "social ", False],
    [r"\bsex\b", "sexual", True],
    ["sexploitation", "sexual exploitation", False],
    ["slice of life", "slice-of-life", False],
    ["space opera", "space-opera", False],
    ["stop motion", "stop-motion", False],
    ["summer camp", "summer-camp", False],
    ["swashbuckling", "swashbuckler", False],
    ["television", "tv", False],
    ["true crime", "true-event crime", False],
    ["cbc-tv", "tv", False],
    ["thai boxing", "boxing", False],
    ["time travel", "time-travel", False],
    [r"war[ ]?time", "war", True],
    [r"worldwar (i{,2})", r"worldwar-\1", True],
    ["worldwar1", "worldwar-i", False],
    ["worldwarii", "worldwar-ii", False],
]

for i, row in enumerate(regex_fixes):
    df_genres["Genre"] = df_genres["Genre"].str.replace(row[0], row[1], regex=row[2])

Now that we have the genres fixed, we simply split the `Genre` column by spaces, unnest the column, and thus get the genre labels for each movie. We can also:

1. Drop `unknown` rows when the movie has another valid genre.
2. Drop rows where the genre only appeared once in the dataset.

In [23]:
df_genres = (
    df_genres.assign(Genre=lambda dat: dat["Genre"].str.split(" "))
    .explode("Genre")
    .filter(["movieID", "Genre"])
    .drop_duplicates()
)

# Drop genres that only appear once
df_genre_count = (
    df_genres.groupby("Genre")
    .agg("count")
    .reset_index()
    .rename(columns={"movieID": "cnt"})
    .sort_values("cnt")
)

df_genres = df_genres[
    df_genres.Genre.isin(df_genre_count.loc[df_genre_count.cnt > 1, "Genre"])
]

# Drop unknown rows when there's another valid genre for the movie
df_valid_genres = df_genres[df_genres.Genre != "unknown"].copy()
df_unknown_genres = df_genres[
    (df_genres.Genre == "unknown") & (~df_genres.movieID.isin(df_valid_genres.movieID))
].copy()
df_genres = pd.concat([df_valid_genres, df_unknown_genres], ignore_index=True)

df_genres.Genre.unique().shape[0]

154

# Cleaning the Plot and saving results

We only applied very simple cleaning on `Plot` in hope that the model will do the heavylifting.

In [24]:
# Some data cleaning on the Plot
regex_fixes = [
    ["\s*\[.*\]", " ", True],  # footnotes
    ["–|—", "-", True],  # dash
    ["\r\n", "", True],  # newlines
]

for i, row in enumerate(regex_fixes):
    df["Plot"] = df["Plot"].str.replace(row[0], row[1], regex=row[2])

In [25]:
# End of data cleaning; save results
df.to_csv("./data_cleaned.csv", index=False)
df_genres.to_csv("./genres_cleaned.csv", index=False)

# Prepare dataset for modeling

Finally the exciting part -- fine-tuning a pre-trained DistilBERT model from [Hugging Face](https://huggingface.co/distilbert-base-uncased) for our multi-label classification task. We first load the relevant Hugging Face and PyTorch libraries:

In [26]:
from datasets import Dataset
from torch import logical_and, logical_or, nn
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)

# Configurations
model_name = "distilbert-base-uncased"
test_ratio = 0.1
val_ratio = 0.1
batch_size = 16
num_genres = 10

## Merging similar genres

For faster training and better performance, we group similar genres together and only keep samples whose genres appear the most often.

In [27]:
# Merge similar genres (see https://aclanthology.org/Y18-1007.pdf)
genre_groups = pd.DataFrame(
    [
        (
            "action",
            [
                "action",
                "adventure",
                "sci-fi",
                "superhero",
                "sport",
                "spy",
                "war",
                "worldwar-i",
                "worldwar-ii",
            ],
        ),
        ("comedy", ["comedy", "rom-com", "black-comedy"]),
        ("drama", ["drama", "fantasy", "biodrama", "melodrama"]),
        ("family", ["family", "animation", "musical", "anime", "child"]),
        ("thriller", ["thriller", "mystery"]),
        ("documentary", ["documentary", "biographical", "historical"]),
    ],
    columns=["genre_group", "Genre"],
).explode("Genre")

df_genres = df_genres.merge(genre_groups, how="left", on="Genre")
df_genres.loc[df_genres.genre_group.notna(), "Genre"] = df_genres["genre_group"]

In [28]:
# Only keep the top `num_genres` genres
top_genres = (
    df_genres.query("Genre != 'unknown'")
    .groupby("Genre")
    .agg(n=("Genre", "count"))
    .reset_index()
    .sort_values("n", ascending=False)
    .head(num_genres)
    .Genre.values
)

top_genres

array(['drama', 'comedy', 'action', 'family', 'thriller', 'romance',
       'crime', 'horror', 'western', 'documentary'], dtype=object)

Now we one-hot-encode the genres so that our model can understand it.

In [29]:
# Encode genre labels to wide arrays
df_genres = (
    df_genres.query("Genre in @top_genres")
    .assign(cnt=1)
    .pivot_table(index=["movieID"], columns="Genre", values=["cnt"])
    .fillna(0)
    # .astype(int)
    .reset_index(col_level=1)  # get movieID out
)

df_genres.columns = [x[1] for x in df_genres.columns]
df_genres = df_genres.set_index("movieID")

genre_names = df_genres.columns.tolist()
labels = df_genres.values.tolist()
df_genres = pd.DataFrame({"movieID": df_genres.index, "labels": labels})

Our dataset is almost ready to go! Let's take a quick look:

In [30]:
df = (
    df.reset_index()
    .rename(columns={"index": "movieID"})
    .filter(["movieID", "Plot"])
    .merge(df_genres, on="movieID")
    .reset_index(drop=True)
)

print(f"Number of samples left: {df.shape[0]}")
df.sample(10)

Number of samples left: 27504


Unnamed: 0,movieID,Plot,labels
13182,14267,While the rest of his high school graduating c...,"[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
20261,22263,Dimanche tells the story of a young boy who go...,"[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."
1695,1778,Hester Prynne has a child out of wedlock and r...,"[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."
7397,8071,"The film briefly depicts Chappaqua, New York, ...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."
16307,17589,Working as a prostitute on the weekend train t...,"[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."
21924,24451,The film is a light-hearted comedy with Meena ...,"[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."
24839,28796,Settled at the foothills of a fort is a quaint...,"[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."
25010,30711,"The movie is about Velan (Prashanth), a respon...","[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
7419,8093,The story follows the fate of four Formula One...,"[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."
18902,20260,"A mobster named Roxy Robinson is ""splurged"" by...","[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."


## Construct data loader and train-test split

In [31]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

# Construct dataset
dat = Dataset.from_pandas(df)

# Tokenize the Plot column
dat = dat.map(
    lambda batch: tokenizer.batch_encode_plus(
        batch["Plot"], padding="max_length", truncation=True
    ),
    batched=True,
    remove_columns=["movieID"],
)

# Retrieve tensors of the following columns as model inputs
valid_cols = ["input_ids", "token_type_ids", "attention_mask", "labels"]
cols = [c for c in dat.column_names if c in valid_cols]
dat.set_format(type="torch", columns=cols)

# Train/validation/test split
dat = dat.train_test_split(test_size=test_ratio, seed=42)
dat_train = dat["train"].train_test_split(test_size=val_ratio, seed=42)
dat["train"] = dat_train["train"]
dat["validation"] = dat_train["test"]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

  0%|          | 0/28 [00:00<?, ?ba/s]

# Fine-tune model

In [32]:
# Modify last layer of model
model = AutoModelForSequenceClassification.from_pretrained(
    model_name, problem_type="multi_label_classification", num_labels=num_genres
)
model.to("cuda")

Downloading:   0%|          | 0.00/256M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classi

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
       

In [33]:
training_args = TrainingArguments(
    output_dir="distilbert_multilabel",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=2 * batch_size,
    learning_rate=1e-5,
    num_train_epochs=5,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=1,  # make sure validation loss is logged in each epoch
    seed=42,
)

In [34]:
# Login to Weights & Biases
import wandb
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()

wandb_api = user_secrets.get_secret("WANDB_API") 
wandb.login(key=wandb_api)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[34m[1mwandb[0m: W&B API key is configured (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [35]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dat["train"],
    eval_dataset=dat["validation"],
)
trainer.train()

The following columns in the training set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: Plot.
***** Running training *****
  Num examples = 22277
  Num Epochs = 5
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 6965
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: Currently logged in as: [33my1zhou[0m (use `wandb login --relogin` to force relogin)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Epoch,Training Loss,Validation Loss
1,0.1973,0.230971
2,0.1191,0.214572
3,0.2744,0.210393
4,0.1801,0.211025
5,0.0982,0.211849


The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: Plot.
***** Running Evaluation *****
  Num examples = 2476
  Batch size = 32
Saving model checkpoint to distilbert_multilabel/checkpoint-1393
Configuration saved in distilbert_multilabel/checkpoint-1393/config.json
Model weights saved in distilbert_multilabel/checkpoint-1393/pytorch_model.bin
The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: Plot.
***** Running Evaluation *****
  Num examples = 2476
  Batch size = 32
Saving model checkpoint to distilbert_multilabel/checkpoint-2786
Configuration saved in distilbert_multilabel/checkpoint-2786/config.json
Model weights saved in distilbert_multilabel/checkpoint-2786/pytorch_model.bin
The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequence

TrainOutput(global_step=6965, training_loss=0.21355451704547804, metrics={'train_runtime': 3396.5097, 'train_samples_per_second': 32.794, 'train_steps_per_second': 2.051, 'total_flos': 1.4756986258176e+16, 'train_loss': 0.21355451704547804, 'epoch': 5.0})

# Metrics on the test set

In [36]:
model.eval()

dl = DataLoader(dat["test"], batch_size=8)

proba_labels = []
sigmoid = nn.Sigmoid()
for batch in dl:
    batch = {k: v.to("cuda") for k, v in batch.items()}
    logits = model(**batch).get("logits")

    # Make sure there's at least one predicted label
    # by setting the logit of the maximum to 10
    max_proba = logits.argmax(axis=1)
    for i in range(logits.shape[0]):
        logits[i, max_proba[i]] = 10.0

    y_pred = (sigmoid(logits) > 0.5).cpu().detach().numpy()
    proba_labels.append(y_pred)

proba_labels = np.vstack(proba_labels)
y_labels = dat["test"]["labels"].bool().cpu().detach().numpy()

In [37]:
from sklearn.metrics import classification_report

print(classification_report(y_labels, proba_labels, target_names=genre_names))

              precision    recall  f1-score   support

      action       0.65      0.61      0.63       514
      comedy       0.74      0.57      0.65       740
       crime       0.53      0.32      0.40       154
 documentary       0.65      0.18      0.28        74
       drama       0.61      0.67      0.64       969
      family       0.76      0.42      0.54       257
      horror       0.68      0.70      0.69       172
     romance       0.51      0.38      0.44       201
    thriller       0.52      0.35      0.42       228
     western       0.83      0.87      0.85        97

   micro avg       0.65      0.57      0.60      3406
   macro avg       0.65      0.51      0.55      3406
weighted avg       0.65      0.57      0.60      3406
 samples avg       0.66      0.60      0.61      3406



In [38]:
true_pos = np.logical_and(y_labels, proba_labels).sum(axis=1)
pred_pos = np.logical_or(y_labels, proba_labels).sum(axis=1)

hamming_score = np.nansum(true_pos / pred_pos) / y_labels.shape[0]
precision = np.nansum(true_pos / y_labels.sum(axis=1)) / y_labels.shape[0]
recall = np.nansum(true_pos / proba_labels.sum(axis=1)) / y_labels.shape[0]

print(
    f"""
    Hamming accuracy: {hamming_score}
    Precision: {precision}
    Recall: {recall}
"""
)


    Hamming accuracy: 0.5801647885617351
    Precision: 0.602447594814007
    Recall: 0.6574578940991155



Obviously not great metrics, but not bad for a first version!

In [39]:
# Get genre names and plot text
true_labels = [
    [genre_names[x] for x in np.argwhere(arr == 1).flatten()] for arr in y_labels
]
plots = dat["test"]["Plot"]
pred_labels = [
    [genre_names[x] for x in np.argwhere(arr == 1).flatten()] for arr in proba_labels
]

# Look at some predictions
for i in range(20):
    print(
        f"""
    True label: {true_labels[i]}
    Predicted label: {pred_labels[i]}
    Plot: {plots[i]}
    """
    )


    True label: ['action']
    Predicted label: ['action']
    Plot: A British naval officer volunteers for a dangerous mission to infiltrate the base of pirates who threaten shipping off Madagascar.
    

    True label: ['comedy']
    Predicted label: ['comedy']
    Plot: The film starts off with Calvin "Babyface" Simms (Marlon Wayans) who is a very short convict. He is seen getting released and planning a robbery to steal a diamond with the help of his goofball cohort Percy (Tracy Morgan). After the successful robbery, the duo are almost arrested, but not before Calvin manages to stash the diamond in a nearby woman's purse. The thieves follow the handbag's owner to her home where they discover a couple, Darryl (Shawn Wayans) and Vanessa Edwards (Kerry Washington), who are eager to have a child.Calvin and Percy hatch a plot to pass Calvin off as a baby left on the couple's doorstep. After seeing Calvin, Darryl and Vanessa, wanting a child, immediately adopt the baby as their own. Ho