In [None]:
import pandas as pd
import numpy as np
import torch
from datasets import Dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os
import json

In [None]:
# ✅ Detect MPS or fallback to CPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
project_root = os.path.abspath("..")

In [None]:
DATASET_PATH = os.path.join(project_root, "datasets/processed_split_toxicity_data.csv")
df = pd.read_csv(DATASET_PATH)
dataset = Dataset.from_pandas(df)

def tokenize_split(example):
    return tokenizer(example["message"], example["context"], padding="max_length", truncation=True, max_length=512)

def tokenize_message_only(example):
    return tokenizer(example["message"], padding="max_length", truncation=True, max_length=512)

tokenize_funcs = {
    "split": tokenize_split,
    "message_only": tokenize_message_only
}

In [None]:
BEST_MODELS_PATH = "best_models.json"
F1_TOLERANCE = 1e-2

def load_best_models():
    if not os.path.exists(BEST_MODELS_PATH):
        return {}
    with open(BEST_MODELS_PATH, "r") as f:
        return json.load(f)

In [None]:
def update_best_model(model_name, f1_score, precision, recall, accuracy):
    best_models = load_best_models()
    current_best = best_models.get(model_name, {"f1": 0.0, "precision": 0.0})

    better_f1 = f1_score - F1_TOLERANCE > current_best["f1"]
    similar_f1 = abs(f1_score - current_best["f1"]) <= F1_TOLERANCE
    better_precision = precision > current_best["precision"]

    should_update = better_f1 or (similar_f1 and better_precision)

    if should_update:
        print(f"🎯 New best for {model_name}!")
        print(f"F1: {f1_score:.4f} (prev: {current_best['f1']:.4f}) | Precision: {precision:.4f} (prev: {current_best['precision']:.4f})")
        best_models[model_name] = {
            "f1": f1_score,
            "precision": precision,
            "recall": recall,
            "accuracy": accuracy,
        }
        with open(BEST_MODELS_PATH, "w") as f:
            json.dump(best_models, f, indent=2)
        return True
    else:
        print(f"🧪 {model_name} did not improve:")
        print(f"F1: {f1_score:.4f} (best: {current_best['f1']:.4f}) | Precision: {precision:.4f} (best: {current_best['precision']:.4f})")
        return False

In [None]:
# Result tracking
results = {}
best_f1 = -1
best_model = None
best_version = ""

# Tokenizer (for BERT large)
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")

# Training loop
for version, tokenizer_func in tokenize_funcs.items():
    print(f"\n🔄 Training version: {version.upper()}")
    tokenized = dataset.map(tokenizer_func, batched=True)
    tokenized = tokenized.train_test_split(test_size=0.25)
    tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

    model_name = f"bert-large-uncased-{version}"
    model = BertForSequenceClassification.from_pretrained("bert-large-uncased", num_labels=2)
    model.to(device)

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        preds = np.argmax(logits, axis=-1)
        acc = accuracy_score(labels, preds)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', zero_division=0)
        return {"accuracy": acc, "precision": precision, "recall": recall, "f1": f1}

    training_args = TrainingArguments(
        output_dir="./results",
        eval_strategy="epoch",             # Evaluate every epoch
        save_strategy="epoch",                   # Save every epoch
        learning_rate=2e-5,                      # From table
        per_device_train_batch_size=16,          # From table
        per_device_eval_batch_size=16,           # From table
        num_train_epochs=5,                      # From table
        weight_decay=0.01,                       # From table
        warmup_ratio=0.0,                        # Table didn't specify, assume 0 unless specified
        lr_scheduler_type="linear",              # From table
        logging_dir="./logs",
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        optim="adamw_torch",                     # Explicitly using AdamW
        adam_epsilon=1e-8                        # From table
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized["train"],
        eval_dataset=tokenized["test"],
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )

    trainer.train()
    eval_result = trainer.evaluate()
    results[version] = eval_result

    if update_best_model(
        model_name=model_name,
        f1_score=eval_result["eval_f1"],
        precision=eval_result["eval_precision"],
        recall=eval_result["eval_recall"],
        accuracy=eval_result["eval_accuracy"]
    ):
        trainer.save_model(f"models/best-{model_name}")

    if eval_result["eval_f1"] > best_f1:
        best_f1 = eval_result["eval_f1"]
        best_model = model
        best_version = version

# Final summary
print("\n--- Summary of Results (BERT-LARGE) ---")
for version, result in results.items():
    print(f"{version.upper()}: F1 = {result['eval_f1']:.4f}")

print(f"\n✅ Best model: {best_version.upper()} (F1 = {best_f1:.4f})")
