# 🧠 Weighted Loss Fine-Tuning: Adversarial DistilBERT

Improves performance by penalizing false positives to balance precision and recall.

In [None]:
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding, DistilBertTokenizerFast
import wandb
import torch
import numpy as np
from torch.nn import CrossEntropyLoss

wandb.login()
wandb.init(project="adversarial-phishing-defense-weighted")

model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Define class weights: higher weight to legitimate class (label 0) to reduce false positives
class_weights = torch.tensor([1.5, 1.0]).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

# Custom Trainer with class-weighted loss
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss_fct = CrossEntropyLoss(weight=class_weights)
        loss = loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss


In [None]:
training_args = TrainingArguments(
    output_dir="./results_weighted",
    num_train_epochs=4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=200,
    load_best_model_at_end=True,
    fp16=True,
    disable_tqdm=True,
    report_to="wandb"
)

trainer = WeightedTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator
)

trainer.train()


In [None]:
from sklearn.metrics import classification_report

preds_output = trainer.predict(test_dataset)
y_pred = preds_output.predictions.argmax(-1)
y_true = [y for y in test_dataset["label"]]

print("🔍 Classification Report (Weighted Loss):")
print(classification_report(y_true, y_pred))
