# Diagnostic MLP – 430 Features

Version 0.2  
Auteur : Yoan  
Date : 2025‑06‑30

Objectif : prédire automatiquement si un patient est malade (1) ou sain (0) à l’aide de 430 caractéristiques numériques extraites de données d’IRM.

Le notebook suit un pipeline complet : ingestion, feature engineering, split, hyperparameter tuning, entraînement final, évaluation et explicabilité.


In [None]:
# --- Configuration générale ---
from pathlib import Path
from robust_evaluation_tools.robust_MLP import PatientMLP, MODEL_DIR
DATA_DIR  = Path("DONNES_F/COMPILATIONS_AUG_3/")      # <-- adapte si besoin
disease = "ALL"
RUN_NAME  = f"mlp2_{disease}"
MODEL_DIR.mkdir(parents=True, exist_ok=True)
SEED = 41


In [None]:
# INSTALLATION (décommente si nécessaire)
# %pip install -q pandas numpy scikit-learn torch optuna shap tensorboard joblib tqdm


In [None]:
import numpy as np
import pandas as pd
import torch, torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (classification_report, roc_auc_score, f1_score,
                             confusion_matrix, ConfusionMatrixDisplay, RocCurveDisplay,
                             PrecisionRecallDisplay)
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import joblib, random, os, json, optuna
from tqdm.auto import tqdm

# ----- Helpers -----
def set_seed(seed: int = 42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(SEED)

device = "cpu"
print("Device:", device)

def show_class_balance(y):
    vals, counts = np.unique(y, return_counts=True)
    for v, c in zip(vals, counts):
        print(f"Classe {int(v)} : {c}")

def plot_curves(train, val, ylabel="Loss"):
    plt.figure(figsize=(6,4))
    epochs = range(1, len(train)+1)
    plt.plot(epochs, train, label="train")
    plt.plot(epochs, val,   label="val")
    plt.xlabel("Epoch")
    plt.ylabel(ylabel)
    plt.title(f"Courbe {ylabel}")
    plt.legend(); plt.grid(True); plt.show()


In [None]:


# Si disease == "ALL", on fusionne toutes les maladies sans doublons de SID
if disease == "ALL":
    # sids_vus = set()
    # df_total = pd.DataFrame()
    # for maladie in ["AD", "ADHD", "BIP", "MCI", "SCHZ", "TBI"]:
    #     df_raw = pd.read_csv(DATA_DIR / f"{maladie}_combination_all_metrics_CamCAN.csv.gz")
        
    #     # On enlève les SIDs déjà vus
    #     df_filtré = df_raw[~df_raw["sid"].isin(sids_vus)]
        
    #     # On ajoute les nouveaux SIDs à notre set
    #     sids_vus.update(df_filtré["sid"].unique())
        
    #     # On concatène le DataFrame filtré
    #     df_total = pd.concat([df_total, df_filtré], ignore_index=True)
    # df_raw = df_total
    df_raw = pd.read_csv("DONNES_MLP/train_data_all_aug5.csv")
    df_raw[~((df_raw['disease'] == 'HC') & (df_raw['old_site'] != 'CamCAN'))]
else:
    df_raw = pd.read_csv(DATA_DIR / f"{disease}_combination_all_metrics_CamCAN.csv.gz")
print("Raw shape:", df_raw.shape)
display(df_raw.head())


In [None]:
# Nettoyage minimal
df_raw = df_raw[~df_raw['bundle'].isin(['left_ventricle', 'right_ventricle'])].copy()
print("Sans ventricules :", df_raw.shape)


In [None]:
# ----- 3. Feature engineering -----
def compute_zscore(df, value_col="mean_no_cov"):
    stats = (df.groupby("metric_bundle")[value_col]
               .agg(['mean', 'std'])
               .rename(columns={'mean': 'global_mean', 'std': 'global_std'}))
    stats['global_std'] = stats['global_std'].replace(0, 1e-6)
    df = df.merge(stats, on="metric_bundle", how="left")
    df["zscore"] = (df[value_col] - df["global_mean"]) / df["global_std"]
    return df.drop(columns=["global_mean", "global_std"])

def build_feature_matrix(df, value_col="zscore", bundle_col="metric_bundle", healthy_tag="HC"):
    features = df.pivot(index="sid", columns=bundle_col, values=value_col)
    label = (df.groupby("sid")["disease"].first().ne(healthy_tag).astype(int))
    mat = features.assign(label=label).reset_index(drop=False)
    return mat

df_clean = compute_zscore(df_raw, value_col="mean_no_cov")
dupes = (df_clean
         .groupby(["sid", "metric_bundle"])
         .size()
         .loc[lambda s: s > 1]
         .sort_values(ascending=False))
print(f"Nombre de paires sid / metric_bundle en double : {dupes.shape[0]}")
df_mat   = build_feature_matrix(df_clean, value_col="zscore")
df_mat = df_mat.drop(columns=["sid"])
print("Matrix shape:", df_mat.shape)
display(df_mat.head())


In [None]:
# ----- 4. Split & normalisation -----
X = df_mat.drop(columns="label").values.astype(np.float32)
y = df_mat["label"].values.astype(np.float32)
show_class_balance(y)

X_train, X_temp, y_train, y_temp = train_test_split(
    X, y, test_size=0.5, stratify=y, random_state=SEED)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=SEED)

# scaler = StandardScaler().fit(X_train)
# X_train = scaler.transform(X_train)
# X_val   = scaler.transform(X_val)
# X_test  = scaler.transform(X_test)

print("Train:", X_train.shape, "Val:", X_val.shape, "Test:", X_test.shape)


In [None]:
# ----- 5. DataLoader -----
class PatientDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    def __len__(self): return len(self.X)
    def __getitem__(self, idx): return self.X[idx], self.y[idx]

BATCH = 64
train_dl = DataLoader(PatientDataset(X_train, y_train), batch_size=BATCH, shuffle=True)
val_dl   = DataLoader(PatientDataset(X_val,   y_val),   batch_size=BATCH)
test_dl  = DataLoader(PatientDataset(X_test,  y_test),  batch_size=BATCH)


In [None]:
# ----- 6A. Baseline LogisticRegression -----
from sklearn.linear_model import LogisticRegression
baseline = LogisticRegression(max_iter=1000, n_jobs=-1)
baseline.fit(X_train, y_train)
prob_val = baseline.predict_proba(X_val)[:,1]
auc_base = roc_auc_score(y_val, prob_val)
print(f"AUC validation LogisticRegression: {auc_base:.3f}")


In [None]:
# ----- 7. Training helpers -----
def train_epoch(model, loader, crit, opt):
    model.train()
    running = 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad()
        loss = crit(model(xb), yb)
        loss.backward()
        opt.step()
        running += loss.item() * xb.size(0)
    return running / len(loader.dataset)

@torch.no_grad()
def eval_epoch(model, loader, crit):
    model.eval()
    losses, probs, labels = [], [], []
    for xb, yb in loader:
        xb = xb.to(device)
        logits = model(xb)
        loss = crit(logits, yb.to(device)).item()
        losses.append(loss * xb.size(0))
        probs.append(torch.sigmoid(logits).cpu())
        labels.append(yb)
    probs  = torch.cat(probs).numpy()
    labels = torch.cat(labels).numpy()
    auc = roc_auc_score(labels, probs)
    f1  = f1_score(labels, (probs>0.5).astype(int))
    return np.sum(losses) / len(loader.dataset), auc, f1

def fit(model, train_dl, val_dl, epochs=100, lr=1e-3, wd=1e-4, patience=10, run_name="run"):
    crit = nn.BCEWithLogitsLoss()
    opt  = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    sched= torch.optim.lr_scheduler.ReduceLROnPlateau(opt, patience=5, factor=0.5)
    writer = SummaryWriter(f"{MODEL_DIR}/runs/{run_name}")
    best_auc, best_state, counter = 0, None, 0
    tr_losses, val_losses = [], []
    for ep in tqdm(range(1, epochs+1)):
        tr_loss = train_epoch(model, train_dl, crit, opt)
        val_loss, val_auc, _ = eval_epoch(model, val_dl, crit)
        tr_losses.append(tr_loss); val_losses.append(val_loss)
        writer.add_scalar("Loss/train", tr_loss, ep)
        writer.add_scalar("Loss/val",   val_loss, ep)
        writer.add_scalar("AUC/val",    val_auc,  ep)
        sched.step(val_loss)
        if val_auc > best_auc + 1e-4:
            best_auc = val_auc
            best_state = model.state_dict()
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("Early stopping.")
                break
    model.load_state_dict(best_state)
    return best_state, tr_losses, val_losses, best_auc


In [None]:
# ----- 8. Hyperparameter tuning (Optuna) -----
def objective(trial):
    hidden_dim1 = trial.suggest_int("h1", 128, 512, step=64)
    hidden_dim2 = trial.suggest_int("h2", 64, 256, step=32)
    hidden_dim3 = trial.suggest_int("h3", 32, 128, step=16)
    drop        = trial.suggest_float("dropout", 0.1, 0.5)
    lr          = trial.suggest_float("lr", 1e-4, 5e-3, log=True)
    wd          = trial.suggest_float("wd", 1e-6, 1e-3, log=True)

    model = PatientMLP(hidden_dims=(hidden_dim1, hidden_dim2, hidden_dim3), drop=drop).to(device)
    state, _, _, best_auc = fit(model, train_dl, val_dl,
                                epochs=15, lr=lr, wd=wd,
                                patience=5, run_name="tune")
    return best_auc

study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=30, show_progress_bar=True)

print("Best AUC:", study.best_value)
print("Best params:", study.best_params)


In [None]:
# ----- 9. Entraînement final avec les meilleurs hyperparamètres -----
best = study.best_params
model_final = PatientMLP(
    hidden_dims=(best["h1"], best["h2"], best["h3"]),
    drop=best["dropout"]).to(device)
state, train_losses, val_losses, best_auc = fit(
    model_final, train_dl, val_dl,
    epochs=100, lr=best["lr"], wd=best["wd"],
    patience=10, run_name=RUN_NAME)


In [None]:
# Courbes d’apprentissage
plot_curves(train_losses, val_losses, ylabel="BCE Loss")


In [None]:
# ----- 11. Évaluation finale sur test -----
_, test_auc, test_f1 = eval_epoch(model_final, test_dl, nn.BCEWithLogitsLoss())
print(f"AUC test: {test_auc:.3f} | F1 test: {test_f1:.3f}")

# Confusion matrix
model_final.eval()
preds, labels = [], []
with torch.no_grad():
    for xb, yb in test_dl:
        preds.append(torch.sigmoid(model_final(xb.to(device))).cpu())
        labels.append(yb)
preds = torch.cat(preds).numpy()
labels= torch.cat(labels).numpy()
ConfusionMatrixDisplay.from_predictions(labels, preds>0.5)
plt.show()


In [None]:
# ----- 12. Sauvegarde -----
torch.save(state, MODEL_DIR / f"{RUN_NAME}_weights.pt")
with open(MODEL_DIR / f"{RUN_NAME}_params.json", "w") as fp:
    json.dump(study.best_params, fp, indent=2)
print("Artifacts saved in", MODEL_DIR)


In [None]:
# ----- 13. Exemple d’inférence -----
sample = np.random.rand(430).reshape(1, -1)
with torch.no_grad():
    prob = torch.sigmoid(model_final(torch.tensor(sample, dtype=torch.float32).to(device))).item()
print(f"Probabilité malade: {prob:.3f}")


In [None]:
# ----- 14. Explainability (facultatif) -----
# import shap
# explainer = shap.DeepExplainer(model_final, torch.tensor(X_train[:100]).to(device))
# shap_values = explainer.shap_values(torch.tensor(sample_std).to(device))
# shap.summary_plot(shap_values, features=sample_std)
