In [1]:
import os
import sys
import json
import random
from datetime import datetime

import numpy as np
import polars as pl

from sklearn.model_selection import train_test_split

import torch
from torch import nn, optim, cuda
from torch.utils.data import DataLoader, Dataset

In [2]:
ROOT_PATH = '../'
DRIVE_PATH = 'Colab/ToxicityClassification'

# When on Colab, use Google Drive as the root path to persist and load data
if 'google.colab' in sys.modules:
    from google.colab import drive
    drive.mount('/content/drive')
    ROOT_PATH = os.path.join('/content/drive/My Drive/', DRIVE_PATH)
    os.makedirs(ROOT_PATH, exist_ok=True)
    os.chdir(ROOT_PATH)

In [3]:
# Register the parent directory of the current script as a package root,
# so that we can import modules from the parent directory
sys.path.append(os.path.abspath(os.path.join(ROOT_PATH, 'src')))

from toxicity.training import train_epochs, model_metrics
from toxicity.embeddings.training import trainer, validate
from toxicity.embeddings.model import EmbeddingModel, EmbeddingDataset

## Setup

In [4]:
# Target device for running the model
PYTORCH_DEVICE = 'cuda' if cuda.is_available() else 'cpu'

# Random Seed
RANDOM_SEED = 777

# Training & Validation configs
TRAIN_RATIO = 0.8
TRAIN_BATCH_SIZE = 16
TEST_BATCH_SIZE = 16
EPOCHS = 6
LEARNING_RATE = 1e-05
POS_WEIGHT = 1.663


EMBEDDING_FILE = os.path.join(ROOT_PATH, 'cbow_s100.txt')
EMBEDDING_NAME = 'cbow_s100'
MAX_LEN = 128

print(f'Using device: {PYTORCH_DEVICE}')

Using device: cuda


In [5]:

def reseed(seed: int = RANDOM_SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

reseed()

## Data Loading

In [6]:
df = pl.read_parquet(os.path.join(ROOT_PATH, 'data', 'joint', 'pre_processed_data.parquet.zstd'))
df = df.with_columns(
    df['off_relaxed'].cast(pl.Int32).cast(pl.List(pl.Int32)).cast(pl.Array(pl.Int32, 1)),
    df['off_strict'].cast(pl.Int32).cast(pl.List(pl.Int32)).cast(pl.Array(pl.Int32, 1)),
)
df.sample(5, seed=RANDOM_SEED)

dataset,id,text,off_strict,off_relaxed,base_clean,base_clean_lower,tokenized,lemmatized,no_accents,lemma_no_accents,no_stop_words,lemma_no_stop_words,no_stop_words_no_accents,lemma_no_stop_words_no_accents
str,str,str,"array[i32, 1]","array[i32, 1]",str,str,list[str],list[str],list[str],list[str],list[str],list[str],list[str],list[str]
"""ToLD-Br""","""17643984771725418028""","""caralho q vergonha kkkkk""",[1],[0],"""caralho q vergonha kkkkk""","""caralho q vergonha kkkkk""","[""caralho"", ""q"", … ""kkkkk""]","[""caralho"", ""q"", … ""kkkkk""]","[""caralho"", ""q"", … ""kkkkk""]","[""caralho"", ""q"", … ""kkkkk""]","[""caralho"", ""q"", … ""kkkkk""]","[""caralho"", ""q"", … ""kkkkk""]","[""caralho"", ""q"", … ""kkkkk""]","[""caralho"", ""q"", … ""kkkkk""]"
"""ToLD-Br""","""3886050625220892585""","""foda-se, vou encher o cu de po…",[1],[0],"""foda se vou encher o cu de por…","""foda se vou encher o cu de por…","[""foda"", ""se"", … ""lol""]","[""foda"", ""se"", … ""lol""]","[""foda"", ""se"", … ""lol""]","[""foda"", ""se"", … ""lol""]","[""foda"", ""vou"", … ""lol""]","[""foda"", ""ir"", … ""lol""]","[""foda"", ""vou"", … ""lol""]","[""foda"", ""ir"", … ""lol""]"
"""ToLD-Br""","""14936095030342170465""","""USER USER USER Vc só pensa no …",[1],[1],"""USER USER USER Vc só pensa no …","""user user user vc só pensa no …","[""user"", ""user"", … ""esperta""]","[""user"", ""user"", … ""esperto""]","[""user"", ""user"", … ""esperta""]","[""user"", ""user"", … ""esperto""]","[""user"", ""user"", … ""esperta""]","[""user"", ""user"", … ""esperto""]","[""user"", ""user"", … ""esperta""]","[""user"", ""user"", … ""esperto""]"
"""ToLD-Br""","""18279259074216789411""","""família""",[0],[0],"""família""","""família""","[""família""]","[""família""]","[""familia""]","[""familia""]","[""família""]","[""família""]","[""familia""]","[""familia""]"
"""OLID-Br""","""7f36b160e8624968a32e82b1c6750f…","""RT USER: vey a juliette veio c…",[0],[0],"""RT USER vey a juliette veio co…","""rt user vey a juliette veio co…","[""rt"", ""user"", … ""t""]","[""rt"", ""user"", … ""t""]","[""rt"", ""user"", … ""t""]","[""rt"", ""user"", … ""t""]","[""rt"", ""user"", … ""t""]","[""rt"", ""user"", … ""t""]","[""rt"", ""user"", … ""t""]","[""rt"", ""user"", … ""t""]"


### Load Embeddings

In [7]:
EMBEDDING_PATH = os.path.join(ROOT_PATH, 'models', f'embeddings-{EMBEDDING_NAME}')
os.makedirs(EMBEDDING_PATH, exist_ok=True)

emb_dim = None
token_count = None
embeddings = {}

if not os.path.exists(f'{EMBEDDING_PATH}/embeddings.parquet.zstd'):
    with open(EMBEDDING_FILE, 'r') as f:
        fl = f.readline()
        token_count, emb_dim = map(int, fl.split(' '))

        while line := f.readline():
            emb = line.split(' ')

            token = emb[0]
            values = [float(v) for v in emb[1:]]

            if emb_dim is None:
                emb_dim = len(values)
            elif emb_dim != len(values):
                raise ValueError('Inconsistent embedding length')

            embeddings[token] = values
    
    print(f'Embedding Length: {emb_dim}')
    print(f'Embedding Vocab Size: {len(embeddings)}; Expected: {token_count}')
    embedding_df = pl.DataFrame({
        'token': list(embeddings.keys()),
        'embedding': list(embeddings.values())
    })
    embedding_df.write_parquet(f'{EMBEDDING_PATH}/embeddings.parquet.zstd', compression="zstd", compression_level=9)
else:
    embedding_df = pl.read_parquet(f'{EMBEDDING_PATH}/embeddings.parquet.zstd')
    # TODO: optimize this below
    embeddings = {row['token']: row['embedding'] for row in embedding_df.to_dicts()}
    emb_dim = len(embeddings[next(iter(embeddings))])
    token_count = len(embeddings)
    print(f'Embedding Length: {emb_dim}')
    print(f'Embedding Vocab Size: {len(embeddings)}; Expected: {token_count}')


Embedding Length: 100
Embedding Vocab Size: 929606; Expected: 929606


## Init Model

### Loss and Optimizer

Using a Binary Cross Entropy loss as it shows good results for binary classification tasks. We are also applying differente weights to the positive and negative classes to account for the class imbalance.

Adam optimizer is also used as it is a good general optimizer for training neural networks.

In [8]:
model = EmbeddingModel(emb_dim, MAX_LEN)
model.to(PYTORCH_DEVICE)

loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([POS_WEIGHT], device=PYTORCH_DEVICE))
optimizer = optim.AdamW(params=model.parameters(), lr=LEARNING_RATE)

## Data Split

In [9]:
train_df, test_df = train_test_split(df, train_size=TRAIN_RATIO, random_state=RANDOM_SEED)

train_loader = DataLoader(
    EmbeddingDataset(train_df, 'lemma_no_stop_words', 'off_relaxed', embeddings=embeddings, emb_dim=emb_dim, seq_len=MAX_LEN), 
    shuffle=True, num_workers=0, batch_size=TRAIN_BATCH_SIZE,
)
test_loader = DataLoader(
    EmbeddingDataset(test_df, 'lemma_no_stop_words', 'off_relaxed', embeddings=embeddings, emb_dim=emb_dim, seq_len=MAX_LEN), 
    shuffle=False, num_workers=0, batch_size=TEST_BATCH_SIZE,
)

## Training

In [10]:
def validate_result():
    # Validate the results
    raw_results, raw_targets = validate(model, test_loader, PYTORCH_DEVICE)
    raw_results = np.array(raw_results)
    raw_targets = np.array(raw_targets)

    # Apply a fixed threshold to the results
    FIXED_THRESHOLD = 0.5
    fixed_results = raw_results > FIXED_THRESHOLD
    fixed_targets = raw_targets > FIXED_THRESHOLD

    # Compute metrics
    print(f'Weighted F2: {model_metrics(fixed_targets, fixed_results)["weighted_f2"]:.6f}')

### Training Loop

In [11]:
TIMESTAMP = datetime.now().strftime('%Y%m%d%H%M%S')
MODEL_PATH = os.path.join(ROOT_PATH, 'models', f'embeddings-{EMBEDDING_NAME}', TIMESTAMP)
CHECKPOINT_PATH = os.path.join(ROOT_PATH, 'checkpoints', f'embeddings-{EMBEDDING_NAME}', TIMESTAMP)
os.makedirs(MODEL_PATH, exist_ok=True)
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

def epoch_callback(epoch, avg_loss):
    print('Validation Results:')
    validate_result()

train_epochs(
    trainer, EPOCHS, model, train_loader, loss_fn, optimizer, PYTORCH_DEVICE,
    checkpoint_path=CHECKPOINT_PATH, epoch_callback=epoch_callback)

Running training epoch 1/10


  0%|          | 0/1398 [00:00<?, ?it/s]

Validation Results:


  0%|          | 0/350 [00:00<?, ?it/s]

Weighted F2: 0.590787
Finished training epoch 1/10; Average Loss: 0.8467
Running training epoch 2/10


  0%|          | 0/1398 [00:00<?, ?it/s]

Validation Results:


  0%|          | 0/350 [00:00<?, ?it/s]

Weighted F2: 0.674804
Finished training epoch 2/10; Average Loss: 0.7655
Running training epoch 3/10


  0%|          | 0/1398 [00:00<?, ?it/s]

Validation Results:


  0%|          | 0/350 [00:00<?, ?it/s]

Weighted F2: 0.691178
Finished training epoch 3/10; Average Loss: 0.7180
Running training epoch 4/10


  0%|          | 0/1398 [00:00<?, ?it/s]

Validation Results:


  0%|          | 0/350 [00:00<?, ?it/s]

Weighted F2: 0.675966
Finished training epoch 4/10; Average Loss: 0.6919
Running training epoch 5/10


  0%|          | 0/1398 [00:00<?, ?it/s]

Validation Results:


  0%|          | 0/350 [00:00<?, ?it/s]

Weighted F2: 0.689189
Finished training epoch 5/10; Average Loss: 0.6701
Running training epoch 6/10


  0%|          | 0/1398 [00:00<?, ?it/s]

Validation Results:


  0%|          | 0/350 [00:00<?, ?it/s]

Weighted F2: 0.692351
Finished training epoch 6/10; Average Loss: 0.6483
Running training epoch 7/10


  0%|          | 0/1398 [00:00<?, ?it/s]

Validation Results:


  0%|          | 0/350 [00:00<?, ?it/s]

Weighted F2: 0.679430
Finished training epoch 7/10; Average Loss: 0.6264
Running training epoch 8/10


  0%|          | 0/1398 [00:00<?, ?it/s]

Validation Results:


  0%|          | 0/350 [00:00<?, ?it/s]

Weighted F2: 0.688136
Finished training epoch 8/10; Average Loss: 0.6017
Running training epoch 9/10


  0%|          | 0/1398 [00:00<?, ?it/s]

Validation Results:


  0%|          | 0/350 [00:00<?, ?it/s]

Weighted F2: 0.684593
Finished training epoch 9/10; Average Loss: 0.5753
Running training epoch 10/10


  0%|          | 0/1398 [00:00<?, ?it/s]

Validation Results:


  0%|          | 0/350 [00:00<?, ?it/s]

Weighted F2: 0.689354
Finished training epoch 10/10; Average Loss: 0.5476


### Save model

In [12]:
torch.save(model, f'{MODEL_PATH}/model.pth')