In [None]:
import torch
import torchmetrics
import pandas as pd
import numpy as np

from transformers import AutoTokenizer, DistilBertForSequenceClassification
from tqdm import tqdm
from ast import literal_eval
from torchmetrics.classification import MultilabelF1Score

In [None]:
df1 = pd.read_csv('movie_5000/tmdb_5000_movies.csv')
df2 = pd.read_csv('movie_45000/movies_metadata.csv')
df3 = pd.read_csv('data.csv')

In [None]:
df1['genres'] = df1['genres'].fillna('[]').apply(literal_eval).apply(lambda x: [i['name'] for i in x] if isinstance(x, list) else [])
df2['genres'] = df2['genres'].fillna('[]').apply(literal_eval).apply(lambda x: [i['name'] for i in x] if isinstance(x, list) else [])
df2 = df2[df2.genres.apply(lambda x: len(x) > 0)]
df3['context'] = df3['context'].apply(lambda x: x.replace('<pad>',''))
df3['context'] = df3['context'].apply(lambda x: x.replace('</s>',''))
df2.reset_index(inplace = True)
df2 = df2.drop(columns='index')

In [None]:
#change the position of the columns 'label' and 'context' in df3 so that the result is text -> context -> label
cols = list(df3.columns)
cols = [cols[0]] + [cols[-1]] + cols[1:-1]
df3 = df3[cols]
df3['label'] = df3['label'].apply(literal_eval)
df3


In [None]:
df1 = df1.drop(columns='id')
df2 = df2.drop(columns='id')

In [None]:
# Check if all df1 movies are in df2
df1['title'] = df1['title'].str.lower()
df2['title'] = df2['title'].str.lower()

# Make every movie in df1 unique by title
df1 = df1.drop_duplicates(subset=['title'], keep='first')
df1.reset_index(inplace = True)
df1 = df1.drop(columns='index')

# Make every movie in df2 unique by title
df2 = df2.drop_duplicates(subset=['title'], keep='first')
df2.reset_index(inplace = True)
df2 = df2.drop(columns='index')

len(df1), len(df2)

In [None]:
df = pd.concat([df1, df2], ignore_index=True)
df = df.drop_duplicates(subset=['title'], keep='first')
df.reset_index(inplace = True)
df.drop(columns='index', inplace=True)

In [None]:
# Remove movies in df that has no overview
final_df = df.copy()
final_df = final_df[final_df['overview'].notnull()]
final_df.reset_index(inplace = True)
final_df = final_df.drop(columns='index')
final_df = final_df[final_df['genres'].apply(lambda x: len(x) > 0)]
final_df.reset_index(inplace = True)
final_df = final_df.drop(columns='index')

In [None]:
genres = ["Crime", "Thriller", "Fantasy", "Horror", "Sci-Fi", "Comedy", "Documentary", "Adventure", "Film-Noir", "Animation", "Romance", "Drama", "Western", "Musical", "Action", "Mystery", "War", "Children\'s"]
mapping = {}
for i in range(len(genres)):
    mapping[i] = genres[i]
mapping

for i in range(len(final_df)):
    gens = final_df['genres'][i]
    for g in gens:
        if g == 'Science Fiction':
            final_df['genres'][i].remove(g)
            final_df['genres'][i].append('Sci-Fi')
        else:
            if g not in genres:
                final_df['genres'][i].remove(g)

final_df = final_df[final_df['genres'].apply(lambda x: len(x) > 0)]
final_df.reset_index(inplace = True)
final_df = final_df.drop(columns='index')


In [38]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", problem_type="multi_label_classification", num_labels=18)
model.config.id2label = mapping

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'pre_classifier.weight', 'classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [39]:
def preprocess(df, genres = genres) -> pd.DataFrame:
    df['label'] = df.genres.apply(lambda x: [1 if genre in x else 0 for genre in genres])
    df.drop(columns=['genres'], inplace=True)
    df = df.reset_index(drop=True)
    return df

In [40]:
final_df = preprocess(final_df)

In [41]:
#merge df3 and final_df
final_df = final_df.rename(columns={'title':'text'})
final_df = final_df.rename(columns={'overview':'context'})
final_df

Unnamed: 0,text,context,label
0,avatar,"In the 22nd century, a paraplegic Marine is di...","[0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, ..."
1,pirates of the caribbean: at world's end,"Captain Barbossa, long believed to be dead, ha...","[0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, ..."
2,spectre,A cryptic message from Bond’s past sends him o...,"[1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, ..."
3,the dark knight rises,Following the death of District Attorney Harve...,"[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, ..."
4,john carter,"John Carter is a war-weary, former military ca...","[0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, ..."
...,...,...,...
39274,shadow of the blair witch,"In this true-crime documentary, we delve into ...","[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
39275,the burkittsville 7,A film archivist revisits the story of Rustin ...,"[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
39276,caged heat 3000,It's the year 3000 AD. The world's most danger...,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
39277,subdue,Rising and falling between a man and woman.,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ..."


In [42]:
#add df3 to final_df
final_df = pd.concat([final_df, df3], ignore_index=True)
final_df = final_df.reset_index(drop=True)
final_df

Unnamed: 0,text,context,label
0,avatar,"In the 22nd century, a paraplegic Marine is di...","[0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, ..."
1,pirates of the caribbean: at world's end,"Captain Barbossa, long believed to be dead, ha...","[0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, ..."
2,spectre,A cryptic message from Bond’s past sends him o...,"[1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, ..."
3,the dark knight rises,Following the death of District Attorney Harve...,"[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, ..."
4,john carter,"John Carter is a war-weary, former military ca...","[0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, ..."
...,...,...,...
67079,Kein Bund für's Leben,Kein Bund für's Leben is a German film about ...,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
67080,"Feuer, Eis & Dosenbier","Feuer, Eis & Dosenbier is a movie about a gro...","[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
67081,The Pirates,The Pirates is a movie about a group of pirat...,"[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ..."
67082,Rentun Ruusu,Rentun Ruusu is a movie about a young woman n...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [43]:
# make final_df a .csv file, split using | and remove the index
# final_df.to_csv('data_final.csv', sep='|', index=False)

In [44]:
#split final_df into train, val and test with the portion 0.8, 0.1, 0.1 randomly
trainset = final_df
validset = final_df.sample(frac=0.1, random_state=42)
testset = final_df.sample(frac=0.1, random_state=32)

In [45]:
class Poroset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len=256):
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        context = row.context
        label = row.label

        if len(context) > self.max_len:
            context = context[:self.max_len]
        
        encoding = self.tokenizer.encode_plus(
            context,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=True,
            padding='max_length',
            pad_to_max_length=True,
            return_attention_mask=True,
            truncation=True,
            return_tensors='pt'
        )
        return {
            'context': context,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

In [46]:
trainset = Poroset(trainset, tokenizer)
validset = Poroset(validset, tokenizer)
testset = Poroset(testset, tokenizer)

In [47]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
validloader = torch.utils.data.DataLoader(validset, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=True)

In [48]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
device

device(type='cuda')

In [49]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-5)

In [50]:
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

In [51]:
def train(epoch):
    model.train()
    train_loss = 0
    f1 = MultilabelF1Score(num_labels=18, threshold=0.5, average='macro')
    f1.to(device)
    for _, data in tqdm(enumerate(trainloader, 0), total=len(trainloader)):
        ids = data['input_ids'].to(device, dtype=torch.long)
        mask = data['attention_mask'].to(device, dtype=torch.long)
        targets = data['label'].to(device, dtype=torch.float)

        outputs = model(ids, mask).logits
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        f1.update(outputs, targets)

    print(f'Epoch: {epoch}, Train Loss: {train_loss / len(trainloader)}, Train F1: {f1.compute()}')

In [52]:
def valid(epoch):
    # Print valid loss, f1 Macro and map@k
    model.eval()
    valid_loss = 0
    f1 = MultilabelF1Score(num_labels=18, threshold=0.5, average='macro')
    f1.to(device)
    with torch.no_grad():
        for _, data in tqdm(enumerate(validloader, 0), total=len(validloader)):
            ids = data['input_ids'].to(device, dtype=torch.long)
            mask = data['attention_mask'].to(device, dtype=torch.long)
            targets = data['label'].to(device, dtype=torch.float)

            outputs = model(ids, mask).logits
            loss = loss_fn(outputs, targets)
            valid_loss += loss.item()

            f1.update(outputs, targets)

    print(f'Epoch: {epoch}, Valid loss: {valid_loss/len(validloader)}, Valid F1: {f1.compute()}')

In [53]:
for epoch in range(32):
    train(epoch)
    valid(epoch)

100%|██████████| 2097/2097 [14:57<00:00,  2.34it/s]


Epoch: 0, Train Loss: 0.2563466115201378, Train F1: 0.15443573892116547


100%|██████████| 210/210 [00:36<00:00,  5.71it/s]


Epoch: 0, Valid loss: 0.2094253305168379, Valid F1: 0.30158019065856934


100%|██████████| 2097/2097 [14:31<00:00,  2.41it/s]


Epoch: 1, Train Loss: 0.2102780885777703, Train F1: 0.3259690999984741


100%|██████████| 210/210 [00:31<00:00,  6.62it/s]


Epoch: 1, Valid loss: 0.19114060224521728, Valid F1: 0.38837894797325134


100%|██████████| 2097/2097 [14:37<00:00,  2.39it/s]


Epoch: 2, Train Loss: 0.19602112207716285, Train F1: 0.3884584307670593


100%|██████████| 210/210 [00:31<00:00,  6.62it/s]


Epoch: 2, Valid loss: 0.1780037684809594, Valid F1: 0.42632678151130676


100%|██████████| 2097/2097 [14:27<00:00,  2.42it/s]


Epoch: 3, Train Loss: 0.18392981151114432, Train F1: 0.43035808205604553


100%|██████████| 210/210 [00:31<00:00,  6.62it/s]


Epoch: 3, Valid loss: 0.16324013421932856, Valid F1: 0.46685314178466797


100%|██████████| 2097/2097 [14:27<00:00,  2.42it/s]


Epoch: 4, Train Loss: 0.17187120179207369, Train F1: 0.4704265892505646


100%|██████████| 210/210 [00:31<00:00,  6.62it/s]


Epoch: 4, Valid loss: 0.14815063895214173, Valid F1: 0.5322701930999756


100%|██████████| 2097/2097 [14:27<00:00,  2.42it/s]


Epoch: 5, Train Loss: 0.1600008910547624, Train F1: 0.5077268481254578


100%|██████████| 210/210 [00:31<00:00,  6.61it/s]


Epoch: 5, Valid loss: 0.13412483902204606, Valid F1: 0.5689296126365662


100%|██████████| 2097/2097 [14:27<00:00,  2.42it/s]


Epoch: 6, Train Loss: 0.14868086472359737, Train F1: 0.5430753827095032


100%|██████████| 210/210 [00:32<00:00,  6.56it/s]


Epoch: 6, Valid loss: 0.1214335526738848, Valid F1: 0.6095781326293945


100%|██████████| 2097/2097 [14:27<00:00,  2.42it/s]


Epoch: 7, Train Loss: 0.13759269409013125, Train F1: 0.5738568305969238


100%|██████████| 210/210 [00:31<00:00,  6.62it/s]


Epoch: 7, Valid loss: 0.10795478756938662, Valid F1: 0.6578888893127441


100%|██████████| 2097/2097 [14:27<00:00,  2.42it/s]


Epoch: 8, Train Loss: 0.12686712657220034, Train F1: 0.6048570871353149


100%|██████████| 210/210 [00:31<00:00,  6.62it/s]


Epoch: 8, Valid loss: 0.09508114073957716, Valid F1: 0.6890876293182373


100%|██████████| 2097/2097 [14:27<00:00,  2.42it/s]


Epoch: 9, Train Loss: 0.11658952702645126, Train F1: 0.6324417591094971


100%|██████████| 210/210 [00:31<00:00,  6.62it/s]


Epoch: 9, Valid loss: 0.08469299464708283, Valid F1: 0.7225284576416016


100%|██████████| 2097/2097 [14:39<00:00,  2.38it/s]


Epoch: 10, Train Loss: 0.10676599797936184, Train F1: 0.6601005792617798


100%|██████████| 210/210 [00:31<00:00,  6.65it/s]


Epoch: 10, Valid loss: 0.07487770194808642, Valid F1: 0.7449080944061279


100%|██████████| 2097/2097 [14:24<00:00,  2.42it/s]


Epoch: 11, Train Loss: 0.09779321892088404, Train F1: 0.6836209297180176


100%|██████████| 210/210 [00:31<00:00,  6.65it/s]


Epoch: 11, Valid loss: 0.06666479389227573, Valid F1: 0.779660701751709


100%|██████████| 2097/2097 [14:24<00:00,  2.43it/s]


Epoch: 12, Train Loss: 0.08919428060658216, Train F1: 0.7143036723136902


100%|██████████| 210/210 [00:31<00:00,  6.65it/s]


Epoch: 12, Valid loss: 0.0563022683330235, Valid F1: 0.8112510442733765


100%|██████████| 2097/2097 [14:24<00:00,  2.43it/s]


Epoch: 13, Train Loss: 0.0811456969246787, Train F1: 0.740847647190094


100%|██████████| 210/210 [00:31<00:00,  6.64it/s]


Epoch: 13, Valid loss: 0.04888497868641501, Valid F1: 0.8350722193717957


100%|██████████| 2097/2097 [14:24<00:00,  2.43it/s]


Epoch: 14, Train Loss: 0.07368104701837756, Train F1: 0.7592310905456543


100%|██████████| 210/210 [00:31<00:00,  6.65it/s]


Epoch: 14, Valid loss: 0.04180424496354092, Valid F1: 0.8563195466995239


100%|██████████| 2097/2097 [14:24<00:00,  2.43it/s]


Epoch: 15, Train Loss: 0.06652588295383031, Train F1: 0.781432569026947


100%|██████████| 210/210 [00:31<00:00,  6.64it/s]


Epoch: 15, Valid loss: 0.03476818627899601, Valid F1: 0.8979200124740601


100%|██████████| 2097/2097 [14:24<00:00,  2.43it/s]


Epoch: 16, Train Loss: 0.06057885615532604, Train F1: 0.8038603663444519


100%|██████████| 210/210 [00:31<00:00,  6.66it/s]


Epoch: 16, Valid loss: 0.029814790854496617, Valid F1: 0.9006966948509216


100%|██████████| 2097/2097 [14:24<00:00,  2.43it/s]


Epoch: 17, Train Loss: 0.05444948567618457, Train F1: 0.8209039568901062


100%|██████████| 210/210 [00:31<00:00,  6.64it/s]


Epoch: 17, Valid loss: 0.02549927400070287, Valid F1: 0.9294713139533997


100%|██████████| 2097/2097 [14:24<00:00,  2.43it/s]


Epoch: 18, Train Loss: 0.04943474044541006, Train F1: 0.838260293006897


100%|██████████| 210/210 [00:31<00:00,  6.66it/s]


Epoch: 18, Valid loss: 0.021405992506160623, Valid F1: 0.9411076307296753


100%|██████████| 2097/2097 [14:24<00:00,  2.43it/s]


Epoch: 19, Train Loss: 0.04452021439898622, Train F1: 0.8522422313690186


100%|██████████| 210/210 [00:31<00:00,  6.65it/s]


Epoch: 19, Valid loss: 0.01863916944490657, Valid F1: 0.9516871571540833


100%|██████████| 2097/2097 [14:24<00:00,  2.43it/s]


Epoch: 20, Train Loss: 0.04030283938686536, Train F1: 0.8689140677452087


100%|██████████| 210/210 [00:31<00:00,  6.66it/s]


Epoch: 20, Valid loss: 0.015511514969347488, Valid F1: 0.94398432970047


100%|██████████| 2097/2097 [14:24<00:00,  2.43it/s]


Epoch: 21, Train Loss: 0.0364469126320782, Train F1: 0.8838423490524292


100%|██████████| 210/210 [00:31<00:00,  6.66it/s]


Epoch: 21, Valid loss: 0.012639388535171747, Valid F1: 0.9634753465652466


100%|██████████| 2097/2097 [14:24<00:00,  2.43it/s]


Epoch: 22, Train Loss: 0.03372361266281159, Train F1: 0.8940054178237915


100%|██████████| 210/210 [00:31<00:00,  6.65it/s]


Epoch: 22, Valid loss: 0.011663989271480768, Valid F1: 0.9672171473503113


100%|██████████| 2097/2097 [14:24<00:00,  2.43it/s]


Epoch: 23, Train Loss: 0.030508222871157575, Train F1: 0.902606725692749


100%|██████████| 210/210 [00:31<00:00,  6.67it/s]


Epoch: 23, Valid loss: 0.009825690749234386, Valid F1: 0.9683904647827148


100%|██████████| 2097/2097 [15:07<00:00,  2.31it/s]


Epoch: 24, Train Loss: 0.027805716100223914, Train F1: 0.9121745824813843


100%|██████████| 210/210 [00:31<00:00,  6.67it/s]


Epoch: 24, Valid loss: 0.008876392127768624, Valid F1: 0.9728063344955444


100%|██████████| 2097/2097 [14:22<00:00,  2.43it/s]


Epoch: 25, Train Loss: 0.025478802847422388, Train F1: 0.9186450242996216


100%|██████████| 210/210 [00:31<00:00,  6.65it/s]


Epoch: 25, Valid loss: 0.008202799710090317, Valid F1: 0.9777411222457886


100%|██████████| 2097/2097 [14:22<00:00,  2.43it/s]


Epoch: 26, Train Loss: 0.023642981972085277, Train F1: 0.9282799959182739


100%|██████████| 210/210 [00:31<00:00,  6.67it/s]


Epoch: 26, Valid loss: 0.006059200879341612, Valid F1: 0.9802449941635132


100%|██████████| 2097/2097 [14:22<00:00,  2.43it/s]


Epoch: 27, Train Loss: 0.02197228901389796, Train F1: 0.9309113621711731


100%|██████████| 210/210 [00:31<00:00,  6.67it/s]


Epoch: 27, Valid loss: 0.006096584703551517, Valid F1: 0.983128547668457


100%|██████████| 2097/2097 [14:24<00:00,  2.43it/s]


Epoch: 28, Train Loss: 0.020616667839695384, Train F1: 0.9357055425643921


100%|██████████| 210/210 [00:31<00:00,  6.65it/s]


Epoch: 28, Valid loss: 0.006126290519854852, Valid F1: 0.9841627478599548


100%|██████████| 2097/2097 [14:32<00:00,  2.40it/s]


Epoch: 29, Train Loss: 0.019314871164075195, Train F1: 0.9415281414985657


100%|██████████| 210/210 [00:31<00:00,  6.66it/s]


Epoch: 29, Valid loss: 0.004953842487330327, Valid F1: 0.9881401062011719


100%|██████████| 2097/2097 [14:27<00:00,  2.42it/s]


Epoch: 30, Train Loss: 0.01784731240114344, Train F1: 0.9454348683357239


100%|██████████| 210/210 [00:31<00:00,  6.66it/s]


Epoch: 30, Valid loss: 0.004524072896345474, Valid F1: 0.9883469343185425


100%|██████████| 2097/2097 [14:25<00:00,  2.42it/s]


Epoch: 31, Train Loss: 0.01664195477235911, Train F1: 0.9520139694213867


100%|██████████| 210/210 [00:31<00:00,  6.66it/s]

Epoch: 31, Valid loss: 0.004041251857831542, Valid F1: 0.9904213547706604





In [58]:
#validate the model with the testset
def test(testing_loader):
    model.eval()
    fin_targets = []
    fin_outputs = []

    with torch.no_grad():
        for _, data in tqdm(enumerate(testing_loader, 0), total=len(testing_loader)):
            input_ids = data['input_ids'].to(device)
            attention_mask = data['attention_mask'].to(device)
            labels = data['label'].to(device)

            outputs = model(input_ids, attention_mask)
            fin_targets.extend(labels.cpu().detach().numpy().tolist())
            fin_outputs.extend(torch.sigmoid(outputs.logits).cpu().detach().numpy().tolist())

    return fin_outputs, fin_targets

In [59]:
outputs, targets = test(testloader)

outputs = np.array(outputs) >= 0.5

100%|██████████| 210/210 [00:34<00:00,  6.04it/s]


In [60]:
# Multi-label F1 score, macro-averaged
f1 = MultilabelF1Score(num_labels=18, threshold=0.5, average='macro')
f1.update(torch.tensor(outputs), torch.tensor(targets))

f1.compute()

tensor(0.9887)

In [61]:
torch.save(model.state_dict(), 'model.pt')