In [1]:
import torch
import pandas as pd
import numpy as np

from torchmetrics.classification import MultilabelF1Score
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm

In [2]:
with open('ml1m/content/dataset/genres.txt', 'r') as f:
    genre_all = f.readlines()
genres = [genre.strip() for genre in genre_all]

mapping = {}
for genre, i in enumerate(genres):
    mapping[genre] = i

mapping

{0: 'Crime',
 1: 'Thriller',
 2: 'Fantasy',
 3: 'Horror',
 4: 'Sci-Fi',
 5: 'Comedy',
 6: 'Documentary',
 7: 'Adventure',
 8: 'Film-Noir',
 9: 'Animation',
 10: 'Romance',
 11: 'Drama',
 12: 'Western',
 13: 'Musical',
 14: 'Action',
 15: 'Mystery',
 16: 'War',
 17: "Children's"}

In [3]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", problem_type="multi_label_classification", num_labels=18)
model.config.id2label = mapping

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'pre_classifier.weight', 'classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
trainset = pd.read_csv('data_final.csv', encoding='utf-8')
trainset = trainset.drop(columns=['text'])
trainset.rename(columns={'context':'text'}, inplace=True)
trainset = trainset.dropna()
trainset.label = trainset.label.apply(lambda x: eval(x))
testset = trainset.sample(frac=0.1, random_state=42)
validset = trainset.drop(testset.index)

In [5]:
len(trainset), len(testset), len(validset)

(67084, 6708, 60376)

# Hard Code

In [6]:
class Poroset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len=256):
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        text = row.text
        label = row.label

        if len(text) > self.max_len:
            text = text[:self.max_len]
        
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=True,
            padding='max_length',
            pad_to_max_length=True,
            return_attention_mask=True,
            truncation=True,
            return_tensors='pt'
        )
        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

In [7]:
trainset = Poroset(trainset, tokenizer)
testset = Poroset(testset, tokenizer)
validset = Poroset(validset, tokenizer)

In [8]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=True)
validloader = torch.utils.data.DataLoader(validset, batch_size=32, shuffle=True)

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [10]:
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)

In [11]:
def train(epoch):
    model.train()
    train_loss = 0
    f1 = MultilabelF1Score(num_labels=18, threshold=0.5, average='macro')
    f1.to(device)
    for _, data in tqdm(enumerate(trainloader, 0), total=len(trainloader)):
        ids = data['input_ids'].to(device, dtype=torch.long)
        mask = data['attention_mask'].to(device, dtype=torch.long)
        targets = data['label'].to(device, dtype=torch.float)

        outputs = model(ids, mask).logits
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        f1.update(outputs, targets)

    print(f'Epoch: {epoch}, Train Loss: {train_loss / len(trainloader)}, Train F1: {f1.compute()}')

In [12]:
def valid(epoch):
    # Print valid loss, f1 Macro and map@k
    model.eval()
    valid_loss = 0
    f1 = MultilabelF1Score(num_labels=18, threshold=0.5, average='macro')
    f1.to(device)
    with torch.no_grad():
        for _, data in tqdm(enumerate(validloader, 0), total=len(validloader)):
            ids = data['input_ids'].to(device, dtype=torch.long)
            mask = data['attention_mask'].to(device, dtype=torch.long)
            targets = data['label'].to(device, dtype=torch.float)

            outputs = model(ids, mask).logits
            loss = loss_fn(outputs, targets)
            valid_loss += loss.item()

            f1.update(outputs, targets)

    print(f'Epoch: {epoch}, Valid loss: {valid_loss/len(validloader)}, Valid F1: {f1.compute()}')

In [13]:
epochs = 32
for epoch in range(epochs):
    train(epoch)
    valid(epoch)

  0%|          | 0/2097 [00:00<?, ?it/s]

100%|██████████| 2097/2097 [11:10<00:00,  3.13it/s]


Epoch: 0, Train Loss: 0.23235662458451178, Train F1: 0.2525077760219574


100%|██████████| 1887/1887 [03:46<00:00,  8.34it/s]


Epoch: 0, Valid loss: 0.1921309399649019, Valid F1: 0.38867348432540894


100%|██████████| 2097/2097 [11:16<00:00,  3.10it/s]


Epoch: 1, Train Loss: 0.1917233191533605, Train F1: 0.41021084785461426


100%|██████████| 1887/1887 [03:47<00:00,  8.31it/s]


Epoch: 1, Valid loss: 0.1545693561332088, Valid F1: 0.5179757475852966


100%|██████████| 2097/2097 [11:18<00:00,  3.09it/s]


Epoch: 2, Train Loss: 0.15971508895748276, Train F1: 0.5161657333374023


100%|██████████| 1887/1887 [03:39<00:00,  8.60it/s]


Epoch: 2, Valid loss: 0.12553278491824912, Valid F1: 0.6190528869628906


100%|██████████| 2097/2097 [10:55<00:00,  3.20it/s]


Epoch: 3, Train Loss: 0.12867488111264944, Train F1: 0.6063541173934937


100%|██████████| 1887/1887 [03:39<00:00,  8.61it/s]


Epoch: 3, Valid loss: 0.09400106864430062, Valid F1: 0.7013272643089294


100%|██████████| 2097/2097 [10:55<00:00,  3.20it/s]


Epoch: 4, Train Loss: 0.10334676401454264, Train F1: 0.678597092628479


100%|██████████| 1887/1887 [03:39<00:00,  8.61it/s]


Epoch: 4, Valid loss: 0.06905812540925174, Valid F1: 0.7746954560279846


100%|██████████| 2097/2097 [14:25<00:00,  2.42it/s]


Epoch: 5, Train Loss: 0.08180252629615149, Train F1: 0.7475529313087463


100%|██████████| 1887/1887 [05:08<00:00,  6.12it/s]


Epoch: 5, Valid loss: 0.05085406591103989, Valid F1: 0.8578332662582397


100%|██████████| 2097/2097 [14:51<00:00,  2.35it/s]


Epoch: 6, Train Loss: 0.065465156656002, Train F1: 0.8015155792236328


100%|██████████| 1887/1887 [05:08<00:00,  6.11it/s]


Epoch: 6, Valid loss: 0.03982186100135247, Valid F1: 0.8732893466949463


100%|██████████| 2097/2097 [15:32<00:00,  2.25it/s]


Epoch: 7, Train Loss: 0.05265480582724456, Train F1: 0.8390929698944092


100%|██████████| 1887/1887 [04:44<00:00,  6.63it/s]


Epoch: 7, Valid loss: 0.029378511786222, Valid F1: 0.9152898788452148


100%|██████████| 2097/2097 [15:29<00:00,  2.26it/s]


Epoch: 8, Train Loss: 0.04374722349042189, Train F1: 0.8708322048187256


100%|██████████| 1887/1887 [05:24<00:00,  5.82it/s]


Epoch: 8, Valid loss: 0.023593734013815926, Valid F1: 0.930281937122345


100%|██████████| 2097/2097 [15:11<00:00,  2.30it/s]


Epoch: 9, Train Loss: 0.03749557166117384, Train F1: 0.8912842273712158


100%|██████████| 1887/1887 [04:48<00:00,  6.54it/s]


Epoch: 9, Valid loss: 0.021194223130924335, Valid F1: 0.9423518180847168


100%|██████████| 2097/2097 [16:39<00:00,  2.10it/s]


Epoch: 10, Train Loss: 0.03284143318538054, Train F1: 0.9034895300865173


100%|██████████| 1887/1887 [05:07<00:00,  6.14it/s]


Epoch: 10, Valid loss: 0.015880267300766908, Valid F1: 0.9559046626091003


100%|██████████| 2097/2097 [15:47<00:00,  2.21it/s]


Epoch: 11, Train Loss: 0.029054239759568976, Train F1: 0.9167182445526123


100%|██████████| 1887/1887 [05:09<00:00,  6.09it/s]


Epoch: 11, Valid loss: 0.014495264251174027, Valid F1: 0.9625974893569946


100%|██████████| 2097/2097 [14:45<00:00,  2.37it/s]


Epoch: 12, Train Loss: 0.026640505161319366, Train F1: 0.9233362674713135


100%|██████████| 1887/1887 [04:54<00:00,  6.41it/s]


Epoch: 12, Valid loss: 0.012614811458141223, Valid F1: 0.9660592079162598


100%|██████████| 2097/2097 [12:49<00:00,  2.72it/s]


Epoch: 13, Train Loss: 0.023678713343024047, Train F1: 0.9336626529693604


100%|██████████| 1887/1887 [04:02<00:00,  7.79it/s]


Epoch: 13, Valid loss: 0.01148875238168807, Valid F1: 0.9692578911781311


100%|██████████| 2097/2097 [11:39<00:00,  3.00it/s]


Epoch: 14, Train Loss: 0.022274192229768685, Train F1: 0.9390428066253662


100%|██████████| 1887/1887 [03:56<00:00,  8.00it/s]


Epoch: 14, Valid loss: 0.009352909098391652, Valid F1: 0.9702039957046509


100%|██████████| 2097/2097 [11:30<00:00,  3.04it/s]


Epoch: 15, Train Loss: 0.020494203815184724, Train F1: 0.9422042369842529


100%|██████████| 1887/1887 [03:53<00:00,  8.09it/s]


Epoch: 15, Valid loss: 0.009571895369962273, Valid F1: 0.9723016619682312


100%|██████████| 2097/2097 [11:31<00:00,  3.03it/s]


Epoch: 16, Train Loss: 0.01939460986629608, Train F1: 0.9475372433662415


100%|██████████| 1887/1887 [03:49<00:00,  8.22it/s]


Epoch: 16, Valid loss: 0.008917410031439036, Valid F1: 0.97434401512146


100%|██████████| 2097/2097 [11:21<00:00,  3.08it/s]


Epoch: 17, Train Loss: 0.01811545464617916, Train F1: 0.9514101147651672


100%|██████████| 1887/1887 [03:49<00:00,  8.23it/s]


Epoch: 17, Valid loss: 0.008022324647692737, Valid F1: 0.981061577796936


100%|██████████| 2097/2097 [11:32<00:00,  3.03it/s]


Epoch: 18, Train Loss: 0.01708767093890371, Train F1: 0.9538013935089111


100%|██████████| 1887/1887 [03:56<00:00,  7.97it/s]


Epoch: 18, Valid loss: 0.007964441176367184, Valid F1: 0.9776462316513062


100%|██████████| 2097/2097 [11:55<00:00,  2.93it/s]


Epoch: 19, Train Loss: 0.016411068038688505, Train F1: 0.9551047086715698


100%|██████████| 1887/1887 [04:00<00:00,  7.85it/s]


Epoch: 19, Valid loss: 0.0069602642631368975, Valid F1: 0.9824789762496948


100%|██████████| 2097/2097 [11:55<00:00,  2.93it/s]


Epoch: 20, Train Loss: 0.015321717814997133, Train F1: 0.9581196308135986


100%|██████████| 1887/1887 [04:00<00:00,  7.85it/s]


Epoch: 20, Valid loss: 0.006578724860832111, Valid F1: 0.9820476174354553


100%|██████████| 2097/2097 [11:40<00:00,  2.99it/s]


Epoch: 21, Train Loss: 0.01496704661538962, Train F1: 0.9593393802642822


100%|██████████| 1887/1887 [03:39<00:00,  8.59it/s]


Epoch: 21, Valid loss: 0.006545136926362437, Valid F1: 0.9816089868545532


100%|██████████| 2097/2097 [10:54<00:00,  3.20it/s]


Epoch: 22, Train Loss: 0.014390801711389452, Train F1: 0.9615979790687561


100%|██████████| 1887/1887 [03:39<00:00,  8.60it/s]


Epoch: 22, Valid loss: 0.006061858327147612, Valid F1: 0.9844773411750793


100%|██████████| 2097/2097 [11:36<00:00,  3.01it/s]


Epoch: 23, Train Loss: 0.013798245346907777, Train F1: 0.9639631509780884


100%|██████████| 1887/1887 [03:56<00:00,  7.97it/s]


Epoch: 23, Valid loss: 0.005432173435326526, Valid F1: 0.9854211807250977


100%|██████████| 2097/2097 [11:46<00:00,  2.97it/s]


Epoch: 24, Train Loss: 0.013029633012456347, Train F1: 0.9653505086898804


100%|██████████| 1887/1887 [03:57<00:00,  7.96it/s]


Epoch: 24, Valid loss: 0.006224769993675808, Valid F1: 0.9852772355079651


100%|██████████| 2097/2097 [11:43<00:00,  2.98it/s]


Epoch: 25, Train Loss: 0.012923402873048417, Train F1: 0.9641183614730835


100%|██████████| 1887/1887 [03:56<00:00,  7.97it/s]


Epoch: 25, Valid loss: 0.005926729398837563, Valid F1: 0.9850278496742249


100%|██████████| 2097/2097 [11:43<00:00,  2.98it/s]


Epoch: 26, Train Loss: 0.012284589288336556, Train F1: 0.9696133136749268


100%|██████████| 1887/1887 [03:49<00:00,  8.21it/s]


Epoch: 26, Valid loss: 0.005454725524647261, Valid F1: 0.9860429763793945


100%|██████████| 2097/2097 [11:28<00:00,  3.04it/s]


Epoch: 27, Train Loss: 0.011989611095453946, Train F1: 0.967768669128418


100%|██████████| 1887/1887 [03:54<00:00,  8.03it/s]


Epoch: 27, Valid loss: 0.005151425484622417, Valid F1: 0.9865025281906128


100%|██████████| 2097/2097 [11:46<00:00,  2.97it/s]


Epoch: 28, Train Loss: 0.011762591759965373, Train F1: 0.9698853492736816


100%|██████████| 1887/1887 [04:06<00:00,  7.65it/s]


Epoch: 28, Valid loss: 0.004492706929142691, Valid F1: 0.9875601530075073


100%|██████████| 2097/2097 [11:48<00:00,  2.96it/s]


Epoch: 29, Train Loss: 0.0111801783522418, Train F1: 0.9714550375938416


100%|██████████| 1887/1887 [03:54<00:00,  8.05it/s]


Epoch: 29, Valid loss: 0.004415218927151024, Valid F1: 0.989145040512085


100%|██████████| 2097/2097 [11:39<00:00,  3.00it/s]


Epoch: 30, Train Loss: 0.01091533914415706, Train F1: 0.9725842475891113


100%|██████████| 1887/1887 [04:16<00:00,  7.37it/s]


Epoch: 30, Valid loss: 0.0048528528429140775, Valid F1: 0.9880671501159668


100%|██████████| 2097/2097 [12:19<00:00,  2.84it/s]


Epoch: 31, Train Loss: 0.010832403493374183, Train F1: 0.9722378253936768


100%|██████████| 1887/1887 [04:05<00:00,  7.70it/s]

Epoch: 31, Valid loss: 0.004414875472307987, Valid F1: 0.9889860153198242





In [20]:
def test(testloader):
    model.eval()
    test_loss = 0
    f1 = MultilabelF1Score(num_labels=18, threshold=0.5, average='macro')
    f1.to(device)
    with torch.no_grad():
        for _, data in tqdm(enumerate(testloader, 0), total=len(testloader)):
            ids = data['input_ids'].to(device, dtype=torch.long)
            mask = data['attention_mask'].to(device, dtype=torch.long)
            targets = data['label'].to(device, dtype=torch.float)

            outputs = model(ids, mask).logits
            loss = loss_fn(outputs, targets)
            test_loss += loss.item()

            f1.update(outputs, targets)

    print(f'Test loss: {test_loss/len(testloader)}, Test F1: {f1.compute()}')

In [22]:
test(testloader)

100%|██████████| 210/210 [00:26<00:00,  7.80it/s]

Test loss: 0.004088942478466336, Test F1: 0.9873391389846802





In [None]:
# Save model
model.save_pretrained('model')

In [23]:
model.push_to_hub('plot-classification')
tokenizer.push_to_hub('plot-classification')

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/dduy193/plot-classification/commit/750deaf8dc19c54407e5e67ce5591cc0fe5bc511', commit_message='Upload tokenizer', commit_description='', oid='750deaf8dc19c54407e5e67ce5591cc0fe5bc511', pr_url=None, pr_revision=None, pr_num=None)

In [None]:
# Inferencing
def inference(input, threshold = 0.5, model=model, tokenizer=tokenizer):
    encoding = tokenizer.encode_plus(
        input,
        add_special_tokens=True,
        max_length=64,
        return_token_type_ids=True,
        padding='max_length',
        pad_to_max_length=True,
        return_attention_mask=True,
        truncation=True,
        return_tensors='pt'
    )

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    outputs = model(input_ids, attention_mask=attention_mask)
    outputs = torch.sigmoid(outputs.logits).cpu().detach().numpy().tolist()[0]
    outputs = np.array(outputs) >= threshold
    outputs = np.where(outputs == True)[0]
    outputs = [genres[i] for i in outputs]
    print(outputs)

In [None]:
inference('The Untouchables (1987)')

['Crime', 'Drama', 'Action']
