# Neural Solution - Transformers: BERT

In [1]:
import os
import sys
import random
from datetime import datetime

import numpy as np
import polars as pl
import plotly.express as px
from sklearn.model_selection import train_test_split

import torch
from torch import nn, optim, cuda
from torch.utils.data import DataLoader

In [2]:
ROOT_PATH = '../'
DRIVE_PATH = 'Colab/ToxicityClassification'

# When on Colab, use Google Drive as the root path to persist and load data
if 'google.colab' in sys.modules:
    from google.colab import drive
    drive.mount('/content/drive')
    ROOT_PATH = os.path.join('/content/drive/My Drive/', DRIVE_PATH)
    os.makedirs(ROOT_PATH, exist_ok=True)
    os.chdir(ROOT_PATH)

In [3]:
# Register the parent directory of the current script as a package root,
# so that we can import modules from the parent directory
sys.path.append(os.path.abspath(os.path.join(ROOT_PATH, 'src')))

from toxicity.transformers.bertimbau_base import bert_tokenizer, BertDatasetBF16, BertModuleBF16
from toxicity.transformers.training import trainer, validate
from toxicity.training import train_epochs, model_metrics

## Setup

In [4]:
# Target device for running the model
PYTORCH_DEVICE = 'cuda' if cuda.is_available() else 'cpu'

# Random Seed
RANDOM_SEED = 777

# Training & Validation configs
TRAIN_RATIO = 0.8
MAX_LEN = 256
TRAIN_BATCH_SIZE = 64
TEST_BATCH_SIZE = 64
EPOCHS = 6
LEARNING_RATE = 5e-05
POS_WEIGHT = 1.663

print(f'Using device: {PYTORCH_DEVICE}')

Using device: cuda


In [5]:
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Data Loading

In [6]:
df = pl.read_parquet(os.path.join(ROOT_PATH, 'data', 'joint', 'pre_processed_data.parquet.zstd'))
df = df.with_columns(
    df['off_relaxed'].cast(pl.Int32).cast(pl.List(pl.Int32)).cast(pl.Array(pl.Int32, 1)),
    df['off_strict'].cast(pl.Int32).cast(pl.List(pl.Int32)).cast(pl.Array(pl.Int32, 1)),
)
df.sample(5, seed=42)

dataset,id,text,off_strict,off_relaxed,base_clean,base_clean_lower,tokenized,lemmatized,no_accents,lemma_no_accents,no_stop_words,lemma_no_stop_words,no_stop_words_no_accents,lemma_no_stop_words_no_accents
str,str,str,"array[i32, 1]","array[i32, 1]",str,str,list[str],list[str],list[str],list[str],list[str],list[str],list[str],list[str]
"""ToLD-Br""","""5508727285226739644""","""medo de ir pra um rolê de novo…",[0],[0],"""medo de ir pra um rolê de novo…","""medo de ir pra um rolê de novo…","[""medo"", ""de"", … ""kkkkkkk""]","[""medo"", ""de"", … ""kkkkkkk""]","[""medo"", ""de"", … ""kkkkkkk""]","[""medo"", ""de"", … ""kkkkkkk""]","[""medo"", ""pra"", … ""kkkkkkk""]","[""medo"", ""pra"", … ""kkkkkkk""]","[""medo"", ""pra"", … ""kkkkkkk""]","[""medo"", ""pra"", … ""kkkkkkk""]"
"""ToLD-Br""","""16827841903506270139""","""https://t.co/2bs6oD330q Ele a…",[0],[0],"""https t co 2bs6oD330q Ele até …","""https t co 2bs6od330q ele até …","[""https"", ""t"", … ""gd2j98vrkg""]","[""https"", ""t"", … ""gd2j98vrkg""]","[""https"", ""t"", … ""gd2j98vrkg""]","[""https"", ""t"", … ""gd2j98vrkg""]","[""https"", ""t"", … ""gd2j98vrkg""]","[""https"", ""t"", … ""gd2j98vrkg""]","[""https"", ""t"", … ""gd2j98vrkg""]","[""https"", ""t"", … ""gd2j98vrkg""]"
"""ToLD-Br""","""7641628880024884135""","""rt USER bruno fernandes assina…",[0],[0],"""rt USER bruno fernandes assina…","""rt user bruno fernandes assina…","[""rt"", ""user"", … ""user""]","[""rt"", ""user"", … ""user""]","[""rt"", ""user"", … ""user""]","[""rt"", ""user"", … ""user""]","[""rt"", ""user"", … ""user""]","[""rt"", ""user"", … ""user""]","[""rt"", ""user"", … ""user""]","[""rt"", ""user"", … ""user""]"
"""ToLD-Br""","""16866242508514532033""","""tinha que ter jogado esse bran…",[1],[1],"""tinha que ter jogado esse bran…","""tinha que ter jogado esse bran…","[""tinha"", ""que"", … ""trem""]","[""ter"", ""que"", … ""tr""]","[""tinha"", ""que"", … ""trem""]","[""ter"", ""que"", … ""tr""]","[""tinha"", ""jogado"", … ""trem""]","[""ter"", ""jogar"", … ""tr""]","[""tinha"", ""jogado"", … ""trem""]","[""ter"", ""jogar"", … ""tr""]"
"""ToLD-Br""","""3068271252403811869""","""eu sou a pessoa certa no bairr…",[0],[0],"""eu sou a pessoa certa no bairr…","""eu sou a pessoa certa no bairr…","[""eu"", ""sou"", … ""errado""]","[""eu"", ""ser"", … ""errar""]","[""eu"", ""sou"", … ""errado""]","[""eu"", ""ser"", … ""errar""]","[""pessoa"", ""certa"", … ""errado""]","[""pessoa"", ""certo"", … ""errar""]","[""pessoa"", ""certa"", … ""errado""]","[""pessoa"", ""certo"", … ""errar""]"


In [7]:
train_df, test_df = train_test_split(df, train_size=TRAIN_RATIO, random_state=RANDOM_SEED)
display(train_df.head(5))
display(test_df.head(5))

dataset,id,text,off_strict,off_relaxed,base_clean,base_clean_lower,tokenized,lemmatized,no_accents,lemma_no_accents,no_stop_words,lemma_no_stop_words,no_stop_words_no_accents,lemma_no_stop_words_no_accents
str,str,str,"array[i32, 1]","array[i32, 1]",str,str,list[str],list[str],list[str],list[str],list[str],list[str],list[str],list[str]
"""OLID-Br""","""3d85473d1c4b4f86a78159f23d7746…","""USER merda, ridículo essa impo…",[1],[1],"""USER merda ridículo essa impos…","""user merda ridículo essa impos…","[""user"", ""merda"", … ""crianças""]","[""user"", ""merda"", … ""criança""]","[""user"", ""merda"", … ""criancas""]","[""user"", ""merda"", … ""crianca""]","[""user"", ""merda"", … ""crianças""]","[""user"", ""merda"", … ""criança""]","[""user"", ""merda"", … ""criancas""]","[""user"", ""merda"", … ""crianca""]"
"""OLID-Br""","""b344c5518f0d44688ed45cc4a3183e…","""USER espero que eles sejam mor…",[1],[1],"""USER espero que eles sejam mor…","""user espero que eles sejam mor…","[""user"", ""espero"", … ""novamente""]","[""user"", ""esperar"", … ""novamente""]","[""user"", ""espero"", … ""novamente""]","[""user"", ""esperar"", … ""novamente""]","[""user"", ""espero"", … ""novamente""]","[""user"", ""esperar"", … ""novamente""]","[""user"", ""espero"", … ""novamente""]","[""user"", ""esperar"", … ""novamente""]"
"""ToLD-Br""","""4335543317461660187""","""eu tenho essas paran贸ias de ac…",[0],[0],"""eu tenho essas paran贸ias de ac…","""eu tenho essas paran贸ias de ac…","[""eu"", ""tenho"", … ""vivo""]","[""eu"", ""ter"", … ""vivo""]","[""eu"", ""tenho"", … ""vivo""]","[""eu"", ""ter"", … ""vivo""]","[""paran贸ias"", ""achar"", … ""vivo""]","[""paran贸ia"", ""achar"", … ""vivo""]","[""paranMao ias"", ""achar"", … ""vivo""]","[""paranMao ia"", ""achar"", … ""vivo""]"
"""OLID-Br""","""7ada9be164434f0e925f50616b637c…","""USER USER é USER USER""",[0],[0],"""USER USER é USER USER""","""user user é user user""","[""user"", ""user"", … ""user""]","[""user"", ""user"", … ""user""]","[""user"", ""user"", … ""user""]","[""user"", ""user"", … ""user""]","[""user"", ""user"", … ""user""]","[""user"", ""user"", … ""user""]","[""user"", ""user"", … ""user""]","[""user"", ""user"", … ""user""]"
"""ToLD-Br""","""16784738693255454158""","""meu pai me deu esse perfume eu…",[1],[0],"""meu pai me deu esse perfume eu…","""meu pai me deu esse perfume eu…","[""meu"", ""pai"", … ""aasacy52xy""]","[""meu"", ""pai"", … ""aasacy52xyr""]","[""meu"", ""pai"", … ""aasacy52xy""]","[""meu"", ""pai"", … ""aasacy52xyr""]","[""pai"", ""deu"", … ""aasacy52xy""]","[""pai"", ""dar"", … ""aasacy52xyr""]","[""pai"", ""deu"", … ""aasacy52xy""]","[""pai"", ""dar"", … ""aasacy52xyr""]"


dataset,id,text,off_strict,off_relaxed,base_clean,base_clean_lower,tokenized,lemmatized,no_accents,lemma_no_accents,no_stop_words,lemma_no_stop_words,no_stop_words_no_accents,lemma_no_stop_words_no_accents
str,str,str,"array[i32, 1]","array[i32, 1]",str,str,list[str],list[str],list[str],list[str],list[str],list[str],list[str],list[str]
"""ToLD-Br""","""10657414299548058873""","""rt USER mano tá tudo me irrita…",[1],[0],"""rt USER mano tá tudo me irrita…","""rt user mano tá tudo me irrita…","[""rt"", ""user"", … ""pqp""]","[""rt"", ""user"", … ""pqp""]","[""rt"", ""user"", … ""pqp""]","[""rt"", ""user"", … ""pqp""]","[""rt"", ""user"", … ""pqp""]","[""rt"", ""user"", … ""pqp""]","[""rt"", ""user"", … ""pqp""]","[""rt"", ""user"", … ""pqp""]"
"""ToLD-Br""","""11088205621966361413""","""USER horrível!""",[1],[0],"""USER horrível""","""user horrível""","[""user"", ""horrível""]","[""user"", ""horrível""]","[""user"", ""horrivel""]","[""user"", ""horrivel""]","[""user"", ""horrível""]","[""user"", ""horrível""]","[""user"", ""horrivel""]","[""user"", ""horrivel""]"
"""ToLD-Br""","""11546370057009176494""","""gnt como pode falar q esse hom…",[0],[0],"""gnt como pode falar q esse hom…","""gnt como pode falar q esse hom…","[""gnt"", ""como"", … ""jesus""]","[""gnt"", ""como"", … ""jesus""]","[""gnt"", ""como"", … ""jesus""]","[""gnt"", ""como"", … ""jesus""]","[""gnt"", ""falar"", … ""jesus""]","[""gnt"", ""falar"", … ""jesus""]","[""gnt"", ""falar"", … ""jesus""]","[""gnt"", ""falar"", … ""jesus""]"
"""ToLD-Br""","""9450469262872738701""","""Que foda o USER PUTA QUE PARIU…",[0],[0],"""Que foda o USER PUTA QUE PARIU…","""que foda o user puta que pariu…","[""que"", ""foda"", … ""3tv8fum5v8""]","[""que"", ""fodar"", … ""3tv8fum5v8""]","[""que"", ""foda"", … ""3tv8fum5v8""]","[""que"", ""fodar"", … ""3tv8fum5v8""]","[""foda"", ""user"", … ""3tv8fum5v8""]","[""fodar"", ""user"", … ""3tv8fum5v8""]","[""foda"", ""user"", … ""3tv8fum5v8""]","[""fodar"", ""user"", … ""3tv8fum5v8""]"
"""ToLD-Br""","""16835911729407698751""","""sapatão é foda, não pode beber…",[1],[1],"""sapatão é foda não pode beber …","""sapatão é foda não pode beber …","[""sapatão"", ""é"", … ""hein""]","[""sapatão"", ""ser"", … ""hein""]","[""sapatao"", ""e"", … ""hein""]","[""sapatao"", ""ser"", … ""hein""]","[""sapatão"", ""foda"", … ""hein""]","[""sapatão"", ""foda"", … ""hein""]","[""sapatao"", ""foda"", … ""hein""]","[""sapatao"", ""foda"", … ""hein""]"


## Setup Model

In [8]:
tokenizer = bert_tokenizer()

model = BertModuleBF16(feature_count=1)
model.to(PYTORCH_DEVICE)

train_loader = DataLoader(BertDatasetBF16(data_frame=train_df, tokenizer=tokenizer, max_len=MAX_LEN, target_col='off_relaxed', text_col='base_clean_lower'), shuffle=True,
                          num_workers=0, batch_size=TRAIN_BATCH_SIZE)
test_loader = DataLoader(BertDatasetBF16(data_frame=test_df, tokenizer=tokenizer, max_len=MAX_LEN, target_col='off_relaxed', text_col='base_clean_lower'), shuffle=True,
                         num_workers=0, batch_size=TEST_BATCH_SIZE)

### Loss and Optimizer

Using a Binary Cross Entropy loss as it shows good results for binary classification tasks. We are also applying differente weights to the positive and negative classes to account for the class imbalance.

Adam optimizer is also used as it is a good general optimizer for training neural networks, with good known results for BERT models.

In [9]:
loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([POS_WEIGHT], device=PYTORCH_DEVICE))
optimizer = optim.AdamW(params=model.parameters(), lr=LEARNING_RATE)

### Training

### Result Validation

In [10]:
def validate_result():
    # Validate the results
    raw_results, raw_targets = validate(model, test_loader, PYTORCH_DEVICE)
    raw_results = np.array(raw_results)
    raw_targets = np.array(raw_targets)

    # Apply a fixed threshold to the results
    FIXED_THRESHOLD = 0.5
    fixed_results = raw_results > FIXED_THRESHOLD
    fixed_targets = raw_targets > FIXED_THRESHOLD

    # Compute metrics
    return model_metrics(fixed_targets, fixed_results)

### Train Model

In [11]:
TIMESTAMP = datetime.now().strftime('%Y%m%d%H%M%S')
CHECKPOINT_PATH = os.path.join(ROOT_PATH, 'checkpoints', 'bertimbau-bf16', TIMESTAMP)
MODEL_PATH = os.path.join(ROOT_PATH, 'models', 'bertimbau-bf16', TIMESTAMP)
BEST_MODEL_PATH = os.path.join(MODEL_PATH, 'best_model.pt')
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
os.makedirs(MODEL_PATH, exist_ok=True)

loss_history = []
metric_history = []
target_metric = ('Weighted F2', 'weighted_f2')
best_metric = float('-inf')
best_epoch = 0

# Save the best model; Override checkpoints; Track metrics
def epoch_callback(epoch, avg_loss):
    global loss_history, metric_history, target_metric, best_metric, best_epoch

    metrics = validate_result()
    print(f'Epoch {epoch+1}: Loss: {avg_loss:.4f}')
    print(f'{target_metric[0]}: {metrics[target_metric[1]]:.4f}')
    loss_history.append(avg_loss)
    metric_history.append(metrics)

    if metrics[target_metric[1]] > best_metric:
        print(f'New best model found!')
        best_metric = metrics[target_metric[1]]
        best_epoch = epoch
        torch.save(model, BEST_MODEL_PATH)



train_epochs(
    trainer, EPOCHS, model, train_loader, loss_fn, optimizer, PYTORCH_DEVICE,
    checkpoint_path=CHECKPOINT_PATH, epoch_callback=epoch_callback)

Running training epoch 1/6


  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 1: Loss: 0.6244
Weighted F2: 0.7908
New best model found!
Finished training epoch 1/6; Average Loss: 0.6244
Running training epoch 2/6


  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 2: Loss: 0.5249
Weighted F2: 0.8044
New best model found!
Finished training epoch 2/6; Average Loss: 0.5249
Running training epoch 3/6


  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 3: Loss: 0.4683
Weighted F2: 0.8068
New best model found!
Finished training epoch 3/6; Average Loss: 0.4683
Running training epoch 4/6


  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 4: Loss: 0.4132
Weighted F2: 0.8102
New best model found!
Finished training epoch 4/6; Average Loss: 0.4132
Running training epoch 5/6


  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 5: Loss: 0.3480
Weighted F2: 0.8044
Finished training epoch 5/6; Average Loss: 0.3480
Running training epoch 6/6


  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 6: Loss: 0.2914
Weighted F2: 0.7981
Finished training epoch 6/6; Average Loss: 0.2914


In [12]:
# Build result metrics data frame per epoch
result_df = pl.DataFrame({
    'epoch': range(1, len(loss_history) + 1),
    'loss': loss_history,
}).with_columns(pl.from_dicts(metric_history))

In [13]:
# Plot Loss and Target Metric per Epoch, highlighting the peak
fig_a = px.line(result_df, x='epoch', y='loss', title='Loss per Epoch')
fig_a.add_scatter(x=[best_epoch+1], y=[loss_history[best_epoch]], mode='markers', showlegend=False, marker={'color': 'red', 'size': 10}, name='Best Model')
fig_a.update_layout()
fig_a.show()

fig_b = px.line(result_df, x='epoch', y='weighted_f2', title='Weighted F2 per Epoch')
fig_b.add_scatter(x=[best_epoch+1], y=[metric_history[best_epoch]['weighted_f2']], mode='markers', showlegend=False, marker={'color': 'red', 'size': 10}, name='Best Model')
fig_b.update_layout()
fig_b.show()
