<a href="https://colab.research.google.com/github/Himagination/NLP_Transformers/blob/main/BERT_Fine_Tuning_Toxic_Comment_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install pytorch-lightning --quiet
!pip install transformers --quiet
!pip install tf-estimator-nightly==2.8.0.dev2021122109
!pip install folium==0.2.1

Collecting tf-estimator-nightly==2.8.0.dev2021122109
  Downloading tf_estimator_nightly-2.8.0.dev2021122109-py2.py3-none-any.whl (462 kB)
[?25l[K     |▊                               | 10 kB 16.6 MB/s eta 0:00:01[K     |█▍                              | 20 kB 11.2 MB/s eta 0:00:01[K     |██▏                             | 30 kB 9.4 MB/s eta 0:00:01[K     |██▉                             | 40 kB 8.5 MB/s eta 0:00:01[K     |███▌                            | 51 kB 4.3 MB/s eta 0:00:01[K     |████▎                           | 61 kB 5.1 MB/s eta 0:00:01[K     |█████                           | 71 kB 5.5 MB/s eta 0:00:01[K     |█████▊                          | 81 kB 5.7 MB/s eta 0:00:01[K     |██████▍                         | 92 kB 6.3 MB/s eta 0:00:01[K     |███████                         | 102 kB 5.1 MB/s eta 0:00:01[K     |███████▉                        | 112 kB 5.1 MB/s eta 0:00:01[K     |████████▌                       | 122 kB 5.1 MB/s eta 0:00:01[K     |█

In [43]:
# Imports
import pandas as pd
import numpy as np

from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torchmetrics.functional import auroc
from torch.utils.data import Dataset, DataLoader

import transformers

import pytorch_lightning as pl
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, multilabel_confusion_matrix
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

In [5]:
%matplotlib inline
RANDOM_SEED = 42
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]
sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))
rcParams['figure.figsize'] = 12, 8
pl.seed_everything(RANDOM_SEED)

Global seed set to 42


42

## Data

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [8]:
df = pd.read_csv("/content/drive/MyDrive/train.csv")
df.head(20)

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0
3,0001b41b1c6bb37e,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0
4,0001d958c54c6e35,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0
5,00025465d4725e87,"""\n\nCongratulations from me as well, use the ...",0,0,0,0,0,0
6,0002bcb3da6cb337,COCKSUCKER BEFORE YOU PISS AROUND ON MY WORK,1,1,1,0,1,0
7,00031b1e95af7921,Your vandalism to the Matt Shirvington article...,0,0,0,0,0,0
8,00037261f536c51d,Sorry if the word 'nonsense' was offensive to ...,0,0,0,0,0,0
9,00040093b2687caa,alignment on this subject and which are contra...,0,0,0,0,0,0


In [34]:
train_df, val_df = train_test_split(df, test_size=0.05)
train_df.shape, val_df.shape

((151592, 8), (7979, 8))

In [35]:
LABEL_COLUMNS = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]

In [36]:
train_df[LABEL_COLUMNS].sum()

toxic            14516
severe_toxic      1511
obscene           8003
threat             448
insult            7466
identity_hate     1339
dtype: int64

## Exploratory Data Analysis

In [37]:
train_toxic = train_df[train_df[LABEL_COLUMNS].sum(axis=1) > 0]
train_clean = train_df[train_df[LABEL_COLUMNS].sum(axis=1) == 0]
train_toxic.shape, train_clean.shape

((15389, 8), (136203, 8))

In [39]:
train_df = pd.concat([
                      train_toxic, 
                      train_clean.sample(15_000)
])

In [13]:
sample_row = df.iloc[16]
sample_comment = sample_row.comment_text
sample_label = sample_row[LABEL_COLUMNS]

print(f"Sample comment: {sample_comment}\n")
print(f"Sample label: {sample_label.to_dict()}")

Sample comment: Bye! 

Don't look, come or think of comming back! Tosser.

Sample label: {'toxic': 1, 'severe_toxic': 0, 'obscene': 0, 'threat': 0, 'insult': 0, 'identity_hate': 0}


In [14]:
BERT_MODEL_NAME = "bert-base-cased"
tokenizer = transformers.BertTokenizerFast.from_pretrained(BERT_MODEL_NAME)

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/426k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [16]:
encoding = tokenizer.encode_plus(
    sample_comment, 
    add_special_tokens=True, 
    max_length=512, 
    return_token_type_ids=False, 
    padding="max_length", 
    return_attention_mask=True, 
    return_tensors="pt"
)

In [17]:
encoding.keys()

dict_keys(['input_ids', 'attention_mask'])

In [18]:
encoding["input_ids"].shape, encoding["attention_mask"].shape

(torch.Size([1, 512]), torch.Size([1, 512]))

In [21]:
class ToxicCommentsDataset(Dataset):
  def __init__(self, data: pd.DataFrame, 
               tokenizer: transformers.BertTokenizerFast, 
               max_token_len: int = 128):
    self.data = data
    self.tokenizer = tokenizer
    self.max_token_len = max_token_len

  def __len__(self):
    return len(self.data)

  def __getitem__(self, index: int):
    data_row = self.data.iloc[index]
    comment_text = data_row.comment_text
    labels = data_row[LABEL_COLUMNS]

    encoding = self.tokenizer.encode_plus(
        comment_text, 
        add_special_tokens = True, 
        max_length = self.max_token_len, 
        return_token_type_ids = False, 
        padding = "max_length", 
        truncation = True, 
        return_attention_mask = True, 
        return_tensors = "pt"
    )
    return dict(
        comment_text = comment_text, 
        input_ids = encoding["input_ids"].flatten(), 
        attention_mask = encoding["attention_mask"].flatten(), 
        labels = torch.FloatTensor(labels)
    )

In [22]:
train_dataset = ToxicCommentsDataset(train_df, tokenizer)
sample_item = train_dataset[0]

In [23]:
sample_item.keys()

dict_keys(['comment_text', 'input_ids', 'attention_mask', 'labels'])

In [24]:
sample_item["comment_text"]

'India related links \n\n • Talk • [ Reply]'

In [25]:
sample_item["labels"]

tensor([0., 0., 0., 0., 0., 0.])

In [26]:
sample_item["input_ids"].shape

torch.Size([128])

In [29]:
bert_model = transformers.BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [30]:
sample_item["input_ids"].unsqueeze(dim=0).shape

torch.Size([1, 128])

In [31]:
prediction = bert_model(sample_item["input_ids"].unsqueeze(dim=0), 
                        sample_item["attention_mask"].unsqueeze(dim=0))

In [32]:
class ToxicCommentDataModule(pl.LightningDataModule):
  def __init__(self, train_df, test_df, tokenizer, batch_size=8, max_token_len=128):
    super().__init__()
    self.train_df = train_df
    self.test_df = test_df
    self.tokenizer = tokenizer
    self.batch_size = batch_size
    self.max_token_len = max_token_len

  def setup(self):
    self.train_dataset = ToxicCommentsDataset(
        self.train_df, 
        self.tokenizer, 
        self.max_token_len
    )

    self.test_dataset = ToxicCommentsDataset(
        self.test_df, 
        self.tokenizer, 
        self.max_token_len
    )

  def train_dataloader(self):
    return DataLoader(
        self.train_dataset, 
        batch_size=self.batch_size, 
        shuffle=True, 
        num_workers=4
    )

  def val_dataloader(self):
    return DataLoader(
        self.test_dataset, 
        batch_size=1, 
        num_workers=4
    )

  def test_dataloader(self):
    return DataLoader(self.test_dataset, 
                      batch_size=1, 
                      num_workers=4)

In [40]:
N_EPOCHS = 10
BATCH_SIZE = 32
data_module = ToxicCommentDataModule(train_df, 
                                     val_df, 
                                     tokenizer, 
                                     batch_size=BATCH_SIZE)
data_module.setup()

## Modeling

### Evaluation

In [44]:
from torchmetrics.functional.classification.auroc import auroc
class ToxicCommentClassifier(pl.LightningModule):
  def __init__(self, n_classes:int, 
               steps_per_epoch=None, 
               n_epochs=None):
    super().__init__()
    self.bert = transformers.BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)
    self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
    self.steps_per_epoch = steps_per_epoch
    self.n_epochs = n_epochs
    self.criterion = nn.BCELoss()

  def forward(self, input_ids, attention_mask, labels=None):
    output = self.bert(input_ids, attention_mask=attention_mask)
    output = self.classifier(output.pooler_output)
    output = torch.sigmoid(output)
    loss = 0
    if labels is not None:
      loss = self.criterion(output, labels)
    return loss, output

  def training_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("train_loss", loss, prog_bar=True, logger=True)
    return {"loss": loss, "predictions": outputs, "labels": labels}

  def validation_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("train_loss", loss, prog_bar=True, logger=True)
    return loss

  def test_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("train_loss", loss, prog_bar=True, logger=True)
    return loss

  def training_epoch_end(self, outputs):
    labels = []
    predictions = []
    for output in outputs:
      for out_labels in output["labels"].detach().cpu():
        labels.append(out_labels)
      for out_predictions in output["predictions"].detach().cpu():
        predictions = torch.stack(predictions)

    labels = torch.stack(labels)
    predictions = torch.stack(predictions)

    for i, name in enumerate(LABEL_COLUMNS):
      roc_score = auroc(predictions[:,i], labels[:,i])
      self.logger.experiment.add_scatter(f"{name}_roc_auc/Train", roc_score, self.current_epoch)

  def configure_optimizers(self):
    optimizer = transformers.AdamW(self.parameters(), lr=2e-5)
    warmup_steps = self.steps_per_epoch // 3
    total_steps = self.steps_per_epoch*self.n_epochs - warmup_steps
    scheduler = transformers.get_linear_schedule_with_warmup(
        optimizer, 
        warmup_steps, 
        total_steps
    )
    return [optimizer], [scheduler]

In [45]:
model = ToxicCommentClassifier(
    n_classes = 6, 
    steps_per_epoch=len(train_df) // BATCH_SIZE, 
    n_epochs=N_EPOCHS
)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [46]:
_, predictions = model(
    sample_item["input_ids"].unsqueeze(dim=0), 
    sample_item["attention_mask"].unsqueeze(dim=0)
)

In [47]:
predictions

tensor([[0.5028, 0.4061, 0.3472, 0.5252, 0.4653, 0.2846]],
       grad_fn=<SigmoidBackward0>)

In [48]:
trainer = pl.Trainer(max_epochs=N_EPOCHS, 
                     gpus=1, 
                     progress_bar_refresh_rate=30)

  f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [49]:
trainer.fit(model, data_module)

  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Missing logger folder: /content/lightning_logs

  | Name       | Type      | Params
-----------------------------------------
0 | bert       | BertModel | 108 M 
1 | classifier | Linear    | 4.6 K 
2 | criterion  | BCELoss   | 0     
-----------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
433.260   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  cpuset_checked))
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
Global seed set to 42


Training: 0it [00:00, ?it/s]

  f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"


Validating: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection.

RuntimeError: ignored