In [None]:
import os
import time
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    f1_score, classification_report, roc_curve, auc,
    multilabel_confusion_matrix, precision_recall_curve
)
from sklearn.calibration import calibration_curve
from transformers import (
    BertTokenizer, BertForSequenceClassification,
    BertConfig, Trainer, TrainingArguments, EarlyStoppingCallback
)
from sklearn.metrics import precision_recall_curve
from datasets import Dataset
from wordcloud import WordCloud, STOPWORDS
import nltk
from nltk.corpus import stopwords

nltk.download("stopwords")

In [None]:
# === CONFIGURATION ===
base_dir = Path("./notebooks/stage2/v2")
base_dir2 = Path("./datasets/stage2/v2")
combined_csv = base_dir2 / "stage2_final_combined.csv"

# Splits and Model directories
splits_dir = Path("./s2_split")
model_dir = Path("./s2_mb_model")
visual_dir = Path("./s2_mb_visual")

model_dir.mkdir(parents=True, exist_ok=True)
visual_dir.mkdir(parents=True, exist_ok=True)

In [None]:
# === LABELS (lowercase) ===
label_cols = ["Race", "Religion", "Gender", "Sexual_Orientation"]
ml_label_cols = [f"label_{c}" for c in label_cols]

In [None]:
# 2) Load saved splits (npy files)
train_texts = np.load(splits_dir / "train_texts.npy", allow_pickle=True)
train_labels = np.load(splits_dir / "train_labels.npy", allow_pickle=True)

val_texts = np.load(splits_dir / "val_texts.npy", allow_pickle=True)
val_labels = np.load(splits_dir / "val_labels.npy", allow_pickle=True)

test_texts = np.load(splits_dir / "test_texts.npy", allow_pickle=True)
test_labels = np.load(splits_dir / "test_labels.npy", allow_pickle=True)
test_langs = np.load(splits_dir / "test_langs.npy", allow_pickle=True) 

In [None]:
# === TOKENIZER & DATASET FUNCTIONS ===
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")

def tokenize(texts, labels):
    enc = tokenizer(list(map(str, texts)), padding="max_length", truncation=True, max_length=128)
    return Dataset.from_dict({
        "input_ids": enc["input_ids"],
        "attention_mask": enc["attention_mask"],
        "labels": [list(map(float, l)) for l in labels]
    })

train_dataset = tokenize(train_texts, train_labels)
val_dataset   = tokenize(val_texts, val_labels)
test_dataset  = tokenize(test_texts, test_labels)

In [None]:
# 5) Compute pos_weight for multilabel BCE (same as Stage 2 MBERT code)
total_samples = train_labels.shape[0] + val_labels.shape[0] + test_labels.shape[0]
label_counts = train_labels.sum(axis=0) + val_labels.sum(axis=0) + test_labels.sum(axis=0)
pos_weights = torch.tensor((total_samples - label_counts) / label_counts, dtype=torch.float)
if torch.cuda.is_available():
    pos_weights = pos_weights.cuda()

In [None]:
# === MODEL DEFINITION ===
config = BertConfig.from_pretrained(
    "bert-base-multilingual-cased", num_labels=len(label_cols),
    problem_type="multi_label_classification", hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1
)
model = BertForSequenceClassification.from_pretrained(
    "bert-base-multilingual-cased", config=config
)
if torch.cuda.is_available(): model.cuda()

In [None]:
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weights)
        loss = loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

In [None]:
def compute_metrics(eval_preds):
    logits, labels_arr = eval_preds
    probs = torch.sigmoid(torch.tensor(logits)).numpy()
    preds_bin = (probs > 0.5).astype(int)
    metrics = {f"f1_{c}": f1_score(labels_arr[:, i], preds_bin[:, i], zero_division=0)
               for i, c in enumerate(label_cols)}
    metrics["macro_f1"] = f1_score(labels_arr, preds_bin, average="macro", zero_division=0)
    return metrics

In [None]:
# === TRAINING ARGUMENTS ===
training_args = TrainingArguments(
    output_dir=str(model_dir / "checkpoints"), 
    eval_strategy="epoch",
    save_strategy="epoch", 
    logging_strategy="epoch", 
    learning_rate=2e-5,
    per_device_train_batch_size=32, 
    per_device_eval_batch_size=32,
    num_train_epochs=6, 
    weight_decay=0.01, 
    load_best_model_at_end=True,
    metric_for_best_model="macro_f1", 
    greater_is_better=True,
    fp16=torch.cuda.is_available(), 
    logging_dir=str(model_dir / "logs"),
    save_total_limit=2
)

In [None]:
trainer = WeightedTrainer(
    model=model, args=training_args,
    train_dataset=train_dataset, eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

In [None]:
# === TRAIN & MEASURE TIME ===
start_time = time.time()
trainer.train()
train_time = time.time() - start_time
print(f"ðŸ•’ Stage 2 training time: {train_time:.2f} seconds")

In [None]:
# === PREDICTION & INFERENCE TIME ===
start_inf = time.time()
preds_output = trainer.predict(test_dataset)
inf_time = time.time() - start_inf
inference_time_per_sample = inf_time / len(test_dataset)
print(f"ðŸ•’ Inference time per sample: {inference_time_per_sample:.4f} seconds")

In [None]:
# Extract arrays
logits = preds_output.predictions
true_labels = preds_output.label_ids
probs_array = torch.sigmoid(torch.tensor(logits)).numpy()

optimal_thresholds = []
for i, label in enumerate(label_cols):
    precision, recall, thresholds = precision_recall_curve(true_labels[:, i], probs_array[:, i])
    f1_scores = 2 * precision * recall / (precision + recall + 1e-8)
    best_thresh = thresholds[np.argmax(f1_scores)]
    optimal_thresholds.append(best_thresh)
    print(f"âœ… {label}: Best threshold = {best_thresh:.2f}")

# Apply best thresholds to probs_array
preds_bin = np.zeros_like(probs_array)
for i, thresh in enumerate(optimal_thresholds):
    preds_bin[:, i] = (probs_array[:, i] > thresh).astype(int)

In [None]:
# Save arrays for reuse
np.save(model_dir / "y_true.npy", true_labels)
np.save(model_dir / "y_pred.npy", preds_bin)
np.save(model_dir / "probs.npy", probs_array)
np.save(model_dir / "test_texts.npy", test_texts)
np.save(model_dir / "test_langs.npy", test_langs)

In [None]:
# === CREATE EVAL DATAFRAME ===
df_eval = pd.DataFrame({
    "text": test_texts,
    "lang": test_langs
})
for i, c in enumerate(label_cols):
    df_eval[f"true_{c}"] = true_labels[:, i]
    df_eval[f"pred_{c}"] = preds_bin[:, i]

In [None]:
# Overall Classification Report
report_dict = classification_report(
    true_labels, preds_bin, target_names=label_cols, zero_division=0, output_dict=True
)
pd.DataFrame(report_dict).transpose().to_csv(visual_dir / "classification_report_overall.csv", index=True)

In [None]:
# Classification Report by Language
for lang in df_eval["lang"].unique():
    subset = df_eval[df_eval["lang"] == lang]
    y_true_lang = subset[[f"true_{c}" for c in label_cols]].values
    y_pred_lang = subset[[f"pred_{c}" for c in label_cols]].values
    rep_lang = classification_report(
        y_true_lang, y_pred_lang, target_names=label_cols, zero_division=0, output_dict=True
    )
    pd.DataFrame(rep_lang).transpose().to_csv(visual_dir / f"classification_report_{lang}.csv", index=True)

In [None]:
# ROC Curves (Overall)
plt.figure(figsize=(8, 6))
for i, c in enumerate(label_cols):
    fpr, tpr, _ = roc_curve(true_labels[:, i], probs_array[:, i])
    plt.plot(fpr, tpr, label=f"{c} (AUC = {auc(fpr, tpr):.2f})")
plt.plot([0, 1], [0, 1], "k--")
plt.title("ROC Curves (Overall)")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(visual_dir / "roc_curves_overall.png")
plt.close()

In [None]:
# Precision-Recall Curves (Overall)
plt.figure(figsize=(8, 6))
for i, c in enumerate(label_cols):
    p, r, _ = precision_recall_curve(true_labels[:, i], probs_array[:, i])
    plt.plot(r, p, label=c)
plt.title("Precision-Recall Curves (Overall)")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(visual_dir / "pr_curves_overall.png")
plt.close()

In [None]:
# Calibration Curves (Overall)
plt.figure(figsize=(8, 6))
for i, c in enumerate(label_cols):
    frac_pos, mean_pred = calibration_curve(
        true_labels[:, i], probs_array[:, i], n_bins=10, strategy="quantile"
    )
    plt.plot(mean_pred, frac_pos, marker="o", label=c)
plt.plot([0, 1], [0, 1], "k--")
plt.title("Calibration Curves (Overall)")
plt.xlabel("Mean Predicted Probability")
plt.ylabel("Fraction of Positives")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(visual_dir / "calibration_curves_overall.png")
plt.close()

In [None]:
# Multilabel Confusion Matrices (Overall)
mcm = multilabel_confusion_matrix(true_labels, preds_bin)
fig, axes = plt.subplots(1, len(label_cols), figsize=(16, 4))
for i, (ax, c) in enumerate(zip(axes, label_cols)):
    sns.heatmap(mcm[i], annot=True, fmt="d", cmap="Blues", ax=ax)
    ax.set_title(c)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
plt.tight_layout()
plt.savefig(visual_dir / "confusion_matrices_overall.png")
plt.close()

In [None]:
# 1) Combine stopwords
eng_stop = set(stopwords.words("english"))
custom_stop = {
    "saya", "awak", "kau", "kita", "kamu", "dia", "mereka", "kita", "kami",
    "yang", "itu", "ini", "dan", "atau", "dengan", "dalam", "kepada", "untuk",
    "akan", "telah", "boleh", "tidak", "hanya", "lagi", "kerana", "jika",
    "oleh", "pada", "sebagai", "adalah", "apa", "semua", "daripada", "lebih",
    "perlu", "juga", "sudah", "masih", "pun", "satu", "ini", "mana", "setiap",
    "tiada", "seorang", "bagaimana", "kenapa", "jadi", "akan", "mungkin", "mereka",
    "dalam", "dengan", "untuk", "mempunyai", "anda", "user", "number", "url", "menjadi", "dari",
    "tetapi", "bahawa", "seperti", "di", "sangat", "ada", "apabila", "ia"
}
all_stopwords = set(STOPWORDS).union(eng_stop).union(custom_stop)

# 2) Load test_texts array
test_texts = np.load("s2_split/test_texts.npy", allow_pickle=True)

# 3) Word Cloud (All Text)
all_text = " ".join(map(str, test_texts.tolist()))
wc_all = WordCloud(
    width=1200,
    height=600,
    stopwords=all_stopwords,
    background_color="white"
).generate(all_text)

plt.figure(figsize=(12, 6))
plt.imshow(wc_all, interpolation="bilinear")
plt.axis("off")
plt.title("Word Cloud (All Text)")
plt.tight_layout()
plt.savefig(visual_dir / "wordcloud_all_text.png")
plt.close()

In [None]:
# 4) Word Clouds by Label
test_labels = np.load("s2_split/test_labels.npy", allow_pickle=True)

for i, c in enumerate(label_cols):
    subset_indices = np.where(test_labels[:, i] == 1)[0]
    subset_texts = [str(test_texts[idx]) for idx in subset_indices]
    if subset_texts:
        wc_label = WordCloud(
            width=1200,
            height=600,
            stopwords=all_stopwords,
            background_color="white"
        ).generate(" ".join(subset_texts))

        plt.figure(figsize=(12, 6))
        plt.imshow(wc_label, interpolation="bilinear")
        plt.axis("off")
        plt.title(f"Word Cloud â€“ {c}")
        plt.tight_layout()
        plt.savefig(visual_dir / f"wordcloud_{c}.png")
        plt.close()

In [None]:
# 9) Macro-F1 by Language
lang_scores = []
for lang in df_eval["lang"].unique():
    sub = df_eval[df_eval["lang"] == lang]
    y_true_lang = sub[[f"true_{c}" for c in label_cols]].values
    y_pred_lang = sub[[f"pred_{c}" for c in label_cols]].values
    lang_scores.append((lang, f1_score(y_true_lang, y_pred_lang, average="macro")))
lang_df = pd.DataFrame(lang_scores, columns=["lang", "macro_f1"])
sns.barplot(data=lang_df, x="lang", y="macro_f1")
plt.ylim(0, 1)
plt.title("Macro-F1 by Language")
plt.tight_layout()
plt.savefig(visual_dir / "macro_f1_by_language.png")
plt.close()

In [None]:
# 10) Label Co-occurrence Heatmap
co_matrix = test_labels.T.dot(test_labels)
sns.heatmap(
    co_matrix,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=label_cols,
    yticklabels=label_cols
)
plt.title("Label Co-occurrence")
plt.tight_layout()
plt.savefig(visual_dir / "label_cooccurrence.png")
plt.close()

In [None]:
# 11) Confidence KDE Plot
plt.figure(figsize=(12, 6))
for i, c in enumerate(label_cols):
    sns.kdeplot(probs_array[:, i], fill=True, label=c)
plt.title("Confidence Distribution (Overall)")
plt.xlabel("Predicted Probability")
plt.legend()
plt.tight_layout()
plt.savefig(visual_dir / "confidence_kde.png")
plt.close()


In [None]:
# Save inference time
with open(visual_dir / "inference_time.txt", "w") as f:
    f.write(f"Inference time per sample: {inference_time_per_sample:.4f} seconds\n")

print("Stage 2 visualisations saved under:", visual_dir)

In [None]:
# === Underfitting/Overfitting Loss Plot ===

import matplotlib.pyplot as plt

# Extract training & validation losses from trainerâ€™s log_history
log_history = trainer.state.log_history

train_loss = [entry["loss"] for entry in log_history if "loss" in entry]
val_loss   = [entry["eval_loss"] for entry in log_history if "eval_loss" in entry]

# Align lengths
min_len = min(len(train_loss), len(val_loss))
train_loss = train_loss[:min_len]
val_loss   = val_loss[:min_len]
epochs     = list(range(1, min_len + 1))

# Plot
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_loss, marker="o", label="Training Loss")
plt.plot(epochs, val_loss,   marker="o", label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()

# Save to visual_dir
plt.savefig(visual_dir / "loss_overfit_underfit.png")
plt.close()


In [None]:
# === Save the Best Model and Tokenizer ===
model_path = Path("s2_mb_model")
model_path.mkdir(parents=True, exist_ok=True)

# Save best model weights and tokenizer
trainer.save_model(model_path)
tokenizer.save_pretrained(model_path)

# Save thresholds for multi-label inference
import json
thresholds_dict = {label: float(thresh) for label, thresh in zip(label_cols, optimal_thresholds)}
with open(model_path / "thresholds.json", "w") as f:
    json.dump(thresholds_dict, f)

# Save predictions for visuals
np.save(visual_dir / "y_pred_mb.npy", preds_bin)

print(f"Model and tokenizer saved to: {model_path}")