# Toxic Comment Classification

In [1]:
!git clone https://github.com/GitYCC/bert-minimal-tutorial.git

Cloning into 'bert-minimal-tutorial'...
remote: Enumerating objects: 128, done.[K
remote: Counting objects: 100% (128/128), done.[K
remote: Compressing objects: 100% (107/107), done.[K
remote: Total 128 (delta 68), reused 62 (delta 19), pack-reused 0[K
Receiving objects: 100% (128/128), 38.88 MiB | 23.00 MiB/s, done.
Resolving deltas: 100% (68/68), done.


In [2]:
%cd bert-minimal-tutorial

/content/bert-minimal-tutorial


In [3]:
!pip install -q -r requirements.txt

[K     |████████████████████████████████| 235kB 11.2MB/s 
[K     |████████████████████████████████| 829kB 34.5MB/s 
[K     |████████████████████████████████| 1.3MB 53.3MB/s 
[K     |████████████████████████████████| 225kB 57.3MB/s 
[K     |████████████████████████████████| 512kB 52.9MB/s 
[K     |████████████████████████████████| 727kB 51.6MB/s 
[K     |████████████████████████████████| 71kB 11.2MB/s 
[K     |████████████████████████████████| 890kB 51.0MB/s 
[K     |████████████████████████████████| 6.8MB 52.7MB/s 
[K     |████████████████████████████████| 25.9MB 95kB/s 
[K     |████████████████████████████████| 1.1MB 47.6MB/s 
[K     |████████████████████████████████| 51kB 9.0MB/s 
[K     |████████████████████████████████| 2.9MB 46.1MB/s 
[K     |████████████████████████████████| 1.3MB 47.4MB/s 
[K     |████████████████████████████████| 133kB 58.3MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone
  Building wheel for sacremoses (setup.py) ... [?25l

In [4]:
import os

import pandas as pd
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizer, BertModel, BertPreTrainedModel
from tqdm.notebook import tqdm
from sklearn import metrics

from utils import RunningAverage

MODEL_NAME = 'bert-base-uncased'
SEED = 1234

torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

## Dataloader

In [5]:
df = pd.read_csv('data/toxic_comment_classification.csv')
df = df.sample(frac=1).reset_index(drop=True)  # shuffle

In [6]:
df

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,157282532efe7f28,"the source, so that the ambiguity can be furth...",0,0,0,0,0,0
1,f663a661afaa5f0e,"""\n\n Vandal box question \n\nWhat can I say, ...",0,0,0,0,0,0
2,cba00537b7bf014e,This article seems like advertising to me. It ...,0,0,0,0,0,0
3,2e21c70cc5d4be4f,bullshit the only reason striver is even pushi...,1,0,1,0,0,0
4,4422cc7ce657f7d9,"""\n\nYelling """"gook"""" over and over again at a...",1,0,0,0,0,0
...,...,...,...,...,...,...,...,...
159566,19ef219f8820d8e5,"""\n\n Hey \n\nI saw what you have done to info...",0,0,0,0,0,0
159567,28f074584e497232,"""\n\n Previously unreleased track on """"The Bes...",0,0,0,0,0,0
159568,eb1f8a915e951f62,I don't know you from a whole inthe ground. I...,0,0,0,0,0,0
159569,1fb532e2001a10db,Add Media Matters back\n add Opposition to the...,0,0,0,0,0,0


In [7]:
LABELS = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']

In [8]:
class MultiLabelDataset(Dataset):
    def __init__(self, tokenizer, df, max_len=512, for_train=True):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.for_train = for_train

        self.texts = []
        self.labels = []
        for _, row in df.iterrows():
            self.texts.append(row['comment_text'])
            if for_train:
                self.labels.append([row[col] for col in LABELS])

    def __getitem__(self, idx):
        text = self.texts[idx]
        tokens = self.tokenizer.tokenize(text)
        tokens = tokens[:self.max_len-2]
        processed_tokens = ['[CLS]'] + tokens + ['[SEP]']

        input_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(processed_tokens))
        token_type_ids = torch.tensor([0] * len(processed_tokens))
        attention_mask = torch.tensor([1] * len(processed_tokens))

        outputs = (input_ids, token_type_ids, attention_mask)

        if self.for_train:
            label = self.labels[idx]
            label = torch.tensor(label)
            outputs += (label, )

        return outputs

    def __len__(self):
        return len(self.texts)

    def create_mini_batch(self, samples):
        outputs = list(zip(*samples))

        # zero pad 到同一序列長度
        input_ids = pad_sequence(outputs[0], batch_first=True)
        token_type_ids = pad_sequence(outputs[1], batch_first=True)
        attention_mask = pad_sequence(outputs[2], batch_first=True)

        batch_output = (input_ids, token_type_ids, attention_mask)
    
        if self.for_train:
            labels = torch.stack(outputs[3])
            batch_output += (labels, )

        return batch_output

In [9]:
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)

dataset = MultiLabelDataset(tokenizer, df)

CUT_RATIO = 0.9
train_size = int(CUT_RATIO * len(dataset))
valid_size = len(dataset) - train_size
train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [10]:
batch_size = 2

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    collate_fn=dataset.create_mini_batch,
    shuffle=True
)
valid_loader = DataLoader(
    dataset=valid_dataset,
    batch_size=batch_size,
    collate_fn=dataset.create_mini_batch,
)

## Model

In [15]:
class BertForMultiLabelSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
        _, pooled_output = self.bert(
            input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask
        )
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        if labels is not None:
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(logits, labels.float())
            return logits, loss
        else:
            return logits

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

model = BertForMultiLabelSequenceClassification.from_pretrained(
    MODEL_NAME, 
    num_labels=len(LABELS)
)
model.to(device)

device: cuda


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMultiLabelSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForMultiLabelSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForMultiLabelSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultiLabelSequenceClassification were not 

BertForMultiLabelSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-1

## Train

In [17]:
def train_batch(model, data, optimizer, device):
    model.train()
    input_ids, token_type_ids, attention_mask, labels = [d.to(device) for d in data]

    _, loss = model(
        input_ids=input_ids,
        token_type_ids=token_type_ids,
        attention_mask=attention_mask,
        labels=labels
    )

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

def evaluate(model, valid_loader, device):
    model.eval()

    loss_averager = RunningAverage()
    all_preds = {l: [] for l in LABELS}
    all_labels = {l: [] for l in LABELS}

    with torch.no_grad():
        for data in tqdm(valid_loader, desc='evaluate'):
            input_ids, token_type_ids, attention_mask, labels = [d.to(device) for d in data]

            logits, loss = model(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss_averager.add(loss.item())

            preds = (torch.sigmoid(logits) > 0.5).int().cpu().numpy()
            labels = labels.cpu().numpy()
            for i, l in enumerate(LABELS):
                all_preds[l] += preds[:, i].tolist()
                all_labels[l] += labels[:, i].tolist()

    f1 = {
        l: metrics.f1_score(all_labels[l], all_preds[l])
        for l in LABELS
    }

    return loss_averager.get(), f1

In [18]:
lr = 0.00001
max_iter = 10000
show_per_iter = 500
valid_per_iter = 5000
save_per_iter = 5000
save_checkpoint_dir = 'models/'
model_prefix = 'en_toxic_label_'

assert save_per_iter % valid_per_iter == 0

optimizer = optim.Adam(model.parameters(), lr=lr)

i = 1
is_running = True
train_loss_averager = RunningAverage()
model_paths = []
while is_running:
    for train_data in train_loader:
        loss = train_batch(model, train_data, optimizer, device)
        train_loss_averager.add(loss)

        if i % show_per_iter == 0:
            print('train [{}]: loss={}'.format(i, train_loss_averager.get()))
            train_loss_averager.flush()

        if i % valid_per_iter == 0:
            loss, f1 = evaluate(model, valid_loader, device)
            print(f'valid: loss={loss}, f1={f1}')

        if i % save_per_iter == 0:
            path = os.path.join(save_checkpoint_dir, model_prefix + f'loss{loss:.5}/')
            print(f'save model at {path}')
            model.save_pretrained(path)
            model_paths.append(path)
        
        if i == max_iter:
            is_running = False
            break

        i += 1

train [500]: loss=0.17788839216157795
train [1000]: loss=0.07699349229969084
train [1500]: loss=0.05766050508990884
train [2000]: loss=0.06068963208142668
train [2500]: loss=0.053303373756352815
train [3000]: loss=0.055981332342606036
train [3500]: loss=0.049565680519212035
train [4000]: loss=0.0585138334967196
train [4500]: loss=0.04934469473268837
train [5000]: loss=0.05635646560182795


HBox(children=(FloatProgress(value=0.0, description='evaluate', max=7979.0, style=ProgressStyle(description_wi…


valid: loss=0.0473065771338643, f1={'toxic': 0.7990039131981501, 'severe_toxic': 0.045714285714285714, 'obscene': 0.8170802077322563, 'threat': 0.0, 'insult': 0.7529722589167768, 'identity_hate': 0.0}
save model at models/en_toxic_label_loss0.047307/
train [5500]: loss=0.04969648266583681
train [6000]: loss=0.051277401978615675
train [6500]: loss=0.05795165065466426
train [7000]: loss=0.04712677399790846
train [7500]: loss=0.05325233582430519
train [8000]: loss=0.04140716752503067
train [8500]: loss=0.05056395696895197
train [9000]: loss=0.047686740265460686
train [9500]: loss=0.0501032347320579
train [10000]: loss=0.04579673587461002


HBox(children=(FloatProgress(value=0.0, description='evaluate', max=7979.0, style=ProgressStyle(description_wi…


valid: loss=0.046249122887243206, f1={'toxic': 0.8221632382216324, 'severe_toxic': 0.5161290322580645, 'obscene': 0.8293228875209849, 'threat': 0.05128205128205128, 'insult': 0.765755053507729, 'identity_hate': 0.28013029315960913}
save model at models/en_toxic_label_loss0.046249/


## Predict

In [19]:
reload_checkpoint = model_paths[-1]

examples = [
    'Fuck you! You son of bitch',
    'I will kill you soon'
]
examples_df = pd.DataFrame(data={'comment_text': examples})

pred_dataset = MultiLabelDataset(tokenizer, examples_df, for_train=False)

pred_loader = DataLoader(
    dataset=pred_dataset,
    batch_size=batch_size,
    collate_fn=pred_dataset.create_mini_batch,
)

model = BertForMultiLabelSequenceClassification.from_pretrained(reload_checkpoint)
model.to(device)

pred_labels = []
with torch.no_grad():
    for data in tqdm(pred_loader, desc='predict'):
        input_ids, token_type_ids, attention_mask = [d.to(device) for d in data]

        logits = model(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask
        )

        pred_labels += (torch.sigmoid(logits) > 0.5).int().cpu().tolist()

print('predict result: ', list(zip(examples, pred_labels)))

HBox(children=(FloatProgress(value=0.0, description='predict', max=1.0, style=ProgressStyle(description_width=…


predict result:  [('Fuck you! You son of bitch', [1, 1, 1, 1, 1, 1]), ('I will kill you soon', [1, 0, 1, 0, 1, 0])]
