In [None]:
#loading libraries
import os
import random
import time
from itertools import product
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from datasets import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    roc_curve,
    auc,
    precision_recall_curve,
    average_precision_score,
)
from sklearn.preprocessing import label_binarize

from setfit import SetFitModel, SetFitTrainer


# disable WandB to keep notebook output clean
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"
os.environ["WANDB_SILENT"] = "true"
os.environ["TOKENIZERS_PARALLELISM"] = "false"


DATA_PATH = "/content/Sample_500_stratified.xlsx"

SPLIT_SEED = 42
HOLDOUT_TEST_SIZE = 0.20

BASE_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"

BATCH_SIZE = 32
LEARNING_RATE = 2e-5

GRID_NUM_ITERATIONS = [20, 50]
GRID_NUM_EPOCHS = [2, 3]

RANKING_SEED = 42
BEST_SEEDS = [42, 7]

MIN_SAMPLES_PER_CLASS = 2
RARE_CLASS_LABEL = "Other / very rare"


# tasks to run (Services_Label intentionally excluded here)
TASKS = [
    {
        "TASK_NAME": "cuisine_region",
        "LABEL_COL": "cuisine_region",
        "FEATURE_COLS": ["desc_1", "desc_2", "favourite_dish_ingredients"],
    },
    {
        "TASK_NAME": "opening_class_label",
        "LABEL_COL": "opening_class_label",
        "FEATURE_COLS": ["opening_hours"],
    },
    {
        "TASK_NAME": "Chain/Indep",
        "LABEL_COL": "Chain/Indep",
        "FEATURE_COLS": ["title"],
    },
    {
        "TASK_NAME": "concept_format",
        "LABEL_COL": "concept_format",
        "FEATURE_COLS": ["desc_1", "desc_2"],
    },
]


def set_all_seeds(seed: int) -> None:
    # ensure reproducibility across numpy, random and torch
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def build_text(df: pd.DataFrame, feature_cols: List[str]) -> pd.Series:
    # concatenate selected feature columns into one clean text field
    X = df[feature_cols].copy()
    for c in feature_cols:
        X[c] = X[c].fillna("").astype(str).str.strip()
    return (
        X.agg(" ".join, axis=1)
        .astype(str)
        .str.replace(r"\s+", " ", regex=True)
        .str.strip()
    )


def bundle_rare_classes(y: pd.Series, min_count: int, rare_label: str) -> pd.Series:
    # merge very small classes into a single fallback label
    vc = y.value_counts(dropna=False)
    rare = vc[vc < min_count].index
    if len(rare) == 0:
        return y
    return y.where(~y.isin(rare), other=rare_label)


def safe_train_test_split(
    X: List[str], y: List[str], test_size: float, seed: int
) -> Tuple[List[str], List[str], List[str], List[str], bool]:
    # try stratified split; if it fails, fall back to non-stratified
    try:
        X_tr, X_te, y_tr, y_te = train_test_split(
            X, y, test_size=test_size, random_state=seed, stratify=y
        )
        return X_tr, X_te, y_tr, y_te, True
    except Exception:
        X_tr, X_te, y_tr, y_te = train_test_split(
            X, y, test_size=test_size, random_state=seed
        )
        return X_tr, X_te, y_tr, y_te, False


def compute_metrics(y_true: List[str], y_pred: List[str], labels: List[str]) -> Dict[str, float]:
    # compute standard macro classification metrics
    acc = accuracy_score(y_true, y_pred)
    p, r, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, labels=labels, average="macro", zero_division=0
    )
    return {
        "accuracy": float(acc),
        "precision_macro": float(p),
        "recall_macro": float(r),
        "f1_macro": float(f1),
    }


def to_numpy_proba(x) -> Optional[np.ndarray]:
    # safely convert probabilities to numpy
    if x is None:
        return None
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return np.asarray(x)


def train_eval_setfit(
    X_train: List[str],
    y_train: List[str],
    X_eval: List[str],
    y_eval: List[str],
    class_list: List[str],
    num_iterations: int,
    num_epochs: int,
    seed: int,
    return_model: bool = False,
) -> Tuple[Dict[str, float], float, Optional[SetFitModel], Optional[np.ndarray], Optional[np.ndarray]]:
    # train and evaluate one SetFit configuration
    set_all_seeds(seed)

    train_ds = Dataset.from_dict({"text": X_train, "label": y_train})

    model = SetFitModel.from_pretrained(
        BASE_MODEL,
        labels=class_list,
        head_params={"class_weight": "balanced", "max_iter": 2000},
    )

    trainer = SetFitTrainer(
        model=model,
        train_dataset=train_ds,
        column_mapping={"text": "text", "label": "label"},
        batch_size=BATCH_SIZE,
        num_epochs=num_epochs,
        num_iterations=num_iterations,
        learning_rate=LEARNING_RATE,
        seed=seed,
    )

    t0 = time.time()
    trainer.train()
    train_time = time.time() - t0

    y_pred = model.predict(X_eval)
    metrics = compute_metrics(y_eval, y_pred, class_list)

    y_prob = None
    try:
        y_prob = to_numpy_proba(model.predict_proba(X_eval))
    except Exception:
        y_prob = None

    if not return_model:
        del trainer, model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return metrics, float(train_time), None, np.asarray(y_pred, dtype=object), y_prob

    del trainer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return metrics, float(train_time), model, np.asarray(y_pred, dtype=object), y_prob


def plot_roc_pr(task_name: str, y_true: List[str], y_prob, class_list: List[str]) -> None:
    # plot ROC and PR curves (binary or micro-averaged multiclass)
    y_prob = to_numpy_proba(y_prob)
    if y_prob is None:
        print(f"{task_name}: no probabilities -> skipping ROC/PR curves")
        return

    if len(class_list) < 2:
        print(f"{task_name}: not enough classes for curves")
        return

    y_true_arr = np.asarray(y_true, dtype=object)

    if len(class_list) == 2:
        # binary case
        pos_class = class_list[1]
        y_true_bin = (y_true_arr == pos_class).astype(int)
        y_score = y_prob[:, 1].astype(float)

        fpr, tpr, _ = roc_curve(y_true_bin, y_score)
        roc_auc = auc(fpr, tpr)

        prec, rec, _ = precision_recall_curve(y_true_bin, y_score)
        ap = average_precision_score(y_true_bin, y_score)

        plt.figure()
        plt.plot(fpr, tpr)
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.title(f"ROC — {task_name} (AUC={roc_auc:.3f})")
        plt.show()

        plt.figure()
        plt.plot(rec, prec)
        plt.xlabel("Recall")
        plt.ylabel("Precision")
        plt.title(f"PR — {task_name} (AP={ap:.3f})")
        plt.show()

        print(f"{task_name}: AUC={roc_auc:.4f} | AP={ap:.4f}")
        return

    # multiclass (micro-average)
    y_true_bin = label_binarize(y_true_arr, classes=class_list)
    y_score = y_prob.astype(float)

    fpr, tpr, _ = roc_curve(y_true_bin.ravel(), y_score.ravel())
    roc_auc = auc(fpr, tpr)

    prec, rec, _ = precision_recall_curve(y_true_bin.ravel(), y_score.ravel())
    ap = average_precision_score(y_true_bin, y_score, average="micro")

    plt.figure()
    plt.plot(fpr, tpr)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC (micro) — {task_name} (AUC={roc_auc:.3f})")
    plt.show()

    plt.figure()
    plt.plot(rec, prec)
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"PR (micro) — {task_name} (AP={ap:.3f})")
    plt.show()

    print(f"{task_name}: AUC(micro)={roc_auc:.4f} | AP(micro)={ap:.4f}")


t_total0 = time.time()

df = pd.read_excel(DATA_PATH)
df.columns = df.columns.astype(str).str.strip()
print(f"File loaded: {DATA_PATH} | n={len(df)} | p={len(df.columns)}")

grid = [
    {"num_iterations": it, "num_epochs": ep}
    for it, ep in product(GRID_NUM_ITERATIONS, GRID_NUM_EPOCHS)
]

for task in TASKS:
    t_task0 = time.time()

    task_name = task["TASK_NAME"]
    label_col = task["LABEL_COL"]
    feature_cols = task["FEATURE_COLS"]

    missing = [c for c in ([label_col] + feature_cols) if c not in df.columns]
    if missing:
        raise ValueError(f"Missing columns ({task_name}): {missing}")

    d = df[feature_cols + [label_col]].copy()
    d[label_col] = d[label_col].astype(str).str.strip()
    d = d.dropna(subset=[label_col])

    d["text"] = build_text(d, feature_cols)
    d["label"] = d[label_col].astype(str).str.strip()

    # quick cleanup: remove empty and "nan/none"
    d = d[(d["text"].str.strip() != "") & (d["label"].str.strip() != "")].reset_index(drop=True)
    d = d[~d["label"].str.lower().isin(["nan", "none"])].reset_index(drop=True)

    if len(d) < 50 or d["label"].nunique() < 2:
        raise ValueError(f"Too few samples ({task_name}): n={len(d)}, k={d['label'].nunique()}")

    # bundle rare classes to stabilize stratification and training
    d["label"] = bundle_rare_classes(d["label"], MIN_SAMPLES_PER_CLASS, RARE_CLASS_LABEL)

    if d["label"].nunique() < 2:
        raise ValueError(f"Only one class after bundling ({task_name})")

    X_all = d["text"].tolist()
    y_all = d["label"].tolist()
    class_list = sorted(d["label"].unique().tolist())

    set_all_seeds(SPLIT_SEED)
    X_tr, X_te, y_tr, y_te, strat_used = safe_train_test_split(
        X_all, y_all, HOLDOUT_TEST_SIZE, SPLIT_SEED
    )

    print(f"\nTask: {task_name}")
    print(f"n={len(d)} | k={len(class_list)} | train={len(y_tr)} | test={len(y_te)} | stratified={strat_used}")
    print(f"Features: {feature_cols} | Label: {label_col}")

    # 1) ranking phase: evaluate each config once 
    ranking_rows = []
    best_cfg = None
    best_f1 = -1.0

    for cfg in grid:
        metrics, train_time, _, _, _ = train_eval_setfit(
            X_train=X_tr,
            y_train=y_tr,
            X_eval=X_te,
            y_eval=y_te,
            class_list=class_list,
            num_iterations=cfg["num_iterations"],
            num_epochs=cfg["num_epochs"],
            seed=RANKING_SEED,
            return_model=False,
        )

        row = {
            "num_iterations": cfg["num_iterations"],
            "num_epochs": cfg["num_epochs"],
            "batch_size": BATCH_SIZE,
            "learning_rate": LEARNING_RATE,
            "train_time_sec": train_time,
            **metrics,
        }
        ranking_rows.append(row)

        if metrics["f1_macro"] > best_f1:
            best_f1 = metrics["f1_macro"]
            best_cfg = cfg

    ranking_df = (
        pd.DataFrame(ranking_rows)
        .sort_values("f1_macro", ascending=False)
        .reset_index(drop=True)
    )

    print("\nGrid ranking (Seed 42), sorted by F1(macro):")
    print(
        ranking_df[
            [
                "num_iterations",
                "num_epochs",
                "batch_size",
                "learning_rate",
                "accuracy",
                "precision_macro",
                "recall_macro",
                "f1_macro",
                "train_time_sec",
            ]
        ].to_string(index=False)
    )

    if best_cfg is None:
        raise ValueError(f"No valid configuration ({task_name})")

    print(
        f"\nBest setup: it={best_cfg['num_iterations']} | ep={best_cfg['num_epochs']} | bs={BATCH_SIZE} | lr={LEARNING_RATE}"
    )

    # 2) run best config with two seeds 
    best_metrics_all = []
    best_probs_for_plots = None
    best_model_for_plots = None

    for seed in BEST_SEEDS:
        metrics, train_time, model, _, y_prob = train_eval_setfit(
            X_train=X_tr,
            y_train=y_tr,
            X_eval=X_te,
            y_eval=y_te,
            class_list=class_list,
            num_iterations=best_cfg["num_iterations"],
            num_epochs=best_cfg["num_epochs"],
            seed=seed,
            return_model=(seed == BEST_SEEDS[0]),
        )

        best_metrics_all.append(
            {"seed": seed, "train_time_sec": train_time, **metrics}
        )

        if seed == BEST_SEEDS[0]:
            best_model_for_plots = model
            best_probs_for_plots = y_prob

    best_df = pd.DataFrame(best_metrics_all)
    mean_row = best_df[
        ["accuracy", "precision_macro", "recall_macro", "f1_macro", "train_time_sec"]
    ].mean()
    std_row = best_df[
        ["accuracy", "precision_macro", "recall_macro", "f1_macro"]
    ].std(ddof=0)

    print("\nBest setup (Seeds 42 & 7):")
    print(
        best_df[
            ["seed", "accuracy", "precision_macro", "recall_macro", "f1_macro", "train_time_sec"]
        ].to_string(index=False)
    )

    print("\nMean ± Std (2 seeds):")
    print(
        f"Accuracy:  {mean_row['accuracy']:.4f} ± {std_row['accuracy']:.4f}\n"
        f"Precision: {mean_row['precision_macro']:.4f} ± {std_row['precision_macro']:.4f}\n"
        f"Recall:    {mean_row['recall_macro']:.4f} ± {std_row['recall_macro']:.4f}\n"
        f"F1(macro): {mean_row['f1_macro']:.4f} ± {std_row['f1_macro']:.4f}"
    )

    # plot curves only once (first seed)
    plot_roc_pr(
        task_name=task_name,
        y_true=y_te,
        y_prob=best_probs_for_plots,
        class_list=class_list,
    )

    if best_model_for_plots is not None:
        del best_model_for_plots
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f"Task runtime: {time.time() - t_task0:.1f}s")

print(f"\nTotal runtime: {time.time() - t_total0:.1f}s")