# Neural Solution - Transformers: BERT

In [1]:
import os
import sys
import random
from datetime import datetime

import numpy as np
import polars as pl
import plotly.express as px
from sklearn.model_selection import train_test_split

import torch
from torch import nn, optim, cuda
from torch.utils.data import DataLoader

In [2]:
ROOT_PATH = '../'
DRIVE_PATH = 'Colab/ToxicityClassification'

# When on Colab, use Google Drive as the root path to persist and load data
if 'google.colab' in sys.modules:
    from google.colab import drive
    drive.mount('/content/drive')
    ROOT_PATH = os.path.join('/content/drive/My Drive/', DRIVE_PATH)
    os.makedirs(ROOT_PATH, exist_ok=True)
    os.chdir(ROOT_PATH)

In [3]:
# Register the parent directory of the current script as a package root,
# so that we can import modules from the parent directory
sys.path.append(os.path.abspath(os.path.join(ROOT_PATH, 'src')))

from toxicity.transformers.bertimbau_base import bert_tokenizer, BertDatasetBF16, BertModuleBF16
from toxicity.transformers.training import trainer, validate
from toxicity.training import train_epochs, model_metrics

## Setup

In [4]:
# Target device for running the model
PYTORCH_DEVICE = 'cuda' if cuda.is_available() else 'cpu'

# Random Seed
RANDOM_SEED = 777

# Training & Validation configs
TRAIN_RATIO = 0.8
MAX_LEN = 256
TRAIN_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 3e-05
POS_WEIGHT = 1.663

print(f'Using device: {PYTORCH_DEVICE}')

Using device: cuda


In [5]:
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Data Loading

In [6]:
df = pl.read_parquet(os.path.join(ROOT_PATH, 'data', 'joint', 'pre_processed_data.parquet.zstd'))
df = df.with_columns(
    df['off_relaxed'].cast(pl.Int32).cast(pl.List(pl.Int32)).cast(pl.Array(pl.Int32, 1)),
    df['off_strict'].cast(pl.Int32).cast(pl.List(pl.Int32)).cast(pl.Array(pl.Int32, 1)),
)
df.sample(5, seed=42)

dataset,id,text,off_strict,off_relaxed,base_clean,base_clean_lower,tokenized,lemmatized,no_accents,lemma_no_accents,no_stop_words,lemma_no_stop_words,no_stop_words_no_accents,lemma_no_stop_words_no_accents
str,str,str,"array[i32, 1]","array[i32, 1]",str,str,list[str],list[str],list[str],list[str],list[str],list[str],list[str],list[str]
"""ToLD-Br""","""5508727285226739644""","""medo de ir pra um rolê de novo…",[0],[0],"""medo de ir pra um rolê de novo…","""medo de ir pra um rolê de novo…","[""medo"", ""de"", … ""kkkkkkk""]","[""medo"", ""de"", … ""kkkkkkk""]","[""medo"", ""de"", … ""kkkkkkk""]","[""medo"", ""de"", … ""kkkkkkk""]","[""medo"", ""pra"", … ""kkkkkkk""]","[""medo"", ""pra"", … ""kkkkkkk""]","[""medo"", ""pra"", … ""kkkkkkk""]","[""medo"", ""pra"", … ""kkkkkkk""]"
"""ToLD-Br""","""16827841903506270139""","""https://t.co/2bs6oD330q Ele a…",[0],[0],"""https t co 2bs6oD330q Ele até …","""https t co 2bs6od330q ele até …","[""https"", ""t"", … ""gd2j98vrkg""]","[""https"", ""t"", … ""gd2j98vrkg""]","[""https"", ""t"", … ""gd2j98vrkg""]","[""https"", ""t"", … ""gd2j98vrkg""]","[""https"", ""t"", … ""gd2j98vrkg""]","[""https"", ""t"", … ""gd2j98vrkg""]","[""https"", ""t"", … ""gd2j98vrkg""]","[""https"", ""t"", … ""gd2j98vrkg""]"
"""ToLD-Br""","""7641628880024884135""","""rt USER bruno fernandes assina…",[0],[0],"""rt USER bruno fernandes assina…","""rt user bruno fernandes assina…","[""rt"", ""user"", … ""user""]","[""rt"", ""user"", … ""user""]","[""rt"", ""user"", … ""user""]","[""rt"", ""user"", … ""user""]","[""rt"", ""user"", … ""user""]","[""rt"", ""user"", … ""user""]","[""rt"", ""user"", … ""user""]","[""rt"", ""user"", … ""user""]"
"""ToLD-Br""","""16866242508514532033""","""tinha que ter jogado esse bran…",[1],[1],"""tinha que ter jogado esse bran…","""tinha que ter jogado esse bran…","[""tinha"", ""que"", … ""trem""]","[""ter"", ""que"", … ""tr""]","[""tinha"", ""que"", … ""trem""]","[""ter"", ""que"", … ""tr""]","[""tinha"", ""jogado"", … ""trem""]","[""ter"", ""jogar"", … ""tr""]","[""tinha"", ""jogado"", … ""trem""]","[""ter"", ""jogar"", … ""tr""]"
"""ToLD-Br""","""3068271252403811869""","""eu sou a pessoa certa no bairr…",[0],[0],"""eu sou a pessoa certa no bairr…","""eu sou a pessoa certa no bairr…","[""eu"", ""sou"", … ""errado""]","[""eu"", ""ser"", … ""errar""]","[""eu"", ""sou"", … ""errado""]","[""eu"", ""ser"", … ""errar""]","[""pessoa"", ""certa"", … ""errado""]","[""pessoa"", ""certo"", … ""errar""]","[""pessoa"", ""certa"", … ""errado""]","[""pessoa"", ""certo"", … ""errar""]"


## Setup Model

In [7]:
tokenizer = bert_tokenizer()

model = BertModuleBF16(feature_count=1)
model.to(PYTORCH_DEVICE)

BertModuleBF16(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(29794, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

## Data Split

In [8]:

train_df, other_df = train_test_split(df, train_size=TRAIN_RATIO, random_state=RANDOM_SEED)
validate_df, test_df = train_test_split(other_df, train_size=0.5, random_state=RANDOM_SEED)

train_loader = DataLoader(
    BertDatasetBF16(data_frame=train_df, tokenizer=tokenizer, max_len=MAX_LEN, target_col='off_relaxed'), 
    shuffle=True, num_workers=0, batch_size=TRAIN_BATCH_SIZE,
)
test_loader = DataLoader(
    BertDatasetBF16(data_frame=test_df, tokenizer=tokenizer, max_len=MAX_LEN, target_col='off_relaxed'), 
    shuffle=False, num_workers=0, batch_size=TEST_BATCH_SIZE,
)
validate_loader = DataLoader(
    BertDatasetBF16(data_frame=validate_df, tokenizer=tokenizer, max_len=MAX_LEN, target_col='off_relaxed'), 
    shuffle=False, num_workers=0, batch_size=TEST_BATCH_SIZE,
)

train_distribution = train_df['off_relaxed'].to_pandas().value_counts()
neg_count, pos_count = train_distribution.iloc[0], train_distribution.iloc[1]
print(f'Training distribution: {neg_count} negative, {pos_count} positive')
POS_WEIGHT = neg_count / pos_count
print(f'Positive weight: {POS_WEIGHT}')

Training distribution: 13932 negative, 8429 positive
Positive weight: 1.6528651085538024


### Loss and Optimizer

Using a Binary Cross Entropy loss as it shows good results for binary classification tasks. We are also applying differente weights to the positive and negative classes to account for the class imbalance.

Adam optimizer is also used as it is a good general optimizer for training neural networks, with good known results for BERT models.

In [9]:
loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([POS_WEIGHT], device=PYTORCH_DEVICE))
optimizer = optim.AdamW(params=model.parameters(), lr=LEARNING_RATE)

### Training

### Result Validation

In [10]:
def validate_result(loader: DataLoader, model: nn.Module):
    # Validate the results
    raw_results, raw_targets = validate(model, loader, PYTORCH_DEVICE)
    raw_results = np.array(raw_results)
    raw_targets = np.array(raw_targets)

    # Apply a fixed threshold to the results
    FIXED_THRESHOLD = 0.5
    fixed_results = raw_results > FIXED_THRESHOLD
    fixed_targets = raw_targets > FIXED_THRESHOLD

    # Compute metrics
    metrics = model_metrics(fixed_targets, fixed_results)
    return metrics

### Train Model

In [11]:
TIMESTAMP = datetime.now().strftime('%Y%m%d%H%M%S')
CHECKPOINT_PATH = os.path.join(ROOT_PATH, 'checkpoints', 'bertimbau-bf16', TIMESTAMP)
MODEL_PATH = os.path.join(ROOT_PATH, 'models', 'bertimbau-bf16', TIMESTAMP)
BEST_MODEL_PATH = os.path.join(MODEL_PATH, 'best_model.pt')
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
os.makedirs(MODEL_PATH, exist_ok=True)

loss_history = []
metric_history = []
test_metric_history = []
target_metric = ('Weighted F2', 'weighted_f2')
best_metric = float('-inf')
best_epoch = 0

# Save the best model; Override checkpoints; Track metrics
def epoch_callback(epoch, avg_loss):
    global loss_history, metric_history, test_metric_history, target_metric, best_metric, best_epoch
    
    metrics = validate_result(validate_loader, model)
    loss_history.append(avg_loss)
    metric_history.append(metrics)
    test_metrics = validate_result(test_loader, model)
    test_metric_history.append(test_metrics)

    print(f'Epoch {epoch+1}: Loss: {avg_loss:.4f}')
    print(f'Validation {target_metric[0]}: {metrics[target_metric[1]]:.4f}')
    print(f'Test {target_metric[0]}: {test_metrics[target_metric[1]]:.4f}')
    
    if metrics[target_metric[1]] > best_metric:
        print(f'New best model found!')
        best_metric = metrics[target_metric[1]]
        best_epoch = epoch
        torch.save(model, BEST_MODEL_PATH)



train_epochs(
    trainer, EPOCHS, model, train_loader, loss_fn, optimizer, PYTORCH_DEVICE,
    checkpoint_path=CHECKPOINT_PATH, epoch_callback=epoch_callback)

Running training epoch 1/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 1: Loss: 0.6361
Validation Weighted F2: 0.7883
Test Weighted F2: 0.7842
New best model found!
Finished training epoch 1/20; Average Loss: 0.6361
Running training epoch 2/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 2: Loss: 0.5438
Validation Weighted F2: 0.8025
Test Weighted F2: 0.7976
New best model found!
Finished training epoch 2/20; Average Loss: 0.5438
Running training epoch 3/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 3: Loss: 0.4995
Validation Weighted F2: 0.7892
Test Weighted F2: 0.7855
Finished training epoch 3/20; Average Loss: 0.4995
Running training epoch 4/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 4: Loss: 0.4715
Validation Weighted F2: 0.8031
Test Weighted F2: 0.7961
New best model found!
Finished training epoch 4/20; Average Loss: 0.4715
Running training epoch 5/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 5: Loss: 0.4431
Validation Weighted F2: 0.8082
Test Weighted F2: 0.8016
New best model found!
Finished training epoch 5/20; Average Loss: 0.4431
Running training epoch 6/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 6: Loss: 0.4195
Validation Weighted F2: 0.8010
Test Weighted F2: 0.7961
Finished training epoch 6/20; Average Loss: 0.4195
Running training epoch 7/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 7: Loss: 0.3994
Validation Weighted F2: 0.7844
Test Weighted F2: 0.7837
Finished training epoch 7/20; Average Loss: 0.3994
Running training epoch 8/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 8: Loss: 0.3765
Validation Weighted F2: 0.7941
Test Weighted F2: 0.7935
Finished training epoch 8/20; Average Loss: 0.3765
Running training epoch 9/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 9: Loss: 0.3608
Validation Weighted F2: 0.7982
Test Weighted F2: 0.7915
Finished training epoch 9/20; Average Loss: 0.3608
Running training epoch 10/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 10: Loss: 0.3436
Validation Weighted F2: 0.7997
Test Weighted F2: 0.7955
Finished training epoch 10/20; Average Loss: 0.3436
Running training epoch 11/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 11: Loss: 0.3334
Validation Weighted F2: 0.7943
Test Weighted F2: 0.7980
Finished training epoch 11/20; Average Loss: 0.3334
Running training epoch 12/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 12: Loss: 0.3170
Validation Weighted F2: 0.7860
Test Weighted F2: 0.7857
Finished training epoch 12/20; Average Loss: 0.3170
Running training epoch 13/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 13: Loss: 0.3075
Validation Weighted F2: 0.7830
Test Weighted F2: 0.7830
Finished training epoch 13/20; Average Loss: 0.3075
Running training epoch 14/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 14: Loss: 0.2995
Validation Weighted F2: 0.7986
Test Weighted F2: 0.7984
Finished training epoch 14/20; Average Loss: 0.2995
Running training epoch 15/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 15: Loss: 0.2885
Validation Weighted F2: 0.7922
Test Weighted F2: 0.7927
Finished training epoch 15/20; Average Loss: 0.2885
Running training epoch 16/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 16: Loss: 0.2818
Validation Weighted F2: 0.7957
Test Weighted F2: 0.7941
Finished training epoch 16/20; Average Loss: 0.2818
Running training epoch 17/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 17: Loss: 0.2704
Validation Weighted F2: 0.7861
Test Weighted F2: 0.7876
Finished training epoch 17/20; Average Loss: 0.2704
Running training epoch 18/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 18: Loss: 0.2652
Validation Weighted F2: 0.7882
Test Weighted F2: 0.7891
Finished training epoch 18/20; Average Loss: 0.2652
Running training epoch 19/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 19: Loss: 0.2608
Validation Weighted F2: 0.7825
Test Weighted F2: 0.7883
Finished training epoch 19/20; Average Loss: 0.2608
Running training epoch 20/20


  0%|          | 0/699 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch 20: Loss: 0.2557
Validation Weighted F2: 0.7839
Test Weighted F2: 0.7847
Finished training epoch 20/20; Average Loss: 0.2557


In [12]:
# Build result metrics data frame per epoch
result_df = pl.DataFrame({
    'epoch': range(1, len(loss_history) + 1),
    'loss': loss_history,
}).with_columns(pl.from_dicts(metric_history))

test_result_df = pl.DataFrame({
    'epoch': range(1, len(loss_history) + 1),
}).with_columns(pl.from_dicts(test_metric_history))

result_df.head()


epoch,loss,weighted_f1,macro_f1,weighted_f2,macro_f2,accuracy,recall,precision
i64,f64,f64,f64,f64,f64,f64,f64,f64
1,0.636126,0.791011,0.781505,0.788261,0.787686,0.788193,0.788193,0.802364
2,0.543765,0.804206,0.794056,0.802465,0.798546,0.802147,0.802147,0.810366
3,0.499501,0.792718,0.784253,0.789249,0.791817,0.789624,0.789624,0.808785
4,0.471502,0.805076,0.795277,0.80312,0.800267,0.802862,0.802862,0.812368
5,0.443094,0.809315,0.798525,0.808227,0.801579,0.807871,0.807871,0.81257


In [13]:
# Plot Loss and Target Metric per Epoch, highlighting the peak
fig_a = px.line(result_df, x='epoch', y='loss', title='Loss per Epoch', template='plotly_dark')
fig_a.add_scatter(x=[best_epoch+1], y=[loss_history[best_epoch]], mode='markers', showlegend=False, marker={'color': 'red', 'size': 10}, name='Best Model')
fig_a.update_layout()
fig_a.show()

fig_b = px.line(result_df, x='epoch', y='weighted_f2', title='Validation Weighted F2 per Epoch', template='plotly_dark')
fig_b.add_scatter(x=[best_epoch+1], y=[metric_history[best_epoch]['weighted_f2']], mode='markers', showlegend=False, marker={'color': 'red', 'size': 10}, name='Best Model')
fig_b.update_layout()
fig_b.show()

fig_c = px.line(test_result_df, x='epoch', y='weighted_f2', title='Test Weighted F2 per Epoch', template='plotly_dark')
fig_c.add_scatter(x=[best_epoch+1], y=[test_metric_history[best_epoch]['weighted_f2']], mode='markers', showlegend=False, marker={'color': 'red', 'size': 10}, name='Best Model')
fig_c.update_layout()
fig_c.show()