In [None]:
# Loading Libraries
import os
import random
import time
from dataclasses import dataclass
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from IPython.display import display
from datasets import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm.auto import tqdm

from setfit import SetFitModel, SetFitTrainer


# disable wandb + reduce tokenizer side messages
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"
os.environ["WANDB_SILENT"] = "true"
os.environ["TOKENIZERS_PARALLELISM"] = "false"


# Input data
DATA_PATH = "/content/Sample_500_stratified.xlsx"

# Split configuration (holdout split used for learning curves)
SAAT_SPLIT = 42
TESTANTEIL = 0.20

# Training setup 
STAPELGROESSE = 16
ANZAHL_EPOCHEN = 1
LERNRATE = 2e-5

# Base model and SetFit-specific settings
BASISMODELL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
ANZAHL_ITERATIONEN = 20
SAATEN = [42, 7, 123]

# Learning curve points (subset sizes drawn from the training pool)
TRAININGSGROESSEN = [25, 50, 100, 200, 300, 400]

# Rare-class handling 
MIN_BEISPIELE_PRO_KLASSE = 2
SELTENE_KLASSE_LABEL = "Other / very rare"


def setze_alle_saaten(saat: int) -> None:
    # Reproducibility across python, numpy, and torch
    random.seed(saat)
    np.random.seed(saat)
    torch.manual_seed(saat)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(saat)


def finde_spalte(df: pd.DataFrame, kandidaten):
    # Robust column lookup: exact match -> case-insensitive -> normalized (remove separators)
    if isinstance(kandidaten, str):
        kandidaten = [kandidaten]

    for k in kandidaten:
        if k in df.columns:
            return k

    lower_map = {sp.lower(): sp for sp in df.columns}
    for k in kandidaten:
        if k.lower() in lower_map:
            return lower_map[k.lower()]

    def norm(s: str) -> str:
        return s.lower().replace(" ", "").replace("-", "").replace("/", "").replace("\\", "")

    norm_map = {norm(sp): sp for sp in df.columns}
    for k in kandidaten:
        key = norm(k)
        if key in norm_map:
            return norm_map[key]

    return None


def erstelle_text_X(df: pd.DataFrame, spalten: List[str]) -> pd.Series:
    # Combine one or more text columns into a single input string per row
    X = df[spalten].copy()
    for sp in spalten:
        X[sp] = X[sp].fillna("").astype(str).str.strip()
    return X.agg(" ".join, axis=1).astype(str)


def makro_f1(y_true, y_pred, klassen_liste) -> float:
    # Macro-F1: treats each class equally 
    _, _, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="macro", labels=klassen_liste, zero_division=0
    )
    return float(f1)


def sichere_train_test_teilung(
    X: List[str], y: List[str], testanteil: float, saat: int
) -> Tuple[List[str], List[str], List[str], List[str], bool]:
    # Try a stratified split; if it fails, fall back to a regular split
    try:
        X_pool, X_eval, y_pool, y_eval = train_test_split(
            X, y, test_size=testanteil, random_state=saat, stratify=y
        )
        return X_pool, X_eval, y_pool, y_eval, True
    except Exception:
        X_pool, X_eval, y_pool, y_eval = train_test_split(
            X, y, test_size=testanteil, random_state=saat
        )
        return X_pool, X_eval, y_pool, y_eval, False


def geschichtete_teilstichprobe_indizes(y: List[str], n: int, saat: int) -> np.ndarray:
    # Draw a stratified subset of size n from the pool indices
    y = np.asarray(y, dtype=object)
    idx = np.arange(len(y))

    if n <= 0:
        raise ValueError("n must be > 0")
    if n > len(y):
        raise ValueError("n must not be larger than the pool size")
    if n == len(y):
        return idx

    _, idx_sub, _, _ = train_test_split(idx, y, test_size=n, random_state=saat, stratify=y)
    return idx_sub


def trainiere_und_evaluiere_setfit(
    train_texte: List[str],
    train_klassen: List[str],
    eval_texte: List[str],
    eval_klassen: List[str],
    klassen_liste: List[str],
    basis_modell: str,
    anzahl_iterationen: int,
    saat: int,
) -> Tuple[float, float, float]:
    # Train one SetFit run and evaluate on the holdout set
    setze_alle_saaten(saat)

    train_ds = Dataset.from_dict({"text": train_texte, "klasse": train_klassen})

    modell = SetFitModel.from_pretrained(
        basis_modell,
        labels=klassen_liste,
        head_params={"class_weight": "balanced", "max_iter": 2000},
    )

    trainer = SetFitTrainer(
        model=modell,
        train_dataset=train_ds,
        column_mapping={"text": "text", "klasse": "label"},
        batch_size=STAPELGROESSE,
        num_epochs=ANZAHL_EPOCHEN,
        num_iterations=anzahl_iterationen,
        learning_rate=LERNRATE,
        seed=saat,
    )

    t0 = time.time()
    trainer.train()
    train_zeit = time.time() - t0

    y_pred = modell.predict(eval_texte)
    acc = accuracy_score(eval_klassen, y_pred)
    f1m = makro_f1(eval_klassen, y_pred, klassen_liste)

    del trainer, modell
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return float(acc), float(f1m), float(train_zeit)


def fasse_laeufe_zusammen(accs, f1s, zeiten):
    # Aggregate repeated runs for a single training size
    return {
        "acc_mittel": float(np.mean(accs)),
        "acc_std": float(np.std(accs, ddof=0)),
        "f1_makro_mittel": float(np.mean(f1s)),
        "f1_makro_std": float(np.std(f1s, ddof=0)),
        "trainingszeit_sek_mittel": float(np.mean(zeiten)),
    }


@dataclass
class AufgabenKonfiguration:
    # Simple task definition: label candidates + feature candidates
    aufgabenname: str
    label_kandidaten: List[str]
    merkmal_kandidaten: List[List[str]]


# Task list 
AUFGABEN: List[AufgabenKonfiguration] = [
    AufgabenKonfiguration(
        aufgabenname="cuisine_region",
        label_kandidaten=["cuisine_region"],
        merkmal_kandidaten=[["desc_1"], ["desc_2"]],
    ),
    AufgabenKonfiguration(
        aufgabenname="opening_class_label",
        label_kandidaten=["opening_class_label"],
        merkmal_kandidaten=[["opening_hours"]],
    ),
    AufgabenKonfiguration(
        aufgabenname="Chain-Indep",
        label_kandidaten=["Chain-Indep"],
        merkmal_kandidaten=[["title"]],
    ),
    AufgabenKonfiguration(
        aufgabenname="Services_Label",
        label_kandidaten=["Services_Label"],
        merkmal_kandidaten=[["services"]],
    ),
    AufgabenKonfiguration(
        aufgabenname="concept_format",
        label_kandidaten=["concept_format"],
        merkmal_kandidaten=[["desc_1"], ["desc_2"], ["title"]],
    ),
]


def fuehre_lernkurven_fuer_aufgabe_aus(df_roh: pd.DataFrame, cfg: AufgabenKonfiguration) -> pd.DataFrame:
    # Resolve the label column and the feature columns for this task
    label_spalte = finde_spalte(df_roh, cfg.label_kandidaten)

    merkmal_spalten = []
    for kandidatenliste in cfg.merkmal_kandidaten:
        sp = finde_spalte(df_roh, kandidatenliste)
        merkmal_spalten.append(sp)

    fehlend = []
    if label_spalte is None:
        fehlend.append(f"label({cfg.aufgabenname})")
    for i, sp in enumerate(merkmal_spalten):
        if sp is None:
            fehlend.append(f"feature_{i+1}({cfg.aufgabenname})")

    if fehlend:
        print(f"\nTask '{cfg.aufgabenname}' skipped (missing columns: {fehlend})")
        return pd.DataFrame()

    # Build a compact working frame for this task
    d = df_roh[merkmal_spalten + [label_spalte]].copy()
    d = d.dropna(subset=[label_spalte])

    d["text"] = erstelle_text_X(d, merkmal_spalten)
    d["klasse"] = d[label_spalte].astype(str).str.strip()

    # Basic cleanup: drop empty text/labels and textual 'nan'
    d = d[(d["text"].str.strip() != "") & (d["klasse"].str.strip() != "")]
    d = d[d["klasse"].str.lower().ne("nan")].reset_index(drop=True)

    if len(d) < 50 or d["klasse"].nunique() < 2:
        print(f"\nTask '{cfg.aufgabenname}' skipped (too few samples/classes after cleaning)")
        return pd.DataFrame()

    # Bundle very rare classes so that splitting/training is stable
    haeufigkeiten = d["klasse"].value_counts()
    seltene_klassen = haeufigkeiten[haeufigkeiten < MIN_BEISPIELE_PRO_KLASSE].index
    if len(seltene_klassen) > 0:
        d.loc[d["klasse"].isin(seltene_klassen), "klasse"] = SELTENE_KLASSE_LABEL

    if d["klasse"].nunique() < 2:
        print(f"\nTask '{cfg.aufgabenname}' skipped (only one class left after bundling)")
        return pd.DataFrame()

    X_gesamt = d["text"].tolist()
    y_gesamt = d["klasse"].tolist()
    klassen_liste = sorted(d["klasse"].unique().tolist())

    # Split once into pool + holdout; learning curve subsamples come from the pool
    setze_alle_saaten(SAAT_SPLIT)
    X_pool, X_eval, y_pool, y_eval, strat_verwendet = sichere_train_test_teilung(
        X_gesamt, y_gesamt, TESTANTEIL, SAAT_SPLIT
    )

    print(f"\nTask: {cfg.aufgabenname}")
    print(
        f"n={len(d)} | classes={len(klassen_liste)} | pool={len(y_pool)} | holdout={len(y_eval)} | stratified={strat_verwendet}"
    )
    print(f"Features: {merkmal_spalten} | Label: {label_spalte}")

    zeilen = []
    groessen = [g for g in TRAININGSGROESSEN if g <= len(y_pool)]
    if len(groessen) == 0:
        groessen = [max(10, min(len(y_pool), 50))]

    # Loop over training sizes and average results across multiple seeds
    for n_train in tqdm(groessen, desc=f"{cfg.aufgabenname} (learning curve)", leave=False):
        try:
            idx_sub = geschichtete_teilstichprobe_indizes(y_pool, n_train, saat=SAAT_SPLIT)
        except Exception as e:
            zeilen.append(
                {
                    "aufgabe": cfg.aufgabenname,
                    "trainingsumfang": n_train,
                    "status": f"skipped ({type(e).__name__})",
                    "acc_mittel": np.nan,
                    "acc_std": np.nan,
                    "f1_makro_mittel": np.nan,
                    "f1_makro_std": np.nan,
                    "trainingszeit_sek_mittel": np.nan,
                }
            )
            continue

        X_sub = [X_pool[i] for i in idx_sub]
        y_sub = [y_pool[i] for i in idx_sub]

        accs, f1s, zeiten = [], [], []
        for saat in SAATEN:
            acc, f1m, t = trainiere_und_evaluiere_setfit(
                train_texte=X_sub,
                train_klassen=y_sub,
                eval_texte=X_eval,
                eval_klassen=y_eval,
                klassen_liste=klassen_liste,
                basis_modell=BASISMODELL,
                anzahl_iterationen=ANZAHL_ITERATIONEN,
                saat=saat,
            )
            accs.append(acc)
            f1s.append(f1m)
            zeiten.append(t)

        zusammen = fasse_laeufe_zusammen(accs, f1s, zeiten)
        zeilen.append(
            {
                "aufgabe": cfg.aufgabenname,
                "trainingsumfang": n_train,
                "status": "ok",
                **zusammen,
            }
        )

    lernkurve_df = pd.DataFrame(zeilen).sort_values("trainingsumfang").reset_index(drop=True)

    # Plot only the successful points
    ok_df = lernkurve_df[lernkruve_df["status"] == "ok"] if False else lernkurve_df[lernkruve_df["status"] == "ok"]
    if len(ok_df) > 0:
        plt.figure()
        plt.plot(ok_df["trainingsumfang"], ok_df["f1_makro_mittel"], marker="o")
        plt.errorbar(
            ok_df["trainingsumfang"],
            ok_df["f1_makro_mittel"],
            yerr=ok_df["f1_makro_std"],
            fmt="none",
            capsize=3,
        )
        plt.xlabel("Number of training samples")
        plt.ylabel("Macro-F1 (holdout)")
        plt.title(f"Learning curve — {cfg.aufgabenname}")
        plt.show()

    print("\nTable:")
    display(lernkruve_df) if False else None
    display(lernkruve_df := lernkurve_df)

    return lernkurve_df


# Load data once and run all tasks
df = pd.read_excel(DATA_PATH)
df.columns = df.columns.str.strip()

alle_ergebnisse = []

for cfg in tqdm(AUFGABEN, desc="Tasks"):
    out_df = fuehre_lernkurven_fuer_aufgabe_aus(df, cfg)
    if len(out_df) > 0:
        alle_ergebnisse.append(out_df)

# Combine and export results across tasks
if alle_ergebnisse:
    ergebnisse_df = pd.concat(alle_ergebnisse, ignore_index=True)
    print("\nOverall table (all tasks):")
    display(ergebnisse_df)

    out_csv = "/content/setfit_lernkurven_alle_aufgaben.csv"
    ergebnisse_df.to_csv(out_csv, index=False)
    print(f"\nSaved to: {out_csv}")
else:
    print("\nNo results produced (all tasks were skipped).")