In [None]:
# --- Step 1. Load dataset ---
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from datasets import Dataset
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
import seaborn as sns

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
)

# Load data
df = pd.read_csv("/content/masterList_relabelled_v2.csv")
df = df.dropna(subset=["text", "label"])
df["label"] = df["label"].astype(int)

dataset = Dataset.from_pandas(df[["text", "label"]])

# --- Step 2. Train/validation split ---
dataset = dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

# --- Step 3. Tokenizer ---
MODEL_NAME = "microsoft/deberta-v3-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize_function(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=256)

train_dataset = train_dataset.map(tokenize_function, batched=True)
eval_dataset  = eval_dataset.map(tokenize_function, batched=True)

train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

# --- Step 4. Compute dynamic class weights ---
labels = df["label"].values
weights = compute_class_weight(class_weight="balanced", classes=np.unique(labels), y=labels)
print("Class Weights:", dict(zip(np.unique(labels), weights)))

class_weights = torch.tensor(weights, dtype=torch.float).to("cuda")
loss_fn = nn.CrossEntropyLoss(weight=class_weights)

# --- Step 5. Load Model ---
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=4).to("cuda")

# --- Step 6. Metrics ---
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="weighted")
    macro_f1 = precision_recall_fscore_support(labels, preds, average="macro")[2]
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1_weighted": f1, "f1_macro": macro_f1,
            "precision": precision, "recall": recall}

# --- Step 7. Custom Trainer ---
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels").to(torch.long)
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss = loss_fn(logits, labels)
        return (loss, outputs) if return_outputs else loss

# --- Step 8. Training arguments ---
training_args = TrainingArguments(
    output_dir="./results_deberta",
    eval_strategy="epoch",
    save_strategy="epoch",
    report_to="none",      
    learning_rate=5e-5,
    warmup_ratio=0.1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=6,
    weight_decay=0.01,
    logging_dir="./logs_deberta",
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="f1_weighted",
    greater_is_better=True,
)

# --- Step 9. Trainer ---
trainer = WeightedTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# --- Step 10. Train ---
trainer.train()

# --- Step 11. Evaluate ---
predictions = trainer.predict(eval_dataset)
y_true = predictions.label_ids
y_pred = np.argmax(predictions.predictions, axis=1)

print(classification_report(
    y_true, y_pred,
    target_names=["Valid (0)", "Spam (1)", "LowQuality (2)", "Rant (3)"]
))

# --- Step 12. Confusion Matrix ---
cm = confusion_matrix(y_true, y_pred, labels=[0,1,2,3])
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=["Valid (0)", "Spam (1)", "LowQuality (2)", "Rant (3)"],
            yticklabels=["Valid (0)", "Spam (1)", "LowQuality (2)", "Rant (3)"])
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix - DeBERTa v3 Base")
plt.show()
