In [1]:
from datetime import datetime

import numpy as np
import pandas as pd
import torch
from sklearn import metrics
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_from_disk

In [2]:
data = load_from_disk('cleaned_data/')

data = data.class_encode_column('genre')

data['train'][0]

Loading cached processed dataset at /home/logan/Desktop/MyProjects/movie-genre-prediction/project-code/research/cleaned_data/train/cache-b0ac180b1c1e0cb9.arrow
Loading cached processed dataset at /home/logan/Desktop/MyProjects/movie-genre-prediction/project-code/research/cleaned_data/test/cache-ae076913fb72cfd6.arrow


{'id': 44978,
 'movie_name': 'Super Me',
 'synopsis': 'A young scriptwriter starts bringing valuable objects back from his short nightmares of being chased by a demon. Selling them makes him rich.',
 'genre': 4,
 'final_text': 'movie name - super me, synopsis - a young scriptwriter starts bringing valuable objects back from his short nightmares of being chased by a demon. selling them makes him rich..'}

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

MODEL_PATH = 'distil-bert-base-uncased/'

tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_PATH, use_fast=True, do_lower_case=True)
model = DistilBertForSequenceClassification.from_pretrained(
    MODEL_PATH,
    num_labels=len(data['train'].features['genre']._int2str),
).to(device)

cuda


Some weights of the model checkpoint at distil-bert-base-uncased/ were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distil-bert-base-uncased/ and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'clas

In [4]:
max_len = 0

for example in data['train']:
    input_ids = tokenizer.encode(example['final_text'], add_special_tokens=True)
    max_len = max(max_len, len(input_ids))

print(f'Max sentence len - {max_len}')

Max sentence len - 100


In [5]:
class ClassificationDataset:
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        text = str(self.data[item]['final_text'])
        target = int(self.data[item]['genre'])
        inputs = self.tokenizer(text, max_length=max_len, padding='max_length', truncation=True)

        ids = inputs['input_ids']
        mask = inputs['attention_mask']

        return {
            'input_ids': torch.tensor(ids, dtype=torch.long).to(device),
            'attention_mask': torch.tensor(mask, dtype=torch.long).to(device),
            'labels': torch.tensor(target, dtype=torch.long).to(device),
        }


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    accuracy = metrics.accuracy_score(labels, predictions)
    return {'accuracy': accuracy}


def train(ds):
    ds_train = ds['train']
    ds_test = ds['test']

    temp_ds = ds_train.train_test_split(test_size=0.1, stratify_by_column='genre')
    ds_train = temp_ds['train']
    ds_val = temp_ds['test']

    train_dataset = ClassificationDataset(ds_train, tokenizer)
    valid_dataset = ClassificationDataset(ds_val, tokenizer)
    test_dataset = ClassificationDataset(ds_test, tokenizer)

    args = TrainingArguments(
        'result',
        learning_rate=2e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=5,
        weight_decay=0.01
    )

    trainer = Trainer(
        model,
        args,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    trainer.train()
    preds = trainer.predict(test_dataset).predictions
    preds = np.argmax(preds, axis=1)

    submission = pd.DataFrame({'id': ds_test['id'], 'genre': preds})
    submission.loc[:, 'genre'] = submission.genre.apply(lambda x: ds_train.features['genre'].int2str(x))
    submission.to_csv(f'submission_{datetime.now().strftime("%d-%m-%Y-%H-%M")}.csv', index=False)


In [6]:
train(data)



  0%|          | 0/15190 [00:00<?, ?it/s]

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'loss': 1.939, 'learning_rate': 1.9341672152732063e-05, 'epoch': 0.16}
{'loss': 1.7405, 'learning_rate': 1.868334430546412e-05, 'epoch': 0.33}
{'loss': 1.7031, 'learning_rate': 1.8025016458196183e-05, 'epoch': 0.49}
{'loss': 1.6783, 'learning_rate': 1.7366688610928244e-05, 'epoch': 0.66}
{'loss': 1.6769, 'learning_rate': 1.6708360763660305e-05, 'epoch': 0.82}
{'loss': 1.6474, 'learning_rate': 1.6050032916392363e-05, 'epoch': 0.99}
{'loss': 1.5055, 'learning_rate': 1.5391705069124425e-05, 'epoch': 1.15}
{'loss': 1.5075, 'learning_rate': 1.4733377221856486e-05, 'epoch': 1.32}
{'loss': 1.4942, 'learning_rate': 1.4075049374588544e-05, 'epoch': 1.48}
{'loss': 1.5152, 'learning_rate': 1.3416721527320606e-05, 'epoch': 1.65}
{'loss': 1.5004, 'learning_rate': 1.2758393680052667e-05, 'epoch': 1.81}
{'loss': 1.5037, 'learning_rate': 1.2100065832784729e-05, 'epoch': 1.97}
{'loss': 1.3629, 'learning_rate': 1.1441737985516787e-05, 'epoch': 2.14}
{'loss': 1.3243, 'learning_rate': 1.0783410138248848e

  0%|          | 0/2250 [00:00<?, ?it/s]