In [1]:
import os
import sys
import random

import numpy as np


import torch
from torch import cuda

In [2]:
ROOT_PATH = '../'
DRIVE_PATH = 'Colab/ToxicityClassification'

# When on Colab, use Google Drive as the root path to persist and load data
if 'google.colab' in sys.modules:
    from google.colab import drive
    drive.mount('/content/drive')
    ROOT_PATH = os.path.join('/content/drive/My Drive/', DRIVE_PATH)
    os.makedirs(ROOT_PATH, exist_ok=True)
    os.chdir(ROOT_PATH)

In [3]:

# Register the parent directory of the current script as a package root,
# so that we can import modules from the parent directory
sys.path.append(os.path.abspath(os.path.join(ROOT_PATH, 'src')))

from toxicity.transformers.bertimbau_base import bert_tokenizer, BertModuleBF16

## Setup

In [4]:
# Target device for running the model
PYTORCH_DEVICE = 'cuda' if cuda.is_available() else 'cpu'

# Random Seed
RANDOM_SEED = 777



print(f'Using device: {PYTORCH_DEVICE}')

Using device: cuda


In [5]:
def reseed(seed: int = RANDOM_SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

reseed()

## Load Model

In [6]:
TIMESTAMP = "20240620230559"
MAX_LEN = 256
MODEL_PATH = os.path.join(ROOT_PATH, f'models/bertimbau-bf16/{TIMESTAMP}/best_model.pt')

model = torch.load(MODEL_PATH, map_location=PYTORCH_DEVICE)

## Run Model

In [7]:
FIXED_THRESHOLD = 0.75

tokenizer = bert_tokenizer()
def run_text(text: str) -> bool:
    tokens = tokenizer.encode_plus(
        text,
        None,
        add_special_tokens=True,
        max_length=MAX_LEN,
        padding="max_length",
        return_token_type_ids=False,
        truncation=True,
    )

    model.eval()
    outs = model(
        torch.tensor(tokens['input_ids'], dtype=torch.int64).unsqueeze(0).to(PYTORCH_DEVICE),
        torch.tensor(tokens['attention_mask'], dtype=torch.int64).unsqueeze(0).to(PYTORCH_DEVICE),
    )

    result = torch.sigmoid(outs).to(torch.float32).cpu().detach().numpy().flatten()

    return result[0] > FIXED_THRESHOLD

In [8]:
run_text('salve galera, tudo bem?')

False

In [9]:
run_text('vocês são uns lixos')

True

In [10]:
run_text('vocês são bem bobo')

True