# Neural Solution - Transformers: BERT

In [1]:
import os
import sys
import random
from datetime import datetime

import numpy as np
import polars as pl
from sklearn.model_selection import train_test_split

import torch
from torch import nn, optim, cuda
from torch.utils.data import DataLoader

In [2]:
ROOT_PATH = '../'
DRIVE_PATH = 'Colab/ToxicityClassification'

# When on Colab, use Google Drive as the root path to persist and load data
if 'google.colab' in sys.modules:
    from google.colab import drive
    drive.mount('/content/drive')
    ROOT_PATH = os.path.join('/content/drive/My Drive/', DRIVE_PATH)
    os.makedirs(ROOT_PATH, exist_ok=True)
    os.chdir(ROOT_PATH)

In [3]:
# Register the parent directory of the current script as a package root,
# so that we can import modules from the parent directory
sys.path.append(os.path.abspath(os.path.join(ROOT_PATH, 'src')))

from toxicity.transformers.xlmroberta_base import xlm_roberta_tokenizer, XLMRobertaDatasetBF16, XLMRobertaModuleBF16
from toxicity.transformers.training import train_epochs, validate, model_metrics

## Setup

In [4]:
# Target device for running the model
PYTORCH_DEVICE = 'cuda' if cuda.is_available() else 'cpu'

# Random Seed
RANDOM_SEED = 777

# Training & Validation configs
TRAIN_RATIO = 0.8
MAX_LEN = 256
TRAIN_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32
EPOCHS = 12
LEARNING_RATE = 3e-05
POS_WEIGHT = 1.663

print(f'Using device: {PYTORCH_DEVICE}')

Using device: cuda


In [5]:
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Data Loading

In [6]:
df = pl.read_parquet(os.path.join(ROOT_PATH, 'data', 'joint', 'data.parquet.zstd'))
df = df.with_columns(
    df['off_relaxed'].cast(pl.Int32).cast(pl.List(pl.Int32)).cast(pl.Array(pl.Int32, 1)),
    df['off_strict'].cast(pl.Int32).cast(pl.List(pl.Int32)).cast(pl.Array(pl.Int32, 1)),
)
df.sample(5, seed=42)

dataset,id,text,off_strict,off_relaxed
str,str,str,"array[i32, 1]","array[i32, 1]"
"""ToLD-Br""","""5508727285226739644""","""medo de ir pra um rolê de novo…",[0],[0]
"""ToLD-Br""","""16827841903506270139""","""https://t.co/2bs6oD330q Ele a…",[0],[0]
"""ToLD-Br""","""7641628880024884135""","""rt USER bruno fernandes assina…",[0],[0]
"""ToLD-Br""","""16866242508514532033""","""tinha que ter jogado esse bran…",[1],[1]
"""ToLD-Br""","""3068271252403811869""","""eu sou a pessoa certa no bairr…",[0],[0]


In [7]:
train_df, test_df = train_test_split(df, train_size=TRAIN_RATIO, random_state=RANDOM_SEED)
display(train_df.head(5))
display(test_df.head(5))

dataset,id,text,off_strict,off_relaxed
str,str,str,"array[i32, 1]","array[i32, 1]"
"""OLID-Br""","""3d85473d1c4b4f86a78159f23d7746…","""USER merda, ridículo essa impo…",[1],[1]
"""OLID-Br""","""b344c5518f0d44688ed45cc4a3183e…","""USER espero que eles sejam mor…",[1],[1]
"""ToLD-Br""","""4335543317461660187""","""eu tenho essas paran贸ias de ac…",[0],[0]
"""OLID-Br""","""7ada9be164434f0e925f50616b637c…","""USER USER é USER USER""",[0],[0]
"""ToLD-Br""","""16784738693255454158""","""meu pai me deu esse perfume eu…",[1],[0]


dataset,id,text,off_strict,off_relaxed
str,str,str,"array[i32, 1]","array[i32, 1]"
"""ToLD-Br""","""10657414299548058873""","""rt USER mano tá tudo me irrita…",[1],[0]
"""ToLD-Br""","""11088205621966361413""","""USER horrível!""",[1],[0]
"""ToLD-Br""","""11546370057009176494""","""gnt como pode falar q esse hom…",[0],[0]
"""ToLD-Br""","""9450469262872738701""","""Que foda o USER PUTA QUE PARIU…",[0],[0]
"""ToLD-Br""","""16835911729407698751""","""sapatão é foda, não pode beber…",[1],[1]


## Setup Model

In [8]:
tokenizer = xlm_roberta_tokenizer()

model = XLMRobertaModuleBF16(feature_count=1)
model.to(PYTORCH_DEVICE)

train_loader = DataLoader(XLMRobertaDatasetBF16(data_frame=train_df, tokenizer=tokenizer, max_len=MAX_LEN, target_col='off_relaxed'), shuffle=True,
                          num_workers=0, batch_size=TRAIN_BATCH_SIZE)
test_loader = DataLoader(XLMRobertaDatasetBF16(data_frame=test_df, tokenizer=tokenizer, max_len=MAX_LEN, target_col='off_relaxed'), shuffle=True,
                         num_workers=0, batch_size=TEST_BATCH_SIZE)

### Loss and Optimizer

Using a Binary Cross Entropy loss as it shows good results for binary classification tasks. We are also applying differente weights to the positive and negative classes to account for the class imbalance.

Adam optimizer is also used as it is a good general optimizer for training neural networks, with good known results for BERT models.

In [9]:
loss_function = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([POS_WEIGHT], device=PYTORCH_DEVICE))
optimizer = optim.AdamW(params=model.parameters(), lr=LEARNING_RATE)

### Training

In [10]:
def validate_result():
    # Validate the results
    raw_results, raw_targets = validate(model, test_loader, PYTORCH_DEVICE)
    raw_results = np.array(raw_results)
    raw_targets = np.array(raw_targets)

    # Apply a fixed threshold to the results
    FIXED_THRESHOLD = 0.75
    fixed_results = raw_results > FIXED_THRESHOLD
    fixed_targets = raw_targets > FIXED_THRESHOLD

    # Compute metrics
    model_metrics(fixed_targets, fixed_results, print_metrics=True)

In [11]:
TIMESTAMP = datetime.now().strftime('%Y%m%d%H%M%S')
CHECKPOINT_PATH = os.path.join(ROOT_PATH, 'checkpoints', 'bertimbau-bf16', TIMESTAMP)

def epoch_callback(epoch, avg_loss):
    print(f'Epoch: {epoch}, Avg Loss: {avg_loss}')
    print('Validation Results:')
    validate_result()

train_epochs(
    EPOCHS, model, train_loader, loss_function, optimizer, PYTORCH_DEVICE,
    checkpoint_path=CHECKPOINT_PATH, epoch_callback=epoch_callback)

Running training epoch 1/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 0, Avg Loss: 0.7465184638769671
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.7349947163626172
Macro F1: 0.7038388193920848
Weighted F2: 0.7416194323243825
Macro F2: 0.6938928497614651
Accuracy: 0.752101591844035
Recall: 0.752101591844035
Precision: 0.755180865709636
Finished training epoch 1/12; Average Loss: 0.7465
Running training epoch 2/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 1, Avg Loss: 0.6419270833333334
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.7719474743293655
Macro F1: 0.749507141196923
Weighted F2: 0.7749286642973829
Macro F2: 0.7421804862157453
Accuracy: 0.7791092827758899
Recall: 0.7791092827758899
Precision: 0.7768438646710073
Finished training epoch 2/12; Average Loss: 0.6419
Running training epoch 3/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 2, Avg Loss: 0.5980334629828327
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.7735616435387823
Macro F1: 0.7488499777636499
Weighted F2: 0.7776069736555253
Macro F2: 0.7383585660223377
Accuracy: 0.7846539080665355
Recall: 0.7846539080665355
Precision: 0.787898035927992
Finished training epoch 3/12; Average Loss: 0.5980
Running training epoch 4/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 3, Avg Loss: 0.5713938438841202
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.7757018111534794
Macro F1: 0.7514678242239272
Weighted F2: 0.7795692942386013
Macro F2: 0.7411690625644154
Accuracy: 0.7862636379896262
Recall: 0.7862636379896262
Precision: 0.7890432914407931
Finished training epoch 4/12; Average Loss: 0.5714
Running training epoch 5/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 4, Avg Loss: 0.5444496602288984
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.7938615875787027
Macro F1: 0.7743861459780432
Weighted F2: 0.7960456379323992
Macro F2: 0.7674972957366737
Accuracy: 0.7991414773743516
Recall: 0.7991414773743516
Precision: 0.797443670145626
Finished training epoch 5/12; Average Loss: 0.5444
Running training epoch 6/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 5, Avg Loss: 0.5276762562589413
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.7866779444105086
Macro F1: 0.7646387639954157
Weighted F2: 0.7897872075997086
Macro F2: 0.7552064507044782
Accuracy: 0.7950277231264532
Recall: 0.7950277231264532
Precision: 0.7962374472304914
Finished training epoch 6/12; Average Loss: 0.5277
Running training epoch 7/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 6, Avg Loss: 0.5153092587625179
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.7993151774483657
Macro F1: 0.7819125422011183
Weighted F2: 0.800729437869708
Macro F2: 0.7774245892995892
Accuracy: 0.802360937220533
Recall: 0.802360937220533
Precision: 0.7998704260693706
Finished training epoch 7/12; Average Loss: 0.5153
Running training epoch 8/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 7, Avg Loss: 0.49632566613018597
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.8002640488004263
Macro F1: 0.7831194630985985
Weighted F2: 0.8015879829971616
Macro F2: 0.778904051427146
Accuracy: 0.8030763727419067
Recall: 0.8030763727419067
Precision: 0.8006086848491462
Finished training epoch 8/12; Average Loss: 0.4963
Running training epoch 9/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 8, Avg Loss: 0.4859215732296137
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.80423252696951
Macro F1: 0.7872670811838022
Weighted F2: 0.805594014146408
Macro F2: 0.7827048762248567
Accuracy: 0.8071901269898051
Recall: 0.8071901269898051
Precision: 0.8048748287279052
Finished training epoch 9/12; Average Loss: 0.4859
Running training epoch 10/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 9, Avg Loss: 0.47342045556151646
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.7905214003531067
Macro F1: 0.7689505943934498
Weighted F2: 0.7935002270081191
Macro F2: 0.7594456957919107
Accuracy: 0.7986049007333214
Recall: 0.7986049007333214
Precision: 0.8000444978087619
Finished training epoch 10/12; Average Loss: 0.4734
Running training epoch 11/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 10, Avg Loss: 0.46837697827253216
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.8037565983214607
Macro F1: 0.7875461974298231
Weighted F2: 0.8047510220890167
Macro F2: 0.7843482897117156
Accuracy: 0.8057592559470578
Recall: 0.8057592559470578
Precision: 0.8035324436329842
Finished training epoch 11/12; Average Loss: 0.4684
Running training epoch 12/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 11, Avg Loss: 0.4511872429363376
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.796562206978919
Macro F1: 0.7775575180993192
Weighted F2: 0.7986096347579544
Macro F2: 0.7709323691683581
Accuracy: 0.801466642818816
Recall: 0.801466642818816
Precision: 0.7996832473207056
Finished training epoch 12/12; Average Loss: 0.4512


### Save the model and tokenizer

In [12]:
target_dir = os.path.join(ROOT_PATH, 'models/trained-xlm-roberta-bf16-{TIMESTAMP}')

os.makedirs(target_dir, exist_ok=True)

tokenizer.save_vocabulary(target_dir)
torch.save(model, f'{target_dir}/model.pth')

### Run the model with the test data

In [13]:
# Validate the results
raw_results, raw_targets = validate(model, test_loader, PYTORCH_DEVICE)
raw_results = np.array(raw_results)
raw_targets = np.array(raw_targets)

  0%|          | 0/175 [00:00<?, ?it/s]

### Check results

In [14]:
FIXED_THRESHOLD = 0.75
fixed_results = raw_results > FIXED_THRESHOLD
fixed_targets = raw_targets > FIXED_THRESHOLD

In [15]:
from sklearn.metrics import (
    f1_score, fbeta_score, accuracy_score, recall_score, precision_score)

fixed_weighted_f1 = f1_score(fixed_targets, fixed_results, average='weighted')
fixed_macro_f1 = f1_score(fixed_targets, fixed_results, average='macro')
fixed_weighted_f2 = fbeta_score(fixed_targets, fixed_results, beta=2, average='weighted')
fixed_macro_f2 = fbeta_score(fixed_targets, fixed_results, beta=2, average='macro')
fixed_accuracy = accuracy_score(fixed_targets, fixed_results)
fixed_recall = recall_score(fixed_targets, fixed_results, average='weighted')
fixed_precision = precision_score(fixed_targets, fixed_results, average='weighted')

print("Model Metrics:")
print(f"Weighted F1 = {fixed_weighted_f1:.6f}")
print(f"Macro F1 = {fixed_macro_f1:.6f}")
print(f"Weighted F2 Score = {fbeta_score(fixed_targets, fixed_results, beta=2, average='weighted'):.6f}")
print(f"Macro F2 Score = {fbeta_score(fixed_targets, fixed_results, beta=2, average='macro'):.6f}")
print(f"Accuracy = {fixed_accuracy:.6f}")
print(f"Recall = {fixed_recall:.6f}")
print(f"Precision = {fixed_precision:.6f}")

Model Metrics:
Weighted F1 = 0.796562
Macro F1 = 0.777558
Weighted F2 Score = 0.798610
Macro F2 Score = 0.770932
Accuracy = 0.801467
Recall = 0.801467
Precision = 0.799683
