In [28]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
from pathlib import Path
import itertools
from os.path import join
import re

import pandas as pd
import numpy as np
from tqdm.auto import tqdm

from sklearn.model_selection import GroupShuffleSplit, RepeatedKFold, RepeatedStratifiedKFold
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from xgboost import XGBRegressor, XGBClassifier
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression, LogisticRegression, Lasso, ElasticNet
from sklearn.base import clone
from sklearn.model_selection import GridSearchCV
from stabl.data import load_onset_of_labor, load_ssi
from stabl.multi_omic_pipelines import multi_omic_stabl_cv_noe_test

# STABL imports
from stabl.adaptive import ALasso
from stabl.stabl import Stabl
from stabl.preprocessing import remove_low_info_samples, LowInfoFilter
from stabl.pipelines_utils import compute_scores_table, save_plots


# -------------------------------------------------------------------
# 1. Définition des estimateurs STABL
# -------------------------------------------------------------------
inner_cv = RepeatedStratifiedKFold(n_splits=3, n_repeats=3, random_state=42)
lasso  = Lasso(max_iter=int(1e6), random_state=42)
en     = ElasticNet(max_iter=int(1e6), random_state=42)
alasso = ALasso(max_iter=int(1e6), random_state=42)
xgb    = XGBRegressor(random_state=42, importance_type="gain", objective="reg:squarederror")

stabl_lasso = Stabl(
    base_estimator=lasso, n_bootstraps=1000,
    artificial_type="knockoff", artificial_proportion=1.0,
    replace=False, fdr_threshold_range=np.arange(0.1,1,0.01),
    sample_fraction=0.5, random_state=42,
    lambda_grid={"alpha": np.logspace(0,2,10)}, verbose=1,
)
stabl_alasso = clone(stabl_lasso).set_params(base_estimator=alasso)
stabl_en     = clone(stabl_lasso).set_params(
    base_estimator=en,
    lambda_grid={"alpha": np.logspace(0.5,2,10), "l1_ratio":[0.5,0.7,0.9]}
)
stabl_xgb    = Stabl(
    base_estimator=xgb, n_bootstraps=1000,
    artificial_type="knockoff", artificial_proportion=1.0,
    replace=False, fdr_threshold_range=np.arange(0.1,1,0.01),
    sample_fraction=0.5, random_state=42,
    lambda_grid={"max_depth":[3,6,9], "reg_alpha":[0,0.5,1,2]},
    verbose=1
)


# -------------------------------------------------------------------
# 2. Fonction de CV sur features existantes (avec union/intersect)
# -------------------------------------------------------------------
def cv_on_existing_feats(
    data_dict, y, outer_splitter,
    estimators, task_type, model_chosen,
    models, fold_feats_path, save_path,
    outer_groups=None, early_fusion=False,
    late_fusion=False, n_iter_lf=10000,
    use_ega=False
):
    logit   = LogisticRegression(max_iter=1000)
    linreg  = LinearRegression()
    rf_est  = estimators["random_forest"]
    xgb_est = estimators["xgboost"]

    os.makedirs(Path(save_path, "Training CV"), exist_ok=True)
    os.makedirs(Path(save_path, "Summary"),     exist_ok=True)

    X_tot = pd.concat(data_dict.values(), axis=1)
    preds_dict   = {m: pd.DataFrame(index=y.index) for m in models}
    feats_dict   = {m: [] for m in models}

    # Lecture des sélections STABL originales
    raw_sel = {}
    for m in models:
        if m.startswith("STABL "):
            df = pd.read_csv(Path(fold_feats_path,"Training CV",f"Selected Features {m}.csv"), index_col=0)
            raw_sel[m] = [
                eval(s) if isinstance(s,str) else s
                for s in df["Fold selected features"]
            ]

    n_splits = outer_splitter.get_n_splits(X=X_tot, y=y, groups=outer_groups)

    for fold_idx, (train_i, test_i) in enumerate(tqdm(
        outer_splitter.split(X_tot, y, groups=outer_groups),
        total=n_splits, desc="Outer CV"
    )):
        train_ids = y.iloc[train_i].index
        test_ids  = y.iloc[test_i].index

        # On récupère pour chaque modèle la liste de features
        for m in models:
            if m.startswith("Union "):
                # "Union STABL Lasso & STABL ALasso" → ['STABL Lasso','STABL ALasso']
                bases = m.replace("Union ","").split(" & ")
                sets  = [ set(raw_sel[b][fold_idx]) for b in bases ]
                feats = list(set.union(*sets))
            elif m.startswith("Intersect "):
                bases = m.replace("Intersect ","").split(" & ")
                sets  = [ set(raw_sel[b][fold_idx]) for b in bases ]
                feats = list(set.intersection(*sets))
            elif m.startswith("STABL "):
                feats = raw_sel[m][fold_idx]
            else:
                # non-STABL → on peut choisir all‑features ou fallback
                feats = list(X_tot.columns)

            # Fit / predict
            if len(feats)==0:
                val = (0.5 if task_type=="binary" else np.mean(y.loc[train_ids]))
                preds_dict[m].loc[test_ids, f"Fold_{fold_idx}"] = val
            else:
                pipe = Pipeline([
                    ("imputer", SimpleImputer(strategy="median")),
                    ("scaler",  StandardScaler())
                ])
                Xtr = pd.DataFrame(pipe.fit_transform(X_tot.loc[train_ids, feats]),
                                   index=train_ids, columns=feats)
                Xte = pd.DataFrame(pipe.transform(X_tot.loc[test_ids, feats]),
                                   index=test_ids,  columns=feats)

                key = (model_chosen or "logit").lower()

                if task_type == "binary":
                    est = {"xgboost": xgb_est, "random_forest": rf_est, "logit": logit}.get(model_chosen, logit)
                    pr  = clone(est).fit(Xtr, y.loc[train_ids]).predict_proba(Xte)[:, 1]
                else:
                    est = {"xgboost": xgb_est, "random_forest": rf_est}.get(model_chosen, linreg)
                    pr  = clone(est).fit(Xtr, y.loc[train_ids]).predict(Xte)

                preds_dict[m].loc[test_ids, f"Fold_{fold_idx}"] = pr

            feats_dict[m].append(feats)

    # Sauvegarde des listes
    cv_dir = Path(save_path,"Training CV")
    for m in models:
        dfm = pd.DataFrame({
            "Fold selected features": feats_dict[m],
            "Fold #features": [len(f) for f in feats_dict[m]]
        }, index=[f"Fold_{i}" for i in range(n_splits)])
        dfm.to_csv(cv_dir/f"Selected Features {m}.csv")

    # Scores & plots
    summary_dir = Path(save_path,"Summary")
    med_preds = {m: preds_dict[m].median(axis=1) for m in models}
    scores = compute_scores_table(
        predictions_dict=med_preds,
        y=y,
        task_type=task_type,
        selected_features_dict=None
    )
    scores.to_csv(summary_dir / "Scores_training_CV.csv")
    save_plots(
        predictions_dict=med_preds,
        y=y,
        task_type=task_type,
        save_path=cv_dir
     )

    return med_preds


# -------------------------------------------------------------------
# 3. Point d'entrée
# -------------------------------------------------------------------
if __name__=="__main__":
    # 3.1 
    X_PATH     = Path("/Users/noeamar/Documents/Stanford/data/olivier_data/ina_13OG_final_long_allstims_filtered (6).csv")
    Y_PATH     = Path("/Users/noeamar/Documents/Stanford/data/olivier_data/outcome_table_all_pre.csv")
    FOLD_FEATS = Path("/Users/noeamar/Documents/Stanford/Sherlock Results/Maigane/results_maigane_normalized_001_omics/control_vs_severe/MAIGANE_control_vs_severe_Functional_xgboost_knockoff_total_cover_GSS/Functional")
    SAVE_ROOT  = Path("/Users/noeamar/Documents/Stanford/Fitting results/Maigane data/Control_vs_severe/Functional KO_total_cover + Logit")
 
    # # Load INA
    # X = pd.read_csv(X_PATH, index_col=0)
    # y = pd.read_csv(Y_PATH, index_col=0).squeeze()
    # common = X.index.intersection(y.index)
    # X, y = X.loc[common], y.loc[common]

    # # # 3.2 Splitter & data_dict
    # task_type = "regression"
    # groups   = X.index.to_series().str.split("_").str[0]
    # splitter = GroupShuffleSplit(n_splits=100, test_size=0.2, random_state=42)
    # data_dict= {"allstim": X}

    # # 1) Load OOL
    # train_dict, valid_data_dict, y, y_valid, patients_id, task_type = load_onset_of_labor(
    # "/Users/noeamar/Documents/Stanford/data/Onset of Labor"
    # )

    # cyto_train = train_dict["CyTOF"]
    # prot_train = train_dict["Proteomics"].copy()
    # # Préfixe PRO_ pour distinguer les colonnes
    # prot_train.columns = [f"{c}" for c in prot_train.columns]

    # # 2) Intersection des index sur (y, cyto, prot)
    # common = y.index
    # for df in (cyto_train, prot_train):
    #     common = common.intersection(df.index)

    # y = y.loc[common]
    # cyto_train = cyto_train.loc[common]
    # prot_train = prot_train.loc[common]

    # # 3) Early-fusion: un seul X
    # X = pd.concat([cyto_train, prot_train], axis=1)

    # # 4) Splitter & data_dict (un seul bloc "allstim")
    # groups   = X.index.to_series().str.split("_").str[0]
    # splitter = GroupShuffleSplit(n_splits=25, test_size=0.2, random_state=42)
    # data_dict = {"allstim": X}

    # # Load SSI
    # train_data_dict, _, y, _, _, task_type = load_ssi("/Users/noeamar/Documents/Stanford/data/Biobank SSI")  # task_type = "binary"
    
    # # Assainir les labels
    # # y = y.astype(int)
    # # vals = sorted(set(y.unique()))
    # # print("Labels uniques:", vals)  # doit afficher [0, 1]
    # # assert set(vals) <= {0, 1}

    # # # 3) Early-fusion: concat CyTOF + Proteomics en un seul X
    # cyto = train_data_dict["CyTOF"]
    # prot = train_data_dict["Proteomics"].copy()
    # # print("Taille de CyTOF :", cyto.shape)
    # # print("Taille de Proteomics :", prot.shape)
    # # (Option recommandé) Préfixer la protéomique pour éviter toute collision de noms :
    # prot.columns = [f"PRO_{c}" for c in prot.columns]

    # # 4) Aligner les index sur l’intersection commune (y, cyto, prot)
    # common = y.index
    # for df in (cyto, prot):
    #     common = common.intersection(df.index)

    # y = y.loc[common]
    # X = pd.concat([cyto.loc[common], prot.loc[common]], axis=1)

    # X=pd.concat([cyto, prot], axis=1)
    # print("Taille de X :", X.shape)
    # print("Taille de y :", y.shape)

    # print("\n--- VERSION INTERSECTION COMMUNE ---")

    # # Index de départ
    # print("Index y :", y.index.tolist())
    # print("Index cyto :", cyto.index.tolist())
    # print("Index prot :", prot.index.tolist())

    # # Intersection progressive
    # common = y.index
    # print("\nÉtape 0 - common (y.index) :", common.tolist())

    # for i, df in enumerate((cyto, prot), start=1):
    #     common = common.intersection(df.index)
    #     print(f"Étape {i} - intersection avec df{i} :", common.tolist())

    # # Sélection finale
    # y_common = y.loc[common]
    # X_common = pd.concat([cyto.loc[common], prot.loc[common]], axis=1)

    # print("\nTaille finale y_common :", y_common.shape)
    # print("Taille finale X_common :", X_common.shape)
    # print("Index final commun :", common.tolist())

    # # Exemple de premières lignes
    # print("\nAperçu X_common :")
    # print(X_common.head())

    # # ------------------------
    # print("\n--- VERSION CONCAT SIMPLE ---")
    # X_simple = pd.concat([cyto, prot], axis=1)

    # print("Taille X_simple :", X_simple.shape)
    # print("Index X_simple :", X_simple.index.tolist())

    # # Vérification des NaN éventuels
    # nan_counts = X_simple.isna().sum().sum()
    # print("Nombre total de NaN dans X_simple :", nan_counts)

    # # Option) S’assurer que y est bien 0/1
    # y = y.astype(int)

    # # 5) Splitter & data_dict (identique à ton snippet)
    # groups   = X_simple.index.to_series().str.split("_").str[0]
    # #splitter = GroupShuffleSplit(n_splits=100, test_size=0.2, random_state=42)
    # splitter=RepeatedStratifiedKFold(n_splits=5, n_repeats=20, random_state=42)
    # data_dict = {"allstim": X_simple}

    # # Load COVID-19
    # from stabl.data import load_covid_19
    # X_train, X_valid, y_train, y_valid, ids, task_type = load_covid_19("/Users/noeamar/Documents/Stanford/data/COVID-19")
    # # Assainir les labels
    # y_train = y_train.astype(int)
    # y_valid = y_valid.astype(int)
    # vals = sorted(set(y_valid.unique()))
    # print("Labels uniques:", vals)  # doit afficher [0, 1]
    # assert set(vals) <= {0, 1} 
    # data_dict = X_valid
    # # If X_valid is a dict, print the shape of each DataFrame inside
    # if isinstance(X_valid, dict):
    #     for k, v in X_valid.items():
    #         print(f"X_valid[{k}] shape:", v.shape)
    # else:
    #     print("X_valid shape:", X_valid.shape)

    # if isinstance(X_train, dict):
    #     for k, v in X_valid.items():
    #         print(f"X_train[{k}] shape:", v.shape)
    # else:
    #     print("X_train shape:", X_valid.shape)
    # splitter = RepeatedStratifiedKFold(n_splits=5, n_repeats=20, random_state=42)
    # y= y_valid
    # groups = ids

    # Load CFRNA
    # from stabl.data import load_cfrna
    # data_dict, valid_data_dict, y, y_valid, groups, task_type = load_cfrna("/Users/noeamar/Documents/Stanford/data/CFRNA")
    # # Assainir les labels
    # y = y.astype(int)
    # vals = sorted(set(y.unique()))
    # print("Labels uniques:", vals)  # doit afficher [0, 1]
    # assert set(vals) <= {0, 1}

    #splitter  = GroupShuffleSplit(n_splits=100, test_size=0.2, random_state=42)

    # Load Maigane dataset

    def load_maigane(data_path, label_mode="control_vs_em", miss_thresh=0.99):
        """
        Charge et préprocess MAIGANE puis renvoie:
        (X_dict, None, y, None, groups, "binary")

        Paramètres
        ----------
        data_path : str
            Dossier contenant les CSV (e.g. ".../Maigane data")
        label_mode : {"control_vs_em","control_vs_mild","control_vs_severe","mild_vs_severe"}
            - "control_vs_em"      : Control (0) vs EM (1) via ControlVsEM.csv
            - "control_vs_mild"    : Control (0) vs I&II (1)
            - "control_vs_severe"  : Control (0) vs III&IV (1)
            - "mild_vs_severe"     : I&II (0) vs III&IV (1)
        miss_thresh : float ∈ (0,1]
            Seuil max de fraction de NaN autorisée par colonne (par défaut 0.99)
        """
        import re
        import numpy as np
        import pandas as pd
        from os.path import join

        # -------- helpers --------
        def _sanitize_cols(cols):
            def clean(c):
                s = str(c).strip().strip('"').strip("'")
                s = re.sub(r'[\\/:*?"<>|]+', '_', s)   # caractères interdits → _
                s = re.sub(r'\s+', '_', s)            # espaces multiples → _
                s = re.sub(r'_+', '_', s)             # __ → _
                if s == '' or s.lower().startswith('unnamed'):
                    s = 'feature'
                return s
            cleaned = [clean(c) for c in cols]
            # déduplication stable
            seen, uniq = {}, []
            for s in cleaned:
                if s in seen:
                    uniq.append(f"{s}__{seen[s]}")
                    seen[s] += 1
                else:
                    uniq.append(s)
                    seen[s] = 1
            return uniq

        def _read_omic(path):
            # lecture tolérante (auto-sep), pas de low_memory avec engine="python"
            df0 = pd.read_csv(path, engine="python", sep=None,
                            na_values=["NA", "NaN", ""], keep_default_na=True)
            first = df0.columns[0]
            cond_unnamed = first in ("", " ", "Unnamed: 0", "Unnamed: 0.1", "Unnamed: 1")
            cond_slide   = df0[first].astype(str).str.startswith("slide").any()

            if cond_unnamed or cond_slide:
                df = df0.set_index(first)
            else:
                df = pd.read_csv(path, engine="python", sep=None,
                                na_values=["NA", "NaN", ""], keep_default_na=True,
                                index_col=0)

            # nettoie index
            idx = pd.Index(df.index).astype(str).str.strip().str.strip('"').str.strip("'")
            df.index = idx
            df = df[~df.index.duplicated(keep="first")]

            # sanitise colonnes
            df.columns = _sanitize_cols(df.columns)

            # conserve numérique uniquement
            df = df.select_dtypes(include=[np.number])

            # remplace inf
            df = df.replace([np.inf, -np.inf], np.nan)

            # filtre colonnes trop manquantes
            if len(df) > 0:
                frac_nan = df.isna().mean(axis=0)
                keep = frac_nan[frac_nan <= miss_thresh].index
                df = df[keep]

            # drop 100% NaN
            df = df.dropna(axis=1, how="all")

            # drop variance nulle (en ignorant NaN)
            if df.shape[1] > 0:
                var = df.var(axis=0, ddof=1, numeric_only=True)
                df = df.loc[:, var > 0]

            return df

        # -------- labels sources --------
        # Control vs EM
        y1 = pd.read_csv(join(data_path, "ControlVsEM.csv"),
                        na_values=["NA", "NaN", ""], keep_default_na=True)
        if "sampleID" not in y1.columns or "EM" not in y1.columns:
            raise ValueError("ControlVsEM.csv doit contenir les colonnes 'sampleID' et 'EM'.")
        y1 = y1.set_index("sampleID")["EM"]
        if y1.dtype.kind not in "iu":
            y1 = y1.map({"Control": 0, "EM": 1})
        y1 = y1.astype("Int64")
        y1.index = pd.Index(y1.index).astype(str).str.strip().str.strip('"').str.strip("'")

        # Stage I&II vs III&IV (uniquement EM)
        y2 = pd.read_csv(join(data_path, "StageI&IIVsStageIII&IV.csv"),
                        na_values=["NA", "NaN", ""], keep_default_na=True)
        if "sampleID" in y2.columns:
            y2 = y2.set_index("sampleID")
        if "Stage" not in y2.columns:
            raise ValueError("StageI&IIVsStageIII&IV.csv doit contenir la colonne 'Stage'.")
        y2 = y2["Stage"]
        if y2.dtype.kind not in "iu":
            y2 = y2.map({"I&II": 0, "III&IV": 1})
        y2 = y2.astype("Int64")
        y2.index = pd.Index(y2.index).astype(str).str.strip().str.strip('"').str.strip("'")

        # -------- construire y selon label_mode --------
        mode = label_mode.lower()
        if mode == "control_vs_em":
            # simple binaire Control(0) vs EM(1)
            y = y1.dropna().astype(int)

        elif mode == "control_vs_mild":
            # Control(0) vs I&II(1)
            idx = sorted(set(y1.index[y1 == 0]).union(set(y2.index[y2 == 0])))
            y = pd.Series(pd.NA, index=idx, dtype="Int64")
            y.loc[y1.index[y1 == 0]] = 0
            y.loc[y2.index[y2 == 0]] = 1
            y = y.dropna().astype(int)

        elif mode == "control_vs_severe":
            # Control(0) vs III&IV(1)
            idx = sorted(set(y1.index[y1 == 0]).union(set(y2.index[y2 == 1])))
            y = pd.Series(pd.NA, index=idx, dtype="Int64")
            y.loc[y1.index[y1 == 0]] = 0
            y.loc[y2.index[y2 == 1]] = 1
            y = y.dropna().astype(int)

        elif mode == "mild_vs_severe":
            # I&II(0) vs III&IV(1) sur patients EM
            y = y2.dropna().astype(int)

        else:
            raise ValueError(
                "label_mode inconnu. Utilise l'un de : "
                "'control_vs_em', 'control_vs_mild', 'control_vs_severe', 'mild_vs_severe'."
            )

        # -------- omiques --------
        file_map = {
            "CellDensities": "EMIMCmdiop_celldensities.csv",
            "Functional":    "EMIMCmdiop_functional.csv",
            "Metavariables": "EMIMCmdiop_metavariables.csv",
            "Neighborhood":  "EMIMCmdiop_neighborhood.csv",
        }
        X = {k: _read_omic(join(data_path, fname)) for k, fname in file_map.items()}

        # -------- alignement IDs --------
        common = pd.Index(y.index)
        for df in X.values():
            common = common.intersection(df.index)
        if common.empty:
            raise ValueError("Aucun échantillon commun entre omiques et y.")
        y = y.loc[common]
        X = {k: df.loc[common] for k, df in X.items()}

        # -------- groups robustes depuis l'ID --------
        def _extract_group(e: str):
            parts = re.split(r'[_\s]+', e)
            return parts[3] if len(parts) > 3 else parts[-1]

        groups = pd.Series({_id: _extract_group(_id) for _id in y.index})

        return X, None, y, None, groups, "binary"


    data_dict, _, y, _, groups, task_type = load_maigane("./data/Maigane data", label_mode="control_vs_severe")
    
    splitter = GroupShuffleSplit(n_splits=100, test_size=0.2, random_state=42)

    # 3.3 Liste initiale de modèles STABL
    stabl_models = ["STABL Lasso", "STABL XGBoost", "STABL ALasso"]
    #stabl_models = ["STABL XGBoost"]
    base_models  = stabl_models

    xgb_param_grid = {
    "n_estimators": [170],
    "learning_rate": [0.005],
    "max_depth": [6],
    "subsample": [1],
    "colsample_bytree": [1]
    }

    rf_reg = RandomForestRegressor(n_estimators=600, max_depth=8, min_samples_split=2, min_samples_leaf=2, max_features="sqrt", ccp_alpha=1e-3, random_state=42, n_jobs=-1)
    xgb_reg = XGBRegressor(n_estimators=800, learning_rate=0.01, max_depth=3, subsample=0.6, colsample_bytree=0.5 ,random_state=42, n_jobs=-1)
    xgb_test_reg = XGBRegressor(n_jobs=-1)
    linreg  = LinearRegression(fit_intercept=True, copy_X=False, positive=True, n_jobs=-1)
    rf_cls = RandomForestClassifier(n_estimators=600, max_depth=8, min_samples_split=2, min_samples_leaf=2, max_features="sqrt", ccp_alpha=1e-3, random_state=42, n_jobs=-1)
    xgb_cls = XGBClassifier(n_estimators=800, learning_rate=0.01, max_depth=3, subsample=0.6, colsample_bytree=0.5 ,random_state=42, n_jobs=-1)
    Logit = LogisticRegression(penalty=None, class_weight="balanced", max_iter=int(1e6), random_state=42)
    xgb_test =XGBClassifier(n_estimators=200, importance_type="gain", n_jobs=-1)
    xgb_test2= XGBClassifier(n_estimators=300, max_depth=4, learning_rate=0.05, subsample=0.8, colsample_bytree=0.8, reg_alpha=0.0, reg_lambda=1.0, eval_metric="auc", n_jobs=-1, random_state=42)
    #xgb_grid = GridSearchCV(estimator=xgb_cls, param_grid=xgb_param_grid, cv=inner_cv, scoring="roc_auc", n_jobs=-1, verbose=2, refit=True)
    # 3.4 Génération autom. des combos 2‑à‑2 et 3‑à‑3
    combos = []
    for r in [2]:
        for group in itertools.combinations(stabl_models, r):
            sep = " & "
            combos += [f"Union {sep.join(group)}", f"Intersect {sep.join(group)}"]
            # Calcul du nombre de features pour chaque combo
            # Calcul du nombre de features pour chaque combo, par fold
            for combo in combos:
                bases = combo.replace("Union ", "").replace("Intersect ", "").split(sep)
                # Lecture des features sélectionnés pour chaque modèle de base, par fold
                fold_feats = []
                for b in bases:
                    df = pd.read_csv(Path(FOLD_FEATS, "Training CV", f"Selected Features {b}.csv"), index_col=0)
                    # Chaque ligne = liste de features pour le fold
                    fold_feats.append(df["Fold selected features"].apply(eval))
                # Calcul union/intersection par fold
                fold_counts = []
                for fold_idx in range(len(fold_feats[0])):
                    sets_fold = [set(f[fold_idx]) if isinstance(f[fold_idx], (list, set)) else set() for f in fold_feats]
                    if combo.startswith("Union "):
                        feats_fold = set.union(*sets_fold)
                    else:
                        feats_fold = set.intersection(*sets_fold)
                    fold_counts.append(len(feats_fold))
                print(f"{combo}: Moyenne={np.mean(fold_counts):.1f}, Médiane={np.median(fold_counts):.1f}")
            print(len(combos), f"→ {combos[-2]} / {combos[-1]}")

    models = base_models + combos

    # 3.5 Estimateurs à passer
    estimators = {
        "Logit" : Logit,
        "random_forest": rf_cls,
        "xgboost":       xgb_cls,
        "stabl_alasso":  stabl_alasso,
        "stabl_en":      stabl_en,
        "linreg":      linreg,
        
    }

    # 3.6 Lancement
    preds = cv_on_existing_feats(
        data_dict       = data_dict,
        y               = y,
        outer_splitter  = splitter,
        estimators      = estimators,
        task_type       = task_type,
        model_chosen    = "Logit",  # Choix du modèle de base pour les non-STABL
        models          = models,
        fold_feats_path = FOLD_FEATS,
        save_path       = SAVE_ROOT,
        outer_groups    = groups,
    )

    print("→ Terminé, prédictions agrégées renvoyées.")


Union STABL Lasso & STABL XGBoost: Moyenne=36.7, Médiane=18.0
Intersect STABL Lasso & STABL XGBoost: Moyenne=0.4, Médiane=0.0
2 → Union STABL Lasso & STABL XGBoost / Intersect STABL Lasso & STABL XGBoost
Union STABL Lasso & STABL XGBoost: Moyenne=36.7, Médiane=18.0
Intersect STABL Lasso & STABL XGBoost: Moyenne=0.4, Médiane=0.0
Union STABL Lasso & STABL ALasso: Moyenne=55.9, Médiane=30.0
Intersect STABL Lasso & STABL ALasso: Moyenne=21.9, Médiane=10.0
4 → Union STABL Lasso & STABL ALasso / Intersect STABL Lasso & STABL ALasso
Union STABL Lasso & STABL XGBoost: Moyenne=36.7, Médiane=18.0
Intersect STABL Lasso & STABL XGBoost: Moyenne=0.4, Médiane=0.0
Union STABL Lasso & STABL ALasso: Moyenne=55.9, Médiane=30.0
Intersect STABL Lasso & STABL ALasso: Moyenne=21.9, Médiane=10.0
Union STABL XGBoost & STABL ALasso: Moyenne=41.5, Médiane=19.0
Intersect STABL XGBoost & STABL ALasso: Moyenne=0.4, Médiane=0.0
6 → Union STABL XGBoost & STABL ALasso / Intersect STABL XGBoost & STABL ALasso


Outer CV: 100%|██████████| 100/100 [00:03<00:00, 26.64it/s]


STABL Lasso slide10_P01endoIMC_mgd_UC724_1_4    0.255490
slide10_P01endoIMC_mgd_UC724_2_5    0.038612
slide10_P01endoIMC_mgd_UC724_3_6    0.139699
slide10_P01endoIMC_mgd_UC738_1_1    0.390969
slide10_P01endoIMC_mgd_UC738_2_2    0.323937
                                      ...   
slide8_P01endoIMC_mgd_UC272_2_2     0.259649
slide8_P01endoIMC_mgd_UC272_3_3     0.272123
slide9_P01endoIMC_mgd_NCB024_1_1    0.518708
slide9_P01endoIMC_mgd_NCB024_2_2    0.727690
slide9_P01endoIMC_mgd_NCB024_3_3    0.716713
Length: 188, dtype: float64
STABL XGBoost slide10_P01endoIMC_mgd_UC724_1_4    0.5
slide10_P01endoIMC_mgd_UC724_2_5    0.5
slide10_P01endoIMC_mgd_UC724_3_6    0.5
slide10_P01endoIMC_mgd_UC738_1_1    0.5
slide10_P01endoIMC_mgd_UC738_2_2    0.5
                                   ... 
slide8_P01endoIMC_mgd_UC272_2_2     0.5
slide8_P01endoIMC_mgd_UC272_3_3     0.5
slide9_P01endoIMC_mgd_NCB024_1_1    0.5
slide9_P01endoIMC_mgd_NCB024_2_2    0.5
slide9_P01endoIMC_mgd_NCB024_3_3    0.5
Length: 188

In [None]:
# %% [markdown]
# --- CONFIG ---
# Renseigne ces 3 chemins + options. Laisse le reste par défaut.

DATA_ROOT       = "/Users/noeamar/Documents/Stanford/data/Biobank SSI"  # dossier du dataset
FOLD_FEATS_PATH = "/Users/noeamar/Documents/Stanford/results_cls_24c_CyTOF_xgboost_knockoff_shap_CLS"  # run existant avec Selected Features *.csv
SAVE_ROOT       = "results_CI_SSI_vf_xgb_shap_existing_notebook"  # dossier de sortie pour CE run

DATASET         = "ssi"               # "ssi" (par défaut) ou "onset"
MODEL_CHOSEN    = "xgboost"          # pour la comparaison 'existing-fit' : "xgboost" | "random_forest" | "logit"
N_SPLITS        = 100
TEST_SIZE       = 0.2
RANDOM_STATE    = 42

# Optionnel : ajouter des combos Union/Intersect entre STABL models si présents dans FOLD_FEATS_PATH
ADD_COMBOS      = False  # True pour générer Union/Intersect 2-à-2


# %%
import os
from pathlib import Path
import itertools
import numpy as np
import pandas as pd

from sklearn.model_selection import GroupShuffleSplit
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from xgboost import XGBClassifier, XGBRegressor

# Dataset loaders & STABL pipeline
from stabl.data import load_ssi, load_onset_of_labor
from stabl.multi_omic_pipelines import multi_omic_stabl_cv_noe_test  # <- ta fonction modifiée

# %%
DATA_ROOT = Path(DATA_ROOT)
FOLD_FEATS_PATH = Path(FOLD_FEATS_PATH)
SAVE_ROOT = Path(SAVE_ROOT)
SAVE_ROOT.mkdir(parents=True, exist_ok=True)

if DATASET == "ssi":
    train_dict, _, y, _, _, task_type = load_ssi(str(DATA_ROOT))
    y = y.astype(int)
    assert set(y.unique()) <= {0, 1}, f"Labels non binaires: {sorted(y.unique())}"
    cyto = train_dict["CyTOF"]
    prot = train_dict["Proteomics"].copy()
    prot.columns = [f"PRO_{c}" for c in prot.columns]  # évite collisions de noms
    common = y.index
    for df in (cyto, prot):
        common = common.intersection(df.index)
    y = y.loc[common]
    X = pd.concat([cyto.loc[common], prot.loc[common]], axis=1)
else:
    train_dict, valid_data_dict, y, y_valid, patients_id, task_type = load_onset_of_labor(str(DATA_ROOT))
    cyto = train_dict["CyTOF"]
    prot = train_dict["Proteomics"].copy()
    prot.columns = [f"PRO_{c}" for c in prot.columns]
    common = y.index
    for df in (cyto, prot):
        common = common.intersection(df.index)
    y = y.loc[common]
    X = pd.concat([cyto.loc[common], prot.loc[common]], axis=1)

print(f"X shape: {X.shape} | y: {y.shape} | task_type: {task_type}")
display(X.head(2))
display(y.value_counts())


# %%
# Groupes par patient (avant l'underscore)
groups = X.index.to_series().str.split("_").str[0]
splitter = GroupShuffleSplit(n_splits=N_SPLITS, test_size=TEST_SIZE, random_state=RANDOM_STATE)

# Early-fusion dans un seul "omic"
data_dict = {"allstim": X}

# On propose comme base ces modèles; on filtre ceux dont les CSV existent réellement
candidate_stabl = ["STABL Lasso", "STABL ElasticNet", "STABL ALasso", "STABL XGBoost"]
available = []
for m in candidate_stabl:
    csv_path = FOLD_FEATS_PATH / "Training CV" / f"Selected Features {m}.csv"
    if csv_path.exists():
        available.append(m)
    else:
        print(f"[WARN] Missing features file for {m}: {csv_path}")

if not available:
    raise FileNotFoundError(
        f"Aucun CSV 'Selected Features {{model}}.csv' trouvé dans {FOLD_FEATS_PATH/'Training CV'}.\n"
        "Vérifie les noms exacts de modèles et le chemin 'fold_feats_path'."
    )

# Optionnel : Union/Intersect 2-à-2 des modèles disponibles
models = list(available)
if ADD_COMBOS and len(available) >= 2:
    sep = " & "
    combos = []
    for group in itertools.combinations(available, 2):
        combos += [f"Union {sep.join(group)}", f"Intersect {sep.join(group)}"]
    models += combos

print("Models utilisés :", models)

# %%
# Estimateurs pour le refit (utilisés par la comparaison "existing-fit" et par le fit de notre run)
rf_cls = RandomForestClassifier(
    n_estimators=600, max_depth=8, min_samples_split=2, min_samples_leaf=2,
    max_features="sqrt", ccp_alpha=1e-3, random_state=RANDOM_STATE, n_jobs=-1
)
xgb_cls = XGBClassifier(
    n_estimators=800, learning_rate=0.01, max_depth=3,
    subsample=0.6, colsample_bytree=0.5, random_state=RANDOM_STATE, n_jobs=-1,
    eval_metric="logloss", use_label_encoder=False
)

# Dictionnaire des estimateurs — seules ces clés sont requises en "skip selection"
estimators = {
    "rf": rf_cls,
    "xgb": xgb_cls,
    # noms attendus par cv_on_existing_feats (la fonction mappe rf/xgb → ces clés si besoin)
    "random_forest": rf_cls,
    "xgboost": xgb_cls,
    
    # placeholders (non utilisés car on skippe la sélection)
    "lasso": None, "en": None, "alasso": None,
    "stabl_lasso": None, "stabl_alasso": None, "stabl_en": None, "stabl_xgb": None,
}

print("Estimators ready.")

# %%
print("[INFO] Starting CV with skip-selection (using existing STABL features)...")
preds = multi_omic_stabl_cv_noe_test(
    data_dict       = data_dict,
    y               = y,
    outer_splitter  = splitter,
    estimators      = estimators,
    task_type       = task_type,            # "binary" pour SSI
    model_chosen    = MODEL_CHOSEN,         # "xgboost" recommandé
    models          = models,
    save_path       = str(SAVE_ROOT),
    outer_groups    = groups,
    early_fusion    = False,
    late_fusion     = False,
    n_iter_lf       = 100000,
    fold_feats_path = str(FOLD_FEATS_PATH), # là où sont les Selected Features {model}.csv
    use_existing_feats_only = True,         # <<< SKIP la sélection; fit uniquement
)
print("✓ Terminé.")


  from tqdm.autonotebook import tqdm


X shape: (91, 1846) | y: (91,) | task_type: binary


Unnamed: 0_level_0,unstim_Baso_149Sm_CREB,unstim_Baso_150Nd_STAT5,unstim_Baso_151Eu_p38,unstim_Baso_153Eu_STAT1,unstim_Baso_154Sm_STAT3,unstim_Baso_155Gd_S6,unstim_Baso_159Tb_MAPKAPK2,unstim_Baso_164Dy_IkB,unstim_Baso_166Er_NFkB,unstim_Baso_167Er_ERK,...,PRO_SAT1,PRO_NFKB1,PRO_CDKN2B,PRO_RAP2A,PRO_XRCC4,PRO_ARID1A,PRO_EGLN1,PRO_TOPBP1,PRO_SLC22A16,PRO_IRF6
sampleID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
BBCR0002,0.632924,0.344065,0.028052,0,0.0,1.143817,0.150393,0.80114,0.68125,0.031901,...,9.571184,11.314243,10.127736,9.873905,8.62315,9.615262,10.399064,8.828454,12.143543,11.980247
BBCR0003,0.921823,0.677804,0.041205,0,0.000672,1.614141,0.138246,0.966064,0.562168,0.0,...,9.349171,13.429106,9.698531,9.763378,8.799929,9.465362,10.176298,8.386294,12.56198,11.355186


model1b
0    75
1    16
Name: count, dtype: int64

Models utilisés : ['STABL Lasso', 'STABL ElasticNet', 'STABL ALasso', 'STABL XGBoost']
Estimators ready.
[INFO] Starting CV with skip-selection (using existing STABL features)...
72 train samples, 19 test samples:   0%|          | 0/100 [00:00<?, ?it/s]~~~~~~~~~~~~~~~~~~~
This fold: 28 features selected for STABL Lasso
This fold: 108 features selected for STABL ElasticNet
This fold: 0 features selected for STABL ALasso
This fold: 10 features selected for STABL XGBoost
~~~~~~~~~~~~~~~~~~~

72 train samples, 19 test samples:   1%|          | 1/100 [00:00<00:18,  5.38it/s]~~~~~~~~~~~~~~~~~~~
This fold: 37 features selected for STABL Lasso
This fold: 51 features selected for STABL ElasticNet
This fold: 0 features selected for STABL ALasso
This fold: 14 features selected for STABL XGBoost
~~~~~~~~~~~~~~~~~~~

72 train samples, 19 test samples:   2%|▏         | 2/100 [00:00<00:16,  6.12it/s]~~~~~~~~~~~~~~~~~~~
This fold: 23 features selected for STABL Lasso
This fold: 109 features selected f

Outer CV: 100%|██████████| 100/100 [01:22<00:00,  1.21it/s]


STABL Lasso sampleID
BBCR0002      0.156201
BBCR0003      0.070953
BBCR0005-1    0.036685
BBCR0006      0.037055
BBCR0007      0.054711
                ...   
BBCR0319      0.213661
BBCR0327      0.019168
BBCR0328      0.034958
BBCR0329      0.260888
BBCR0338      0.275634
Length: 91, dtype: float32
STABL ElasticNet sampleID
BBCR0002      0.184156
BBCR0003      0.108035
BBCR0005-1    0.036078
BBCR0006      0.052397
BBCR0007      0.087836
                ...   
BBCR0319      0.115432
BBCR0327      0.027934
BBCR0328      0.052606
BBCR0329      0.187648
BBCR0338      0.112948
Length: 91, dtype: float32
STABL ALasso sampleID
BBCR0002      0.5
BBCR0003      0.5
BBCR0005-1    0.5
BBCR0006      0.5
BBCR0007      0.5
             ... 
BBCR0319      0.5
BBCR0327      0.5
BBCR0328      0.5
BBCR0329      0.5
BBCR0338      0.5
Length: 91, dtype: float64
STABL XGBoost sampleID
BBCR0002      0.082850
BBCR0003      0.049179
BBCR0005-1    0.043061
BBCR0006      0.092490
BBCR0007      0.041539
        

In [2]:
# %%
scores_csv = SAVE_ROOT / "Summary" / "Scores training CV.csv"
if scores_csv.exists():
    scores_df = pd.read_csv(scores_csv, index_col=0)
    display(scores_df)
else:
    print(f"[WARN] Fichier scores introuvable: {scores_csv}")


Unnamed: 0,ROC AUC,Average Precision,N features,CVS
STABL Lasso,"0.863 [0.769, 0.937]","0.555 [0.331, 0.788]","38.000 [30.000, 48.000]","0.250 [0.203, 0.302]"
STABL ElasticNet,"0.732 [0.603, 0.844]","0.401 [0.220, 0.639]","165.000 [87.250, 252.750]","0.137 [0.070, 0.251]"
STABL ALasso,"0.500 [0.500, 0.500]","0.176 [0.110, 0.253]","0.000 [0.000, 0.000]","0.000 [0.000, 0.000]"
STABL XGBoost,"0.932 [0.852, 0.988]","0.847 [0.684, 0.957]","8.000 [6.000, 11.000]","0.125 [0.067, 0.200]"
STABL Lasso (existing-fit),"0.884 [0.804, 0.953]","0.597 [0.363, 0.821]","38.000 [30.000, 48.000]","0.033 [0.027, 0.039]"
STABL ElasticNet (existing-fit),"0.769 [0.662, 0.865]","0.335 [0.200, 0.556]","165.000 [87.250, 252.750]","0.006 [0.005, 0.009]"
STABL ALasso (existing-fit),"0.500 [0.500, 0.500]","0.176 [0.099, 0.253]","0.000 [0.000, 0.000]","1.000 [1.000, 1.000]"
STABL XGBoost (existing-fit),"0.954 [0.893, 0.995]","0.881 [0.720, 0.979]","8.000 [6.000, 11.000]","0.108 [0.090, 0.130]"


In [11]:
import glob, ast, statistics, os
import pandas as pd
from pathlib import Path

# --- 1) dossier contenant "Training CV/Selected Features *.csv"
ROOT = Path("/Users/noeamar/Documents/Stanford/Fitting results/SSI/SSI normalized RP_0.01_100 + Logit no intersect")   # ← à adapter
CSV_PATTERN = ROOT / "Training CV" / "Selected Features *.csv"

def to_list(cell):
    "Transforme une cellule en liste de features (ou None)."
    if isinstance(cell, str):
        try:
            return list(ast.literal_eval(cell))
        except Exception:
            return None
    elif isinstance(cell, (list, tuple, set)):
        return list(cell)
    return None

def count_feats(cell):
    "Renvoie le nombre de features dans la cellule."
    if isinstance(cell, (int, float)):
        return int(cell)
    lst = to_list(cell)
    return len(lst) if lst is not None else None

summary = []

for csv in glob.glob(str(CSV_PATTERN)):
    df = pd.read_csv(csv, index_col=0)

    # 1) Compte par fold
    if "Fold nb of features" in df.columns:
        counts = df["Fold nb of features"].astype(int)
    elif "Fold #features" in df.columns:
        counts = df["Fold #features"].astype(int)
    else:  # on recompte à partir de la colonne listes
        counts = df.iloc[:, 0].apply(count_feats)

    counts = counts.dropna().astype(int)

    # 2) Union des features sur tous les folds
    if "Fold selected features" in df.columns:
        all_feats = set().union(*df["Fold selected features"].apply(to_list).dropna())
    else:
        # si la colonne n'existe pas, on ne peut pas calculer
        all_feats = set()

    stats = {
        "file":           os.path.basename(csv),
        "folds":          len(counts),
        "mean":           counts.mean(),
        "median":         counts.median(),
        "std":            counts.std(ddof=1),
        "min":            counts.min(),
        "max":            counts.max(),
        "25%":            counts.quantile(0.25),
        "75%":            counts.quantile(0.75),
        "unique_feats":   len(all_feats) if all_feats else "n/a",
    }
    summary.append(stats)

summary_df = (
    pd.DataFrame(summary)
      .set_index("file")
      .round(2)
      .sort_index()
)
print(summary_df)

# (option) sauvegarder
summary_df.to_csv(ROOT / "summary_feature_counts.csv")


                                                    folds  mean  median   std  \
file                                                                            
Selected Features Intersect STABL ALasso & STAB...    100  3.86     4.0  1.84   
Selected Features Intersect STABL Lasso & STABL...    100  4.28     4.0  1.88   
Selected Features Intersect STABL Lasso & STABL...    100  4.10     4.0  2.01   
Selected Features STABL ALasso.csv                    100  4.61     5.0  1.90   
Selected Features STABL ElasticNet.csv                100  7.76     5.0  7.04   
Selected Features STABL Lasso.csv                     100  5.53     5.5  2.53   
Selected Features Union STABL ALasso & STABL El...    100  8.51     6.0  6.84   
Selected Features Union STABL Lasso & STABL ALa...    100  5.86     6.0  2.52   
Selected Features Union STABL Lasso & STABL Ela...    100  9.19     7.5  6.82   

                                                    min  max   25%    75%  \
file                           