# Neural Solution - Transformers: BERT

In [1]:
import os
import sys
import random
from datetime import datetime

import numpy as np
import polars as pl
from sklearn.model_selection import train_test_split

import torch
from torch import nn, optim, cuda
from torch.utils.data import DataLoader

In [2]:
ROOT_PATH = '../'
DRIVE_PATH = 'Colab/ToxicityClassification'

# When on Colab, use Google Drive as the root path to persist and load data
if 'google.colab' in sys.modules:
    from google.colab import drive
    drive.mount('/content/drive')
    ROOT_PATH = os.path.join('/content/drive/My Drive/', DRIVE_PATH)
    os.makedirs(ROOT_PATH, exist_ok=True)
    os.chdir(ROOT_PATH)

In [3]:
# Register the parent directory of the current script as a package root,
# so that we can import modules from the parent directory
sys.path.append(os.path.abspath(os.path.join(ROOT_PATH, 'src')))

from toxicity.transformers.bertimbau_base import bert_tokenizer, BertDatasetBF16, BertModuleBF16
from toxicity.transformers.training import train_epochs, validate, model_metrics

## Setup

In [4]:
# Target device for running the model
PYTORCH_DEVICE = 'cuda' if cuda.is_available() else 'cpu'

# Random Seed
RANDOM_SEED = 777

# Training & Validation configs
TRAIN_RATIO = 0.8
MAX_LEN = 256
TRAIN_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32
EPOCHS = 12
LEARNING_RATE = 3e-05
POS_WEIGHT = 1.663

print(f'Using device: {PYTORCH_DEVICE}')

Using device: cuda


In [5]:
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Data Loading

In [6]:
df = pl.read_parquet(os.path.join(ROOT_PATH, 'data', 'joint', 'data.parquet.zstd'))
df = df.with_columns(
    df['off_relaxed'].cast(pl.Int32).cast(pl.List(pl.Int32)).cast(pl.Array(pl.Int32, 1)),
    df['off_strict'].cast(pl.Int32).cast(pl.List(pl.Int32)).cast(pl.Array(pl.Int32, 1)),
)
df.sample(5, seed=42)

dataset,id,text,off_strict,off_relaxed
str,str,str,"array[i32, 1]","array[i32, 1]"
"""ToLD-Br""","""5508727285226739644""","""medo de ir pra um rolê de novo…",[0],[0]
"""ToLD-Br""","""16827841903506270139""","""https://t.co/2bs6oD330q Ele a…",[0],[0]
"""ToLD-Br""","""7641628880024884135""","""rt USER bruno fernandes assina…",[0],[0]
"""ToLD-Br""","""16866242508514532033""","""tinha que ter jogado esse bran…",[1],[1]
"""ToLD-Br""","""3068271252403811869""","""eu sou a pessoa certa no bairr…",[0],[0]


In [7]:
train_df, test_df = train_test_split(df, train_size=TRAIN_RATIO, random_state=RANDOM_SEED)
display(train_df.head(5))
display(test_df.head(5))

dataset,id,text,off_strict,off_relaxed
str,str,str,"array[i32, 1]","array[i32, 1]"
"""OLID-Br""","""3d85473d1c4b4f86a78159f23d7746…","""USER merda, ridículo essa impo…",[1],[1]
"""OLID-Br""","""b344c5518f0d44688ed45cc4a3183e…","""USER espero que eles sejam mor…",[1],[1]
"""ToLD-Br""","""4335543317461660187""","""eu tenho essas paran贸ias de ac…",[0],[0]
"""OLID-Br""","""7ada9be164434f0e925f50616b637c…","""USER USER é USER USER""",[0],[0]
"""ToLD-Br""","""16784738693255454158""","""meu pai me deu esse perfume eu…",[1],[0]


dataset,id,text,off_strict,off_relaxed
str,str,str,"array[i32, 1]","array[i32, 1]"
"""ToLD-Br""","""10657414299548058873""","""rt USER mano tá tudo me irrita…",[1],[0]
"""ToLD-Br""","""11088205621966361413""","""USER horrível!""",[1],[0]
"""ToLD-Br""","""11546370057009176494""","""gnt como pode falar q esse hom…",[0],[0]
"""ToLD-Br""","""9450469262872738701""","""Que foda o USER PUTA QUE PARIU…",[0],[0]
"""ToLD-Br""","""16835911729407698751""","""sapatão é foda, não pode beber…",[1],[1]


## Setup Model

In [8]:
tokenizer = bert_tokenizer()

model = BertModuleBF16(feature_count=1)
model.to(PYTORCH_DEVICE)

train_loader = DataLoader(BertDatasetBF16(data_frame=train_df, tokenizer=tokenizer, max_len=MAX_LEN, target_col='off_relaxed'), shuffle=True,
                          num_workers=0, batch_size=TRAIN_BATCH_SIZE)
test_loader = DataLoader(BertDatasetBF16(data_frame=test_df, tokenizer=tokenizer, max_len=MAX_LEN, target_col='off_relaxed'), shuffle=True,
                         num_workers=0, batch_size=TEST_BATCH_SIZE)

### Loss and Optimizer

Using a Binary Cross Entropy loss as it shows good results for binary classification tasks. We are also applying differente weights to the positive and negative classes to account for the class imbalance.

Adam optimizer is also used as it is a good general optimizer for training neural networks, with good known results for BERT models.

In [9]:
loss_function = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([POS_WEIGHT], device=PYTORCH_DEVICE))
optimizer = optim.AdamW(params=model.parameters(), lr=LEARNING_RATE)

### Training

### Result Validation

In [10]:
def validate_result():
    # Validate the results
    raw_results, raw_targets = validate(model, test_loader, PYTORCH_DEVICE)
    raw_results = np.array(raw_results)
    raw_targets = np.array(raw_targets)

    # Apply a fixed threshold to the results
    FIXED_THRESHOLD = 0.75
    fixed_results = raw_results > FIXED_THRESHOLD
    fixed_targets = raw_targets > FIXED_THRESHOLD

    # Compute metrics
    model_metrics(fixed_targets, fixed_results, print_metrics=True)

### Train Model

In [11]:
TIMESTAMP = datetime.now().strftime('%Y%m%d%H%M%S')
CHECKPOINT_PATH = os.path.join(ROOT_PATH, 'checkpoints', 'bertimbau-bf16', TIMESTAMP)

def epoch_callback(epoch, avg_loss):
    print(f'Epoch: {epoch}, Avg Loss: {avg_loss}')
    print('Validation Results:')
    validate_result()

train_epochs(
    EPOCHS, model, train_loader, loss_function, optimizer, PYTORCH_DEVICE,
    checkpoint_path=CHECKPOINT_PATH, epoch_callback=epoch_callback)

Running training epoch 1/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 0, Avg Loss: 0.6400158708869814
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.7931951136129235
Macro F1: 0.7737028822938455
Weighted F2: 0.7953706579424175
Macro F2: 0.7669071390313769
Accuracy: 0.798426041852978
Recall: 0.798426041852978
Precision: 0.7966286954917692
Finished training epoch 1/12; Average Loss: 0.6400
Running training epoch 2/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 1, Avg Loss: 0.5463049892703863
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.7966963521691435
Macro F1: 0.7776702021289508
Weighted F2: 0.7987560549200781
Macro F2: 0.770986107973366
Accuracy: 0.8016455016991594
Recall: 0.8016455016991594
Precision: 0.7999094885395409
Finished training epoch 2/12; Average Loss: 0.5463
Running training epoch 3/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 2, Avg Loss: 0.5025804162195995
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.8108954223404946
Macro F1: 0.7958310449076755
Weighted F2: 0.8115745138556835
Macro F2: 0.7935452436393144
Accuracy: 0.8121981756394205
Recall: 0.8121981756394205
Precision: 0.8104645601434637
Finished training epoch 3/12; Average Loss: 0.5026
Running training epoch 4/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 3, Avg Loss: 0.4760888881437768
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.8068909030857607
Macro F1: 0.7896871372331477
Weighted F2: 0.8084309932173975
Macro F2: 0.7842463379968945
Accuracy: 0.8104095868359864
Recall: 0.8104095868359864
Precision: 0.8083991000177365
Finished training epoch 4/12; Average Loss: 0.4761
Running training epoch 5/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 4, Avg Loss: 0.448842934102289
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.8059282161997973
Macro F1: 0.7880685021591497
Weighted F2: 0.8077181964180307
Macro F2: 0.7816444182819828
Accuracy: 0.810230727955643
Recall: 0.810230727955643
Precision: 0.808681148116658
Finished training epoch 5/12; Average Loss: 0.4488
Running training epoch 6/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 5, Avg Loss: 0.4263666286659514
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.8106811325378588
Macro F1: 0.7946385812032595
Weighted F2: 0.8118179262525382
Macro F2: 0.7906191785357679
Accuracy: 0.8130924700411375
Recall: 0.8130924700411375
Precision: 0.8109523597667077
Finished training epoch 6/12; Average Loss: 0.4264
Running training epoch 7/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 6, Avg Loss: 0.4068521436874106
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.8108387738034076
Macro F1: 0.797330328511838
Weighted F2: 0.8106838746747175
Macro F2: 0.7978098013070485
Accuracy: 0.8105884457163298
Recall: 0.8105884457163298
Precision: 0.8111281427992858
Finished training epoch 7/12; Average Loss: 0.4069
Running training epoch 8/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 7, Avg Loss: 0.3848298909155937
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.8065647585166505
Macro F1: 0.7913129203962257
Weighted F2: 0.8071828542855349
Macro F2: 0.7893362665354764
Accuracy: 0.8077267036308353
Recall: 0.8077267036308353
Precision: 0.806073216742467
Finished training epoch 8/12; Average Loss: 0.3848
Running training epoch 9/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 8, Avg Loss: 0.36872219800608014
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.8078787796671424
Macro F1: 0.7921347529387823
Weighted F2: 0.808786424401079
Macro F2: 0.7891055001002043
Accuracy: 0.8096941513146128
Recall: 0.8096941513146128
Precision: 0.807618823922394
Finished training epoch 9/12; Average Loss: 0.3687
Running training epoch 10/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 9, Avg Loss: 0.3516428323944921
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.8035023797799367
Macro F1: 0.7866141609356361
Weighted F2: 0.8048070347102951
Macro F2: 0.7823051829681349
Accuracy: 0.806295832588088
Recall: 0.806295832588088
Precision: 0.8039367574808922
Finished training epoch 10/12; Average Loss: 0.3516
Running training epoch 11/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 10, Avg Loss: 0.34036556129291845
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.8025450622700593
Macro F1: 0.7861266630309831
Weighted F2: 0.8035992638647049
Macro F2: 0.7827634200222545
Accuracy: 0.8046861026649973
Recall: 0.8046861026649973
Precision: 0.802385888154133
Finished training epoch 11/12; Average Loss: 0.3404
Running training epoch 12/12


  0%|          | 0/699 [00:00<?, ?it/s]

Epoch: 11, Avg Loss: 0.32450375536480686
Validation Results:


  0%|          | 0/175 [00:00<?, ?it/s]

Weighted F1: 0.8051014020872056
Macro F1: 0.7900509025506027
Weighted F2: 0.8055614263243022
Macro F2: 0.7886220698732176
Accuracy: 0.8059381148274012
Recall: 0.8059381148274012
Precision: 0.8046190281303939
Finished training epoch 12/12; Average Loss: 0.3245


### Save the model and tokenizer

In [12]:
target_dir = os.path.join(ROOT_PATH, 'models/trained-bertimbau-bf16-{TIMESTAMP}')

os.makedirs(target_dir, exist_ok=True)

tokenizer.save_vocabulary(target_dir)
torch.save(model, f'{target_dir}/model.pth')