In [20]:
import pandas as pd
import numpy as np
import torch
from datasets import Dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os
import json

In [21]:
# ✅ Detect MPS
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [22]:
project_root = os.path.abspath("..")
# Load dataset

In [16]:
DATASET_PATH = os.path.join(project_root, "datasets/processed_split_toxicity_data.csv")
df = pd.read_csv(DATASET_PATH)
dataset = Dataset.from_pandas(df)

def tokenize_split(example):
    return tokenizer(example["message"], example["context"], padding="max_length", truncation=True, max_length=512)

def tokenize_message_only(example):
    return tokenizer(example["message"], padding="max_length", truncation=True, max_length=512)

tokenize_funcs = {
    "split": tokenize_split,
    "message_only": tokenize_message_only
}

In [17]:
BEST_MODELS_PATH = "best_models.json"
F1_TOLERANCE = 1e-2

def load_best_models():
    if not os.path.exists(BEST_MODELS_PATH):
        return {
            'bert-base-uncased-split': {'f1': 0.0, 'precision': 0.0, "recall": 0.0, "accuracy": 0.0},
            'bert-base-uncased-message_only': {'f1': 0.0, 'precision': 0.0, "recall": 0.0, "accuracy": 0.0},
            'bert-large-uncased-split': {'f1': 0.0, 'precision': 0.0, "recall": 0.0, "accuracy": 0.0},
            'bert-large-uncased-message_only': {'f1': 0.0, 'precision': 0.0, "recall": 0.0, "accuracy": 0.0},
            's-nlp-roberta-toxicity-classifier-split': {'f1': 0.0, 'precision': 0.0, "recall": 0.0, "accuracy": 0.0},
            's-nlp-roberta-toxicity-classifier-message_only': {'f1': 0.0, 'precision': 0.0, "recall": 0.0, "accuracy": 0.0}
        }
    with open(BEST_MODELS_PATH, "r") as f:
        return json.load(f)

In [18]:
def update_best_model(model_name, f1_score, precision, recall, accuracy):
    best_models = load_best_models()
    current_best = best_models.get(model_name, {"f1": 0.0, "precision": 0.0})

    better_f1 = f1_score - F1_TOLERANCE > current_best["f1"]
    similar_f1 = abs(f1_score - current_best["f1"]) <= F1_TOLERANCE
    better_precision = precision > current_best["precision"]

    should_update = better_f1 or (similar_f1 and better_precision)

    if should_update:
        print(f"🎯 New best for {model_name}!")
        print(f"F1: {f1_score:.4f} (prev: {current_best['f1']:.4f}) | Precision: {precision:.4f} (prev: {current_best['precision']:.4f})")
        best_models[model_name] = {
            "f1": f1_score,
            "precision": precision,
            "recall": recall,
            "accuracy": accuracy,
        }
        with open(BEST_MODELS_PATH, "w") as f:
            json.dump(best_models, f, indent=2)
        return True
    else:
        print(f"🧪 {model_name} did not improve:")
        print(f"F1: {f1_score:.4f} (best: {current_best['f1']:.4f}) | Precision: {precision:.4f} (best: {current_best['precision']:.4f})")
        return False

In [19]:
results = {}
best_f1 = -1
best_model = None
best_version = ""

# Load tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")



for version, tokenizer_func in tokenize_funcs.items():
    print(f"\nTraining version: {version}")
    tokenized = dataset.map(tokenizer_func, batched=True)
    tokenized = tokenized.train_test_split(test_size=0.25)
    tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

    model_name = f"bert-base-uncased-{version}"
    model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
    model.to(device)

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        preds = np.argmax(logits, axis=-1)
        acc = accuracy_score(labels, preds)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', zero_division=0)
        return {"accuracy": acc, "precision": precision, "recall": recall, "f1": f1}

    training_args = TrainingArguments(
        output_dir="./results",
        eval_strategy="epoch",             # Evaluate every epoch
        save_strategy="epoch",                   # Save every epoch
        learning_rate=2e-5,                      # From table
        per_device_train_batch_size=16,          # From table
        per_device_eval_batch_size=16,           # From table
        num_train_epochs=5,                      # From table
        weight_decay=0.01,                       # From table
        warmup_ratio=0.0,                        # Table didn't specify, assume 0 unless specified
        lr_scheduler_type="linear",              # From table
        logging_dir="./logs",
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        optim="adamw_torch",                     # Explicitly using AdamW
        adam_epsilon=1e-8                        # From table
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized["train"],
        eval_dataset=tokenized["test"],
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )

    trainer.train()
    eval_result = trainer.evaluate()
    results[version] = eval_result

    if update_best_model(
        model_name=model_name,
        f1_score=eval_result["eval_f1"],
        precision=eval_result["eval_precision"],
        recall=eval_result["eval_recall"],
        accuracy=eval_result["eval_accuracy"]
    ):
        trainer.save_model(f"models/best-{model_name}")

    if eval_result["eval_f1"] > best_f1:
        best_f1 = eval_result["eval_f1"]
        best_model = model
        best_version = version

# Print summary
print("\n--- Summary of Results ---")
for version, result in results.items():
    print(f"{version.upper()}: F1 = {result['eval_f1']:.4f}")

print(f"\nBest model: {best_version.upper()} (F1 = {best_f1:.4f})")


Training version: split


Map:   0%|          | 0/1024 [00:00<?, ? examples/s]

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.621341,0.675781,0.625,0.785124,0.695971
2,No log,0.59164,0.710938,0.674074,0.752066,0.710938
3,No log,0.580322,0.707031,0.655405,0.801653,0.72119
4,No log,0.580494,0.699219,0.648649,0.793388,0.713755
5,No log,0.566059,0.722656,0.678571,0.785124,0.727969


🎯 New best for bert-base-uncased-split!
F1: 0.7280 (prev: 0.0000) | Precision: 0.6786 (prev: 0.0000)

Training version: message_only


Map:   0%|          | 0/1024 [00:00<?, ? examples/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.689965,0.613281,0.569307,0.905512,0.699088
2,No log,0.613243,0.714844,0.698529,0.748031,0.722433
3,No log,0.530195,0.765625,0.755725,0.779528,0.767442
4,No log,0.500653,0.789062,0.817391,0.740157,0.77686
5,No log,0.489992,0.792969,0.808333,0.76378,0.785425


🎯 New best for bert-base-uncased-message_only!
F1: 0.7854 (prev: 0.0000) | Precision: 0.8083 (prev: 0.0000)

--- Summary of Results ---
SPLIT: F1 = 0.7280
MESSAGE_ONLY: F1 = 0.7854

Best model: MESSAGE_ONLY (F1 = 0.7854)
