In [31]:
import pandas as pd

train_df = pd.read_csv('train.csv')
train_df.head()

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0
3,0001b41b1c6bb37e,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0
4,0001d958c54c6e35,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0


In [32]:
texts = train_df["comment_text"].tolist()

In [33]:
label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']

labels = train_df[label_cols].values

## Fine-tuning DistilBERT

In [34]:
from transformers import DistilBertTokenizerFast

tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

encodings = tokenizer(texts, truncation=True, padding=True, max_length=128)

In [35]:
# prepare dataset
import torch
from torch.utils.data import Dataset

class ToxicDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float)
        return item

    def __len__(self):
        return len(self.labels)

dataset = ToxicDataset(encodings, labels)

In [36]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

train_idx, val_idx = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)

train_dataset = Subset(dataset, train_idx)
val_dataset = Subset(dataset, val_idx)

In [37]:
# model
from transformers import DistilBertForSequenceClassification

model = DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=6,
    problem_type="multi_label_classification"
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [38]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=2e-5,
    per_device_train_batch_size=512,
    per_device_eval_batch_size=512,
    num_train_epochs=10,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=100,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,  # you can still pass this, eval will default to after each epoch
)

trainer.train()


Step,Training Loss
100,0.2386
200,0.0656
300,0.0508
400,0.0451
500,0.0434
600,0.0398
700,0.038
800,0.0364
900,0.0347
1000,0.0346


TrainOutput(global_step=2500, training_loss=0.041157275485992434, metrics={'train_runtime': 3521.3004, 'train_samples_per_second': 362.525, 'train_steps_per_second': 0.71, 'total_flos': 4.227866131156992e+16, 'train_loss': 0.041157275485992434, 'epoch': 10.0})

In [39]:
# predictions
preds = trainer.predict(val_dataset)
pred_labels = (torch.sigmoid(torch.tensor(preds.predictions)) > 0.5).int().numpy()
true_labels = preds.label_ids

In [40]:
from sklearn.metrics import classification_report

print(classification_report(true_labels, pred_labels, target_names=label_cols))

               precision    recall  f1-score   support

        toxic       0.82      0.85      0.83      3056
 severe_toxic       0.54      0.35      0.43       321
      obscene       0.83      0.83      0.83      1715
       threat       0.58      0.50      0.54        74
       insult       0.76      0.77      0.76      1614
identity_hate       0.67      0.50      0.57       294

    micro avg       0.79      0.78      0.79      7074
    macro avg       0.70      0.63      0.66      7074
 weighted avg       0.79      0.78      0.78      7074
  samples avg       0.07      0.07      0.07      7074



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
