In [None]:


import sys, os, gzip, shutil, math
!{sys.executable} -m pip install openneuro-py nibabel scikit-learn seaborn tqdm --quiet

import numpy as np
import pandas as pd
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as T
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score, f1_score, classification_report, roc_curve
from collections import Counter, defaultdict
from tqdm import tqdm
import matplotlib.pyplot as plt
import statistics as stats
import seaborn as sns
sns.set()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")


#  OpenNeuro

dataset_root_path = "ds003463"
if not os.path.exists(dataset_root_path):
    os.system(f"openneuro-py download --dataset={dataset_root_path}")
else:
    print("Dossier déjà présent, téléchargement sauté.")


#prepa données

label_map = {
    'sub-m01': 1, 'sub-m02': 1, 'sub-m03': 1, 'sub-m04': 1, 'sub-m05': 1,
    'sub-m06': 1, 'sub-m07': 1, 'sub-m08': 1, 'sub-m09': 0, 'sub-m10': 0,
    'sub-m11': 0, 'sub-m12': 0, 'sub-m13': 0, 'sub-m14': 0, 'sub-r01': 1,
    'sub-r02': 1, 'sub-r03': 1, 'sub-r04': 1, 'sub-r05': 0, 'sub-r06': 0,
    'sub-r07': 0, 'sub-r08': 0
}

def is_valid_nifti(path: str) -> bool:
    try:
        img = nib.load(path, mmap=False)
        _ = img.get_fdata(dtype=np.float32)  # force la lecture
        return True
    except Exception:
        return False

def subject_from_path(p: str) -> str:
    parts = p.replace("\\", "/").split("/")
    for i, token in enumerate(parts):
        if token.startswith("sub-"):
            return token
    return "unknown"

all_files, all_labels, all_subjects = [], [], []
for sub in label_map.keys():
    sub_path = os.path.join(dataset_root_path, sub)
    if not os.path.isdir(sub_path):
        continue
    for ses in os.listdir(sub_path):
        if not ses.startswith("ses-"):
            continue
        anat = os.path.join(sub_path, ses, "anat")
        if not os.path.isdir(anat):
            continue
        for f in os.listdir(anat):
            if "MGE" in f and f.endswith(".nii.gz"):
                fullp = os.path.join(anat, f)
                if is_valid_nifti(fullp):
                    all_files.append(fullp)
                    all_subjects.append(sub)
                    all_labels.append(label_map[sub])

print(f"IRMs valides: {len(all_files)} fichiers, {len(set(all_subjects))} sujets.")


# Split stratif par SJT

print("\n### ÉTAPE 4: Split train/test stratifié par SUJET ###")

labels_np = np.array(all_labels)
subjects_np = np.array(all_subjects)
idxs = np.arange(len(all_files))

sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)
train_idx, test_idx = next(sgkf.split(idxs, labels_np, groups=subjects_np))

train_files = [all_files[i] for i in train_idx]
train_labels = [all_labels[i] for i in train_idx]
train_subjects = [all_subjects[i] for i in train_idx]

test_files  = [all_files[i] for i in test_idx]
test_labels = [all_labels[i] for i in test_idx]
test_subjects = [all_subjects[i] for i in test_idx]

print(f"Train: {len(train_files)} fichiers / {len(set(train_subjects))} sujets")
print(f"Test : {len(test_files)} fichiers / {len(set(test_subjects))} sujets")


def robust_normalize(vol: np.ndarray):
    if vol.ndim == 4:
        vol = vol.mean(axis=3)
    lo, hi = np.percentile(vol, [0.5, 99.5])
    vol = np.clip(vol, lo, hi)
    mu, sd = vol.mean(), vol.std() + 1e-8
    return (vol - mu) / sd

def make_triplet(vol: np.ndarray, k: int, mode="center"):

    assert vol.ndim == 3
    z = vol.shape[2]
    mids = []
    if mode == "center":
        mid = z // 2

        half = k // 2
        start = max(0, mid - half)
        idxs = np.arange(start, min(start + k, z))
    else:

        idxs = np.linspace(z*0.35, z*0.65, num=k).astype(int)

    triplets = []
    for m in idxs:
        i1 = max(0, m-1); i2 = m; i3 = min(z-1, m+1)
        arr = np.stack([vol[:,:,i1], vol[:,:,i2], vol[:,:,i3]], axis=0).astype(np.float32)
        triplets.append(arr)
    return triplets

class MRITripletDataset(Dataset):
    def __init__(self, file_paths, labels, subjects, k_slices=5, train=True):
        self.samples = []
        self.train = train
        self.k = k_slices

        # Transf
        if train:
            self.tf = T.Compose([
                T.Resize((224,224), antialias=True),
                T.RandomRotation(degrees=10),
                T.RandomHorizontalFlip(p=0.5),
                T.Normalize(mean=[0.0,0.0,0.0], std=[1.0,1.0,1.0]),
            ])
        else:
            self.tf = T.Compose([
                T.Resize((224,224), antialias=True),
                T.Normalize(mean=[0.0,0.0,0.0], std=[1.0,1.0,1.0]),
            ])

        for p, y, s in tqdm(list(zip(file_paths, labels, subjects)),
                            total=len(file_paths),
                            desc=("Génération triplets [TRAIN]" if train else "Génération triplets [TEST]")):
            vol = nib.load(p, mmap=False).get_fdata(dtype=np.float32)
            vol = robust_normalize(vol)
            trips = make_triplet(vol, k=self.k, mode="center")
            for arr in trips:
                x = torch.from_numpy(arr)
                x = self.tf(x)
                self.samples.append((x, int(y), s, p))

    def __len__(self): return len(self.samples)
    def __getitem__(self, i):
        x, y, subj, path = self.samples[i]
        return x, torch.tensor(y, dtype=torch.long), subj, path

k_slices = 5
train_ds = MRITripletDataset(train_files, train_labels, train_subjects, k_slices=k_slices, train=True)
test_ds  = MRITripletDataset(test_files,  test_labels,  test_subjects,  k_slices=k_slices, train=False)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,  num_workers=0)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False, num_workers=0)

print(f"Train samples (triplets): {len(train_ds)}")
print(f"Test  samples (triplets): {len(test_ds)}")



In [None]:
# CV 2.5D Resnet
#
def build_datasets(train_files, train_labels, train_subjects,
                   test_files,  test_labels,  test_subjects,
                   k_slices=5):
    train_ds = MRITripletDataset(train_files, train_labels, train_subjects,
                                 k_slices=k_slices, train=True)
    test_ds  = MRITripletDataset(test_files,  test_labels,  test_subjects,
                                 k_slices=k_slices, train=False)
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,  num_workers=0)
    test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False, num_workers=0)
    return train_ds, test_ds, train_loader, test_loader

def make_model(device):
    m = models.resnet18(weights='IMAGENET1K_V1')
    m.fc = nn.Linear(m.fc.in_features, 2)
    m = m.to(device)
    for p in m.parameters(): p.requires_grad = False
    for p in m.fc.parameters(): p.requires_grad = True
    return m

def train_one_fold(model, train_loader, test_loader, class_weights, device,
                   warmup_epochs=3, max_epochs=30):
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                            lr=3e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                     factor=0.5, patience=2)

    def eval_val_loss():
        model.eval()
        loss_sum, n = 0.0, 0
        with torch.no_grad():
            for xb, yb, _, _ in test_loader:
                xb, yb = xb.to(device), yb.to(device)
                loss_sum += criterion(model(xb), yb).item() * xb.size(0)
                n += xb.size(0)
        return loss_sum / max(1, n)

    best_val = float('inf'); patience, no_imp = 6, 0
    for epoch in range(1, max_epochs+1):
        model.train()
        for xb, yb, _, _ in tqdm(train_loader, desc=f"Fold train epoch {epoch}/{max_epochs}", leave=False):
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            loss = criterion(model(xb), yb)
            loss.backward(); optimizer.step()
        val_loss = eval_val_loss(); scheduler.step(val_loss)

        if epoch == warmup_epochs:
            for name, m in model.named_children():
                if name in ["layer3", "layer4", "bn1", "conv1", "fc"]:
                    for p in m.parameters(): p.requires_grad = True
            optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                                    lr=1e-4, weight_decay=1e-4)

        if val_loss < best_val - 1e-4:
            best_val = val_loss; no_imp = 0
            torch.save(model.state_dict(), "best_fold.pt")
        else:
            no_imp += 1
            if no_imp >= patience:
                print("  early stopping.")
                break

    model.load_state_dict(torch.load("best_fold.pt", map_location=device))
    return model

def evaluate_slice_and_subject(model, test_loader):
    model.eval()
    sm = nn.Softmax(dim=1)
    all_probs, all_preds, all_true = [], [], []
    all_subjects, all_paths = [], []
    with torch.no_grad():
        for xb, yb, subj, path in test_loader:
            xb = xb.to(device)
            probs = sm(model(xb)).cpu().numpy()
            preds = probs.argmax(1)
            all_probs.append(probs[:,1]); all_preds.append(preds); all_true.append(yb.numpy())
            all_subjects.extend(list(subj)); all_paths.extend(list(path))
    all_probs = np.concatenate(all_probs); all_preds = np.concatenate(all_preds); all_true = np.concatenate(all_true)

    # par cp
    acc_slice = accuracy_score(all_true, all_preds)
    f1_slice  = f1_score(all_true, all_preds)
    try: roc_slice = roc_auc_score(all_true, all_probs)
    except: roc_slice = float("nan")

    # par sjt
    by_subject_scores = defaultdict(list); by_subject_true = {}
    for p1, y, s in zip(all_probs, all_true, all_subjects):
        by_subject_scores[s].append(p1); by_subject_true[s] = y
    subj_probs, subj_true, subj_pred = [], [], []
    for s, scores in by_subject_scores.items():
        p = float(np.mean(scores)); y = by_subject_true[s]; pred = int(p >= 0.5)
        subj_probs.append(p); subj_true.append(y); subj_pred.append(pred)
    subj_probs = np.array(subj_probs); subj_true = np.array(subj_true); subj_pred = np.array(subj_pred)

    acc_subj = accuracy_score(subj_true, subj_pred)
    f1_subj  = f1_score(subj_true, subj_pred)
    try: roc_subj = roc_auc_score(subj_true, subj_probs)
    except: roc_subj = float("nan")

    return {
        "slice": {"acc": acc_slice, "f1": f1_slice, "roc": roc_slice},
        "subject": {"acc": acc_subj, "f1": f1_subj, "roc": roc_subj},
        "n_subjects_test": len(subj_true)
    }

# les listes
files = all_files
labels = np.array(all_labels)
subjects = np.array(all_subjects)
idxs = np.arange(len(files))

# Cross-validation 5-fold par sujet
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)

fold_results = []
fold_id = 0
for train_idx, test_idx in sgkf.split(idxs, labels, groups=subjects):
    fold_id += 1
    print(f"\n========== FOLD {fold_id} ==========")
    tr_files = [files[i] for i in train_idx]
    tr_labels = [int(labels[i]) for i in train_idx]
    tr_subjects = [subjects[i] for i in train_idx]
    te_files = [files[i] for i in test_idx]
    te_labels = [int(labels[i]) for i in test_idx]
    te_subjects = [subjects[i] for i in test_idx]

    print(f"Train: {len(set(tr_subjects))} sujets, Test: {len(set(te_subjects))} sujets")


    counts = Counter(tr_labels); total = sum(counts.values())
    w0 = total/(2.0*counts[0]); w1 = total/(2.0*counts[1])
    class_weights = torch.tensor([w0, w1], dtype=torch.float, device=device)

    # Data
    train_ds, test_ds, train_loader, test_loader = build_datasets(
        tr_files, tr_labels, tr_subjects,
        te_files, te_labels, te_subjects,
        k_slices=5
    )

    # Modele + entrainement
    model = make_model(device)
    model = train_one_fold(model, train_loader, test_loader, class_weights, device,
                           warmup_epochs=3, max_epochs=30)

    # evaluation
    metrics = evaluate_slice_and_subject(model, test_loader)
    fold_results.append(metrics)

    s = metrics["slice"]; u = metrics["subject"]
    print(f"[FOLD {fold_id}] Par coupe  : Acc={s['acc']*100:.1f}% | F1={s['f1']:.3f} | ROC-AUC={s['roc']:.3f}")
    print(f"[FOLD {fold_id}] PAR SUJET : Acc={u['acc']*100:.1f}% | F1={u['f1']:.3f} | ROC-AUC={u['roc']:.3f} | n_test_sujets={metrics['n_subjects_test']}")

#
def mean_std(vals):
    vals = [v for v in vals if not (isinstance(v, float) and (np.isnan(v) or np.isinf(v)))]
    return (np.mean(vals), np.std(vals)) if vals else (float("nan"), float("nan"))

slice_acc = [fr["slice"]["acc"] for fr in fold_results]
slice_f1  = [fr["slice"]["f1"]  for fr in fold_results]
slice_roc = [fr["slice"]["roc"] for fr in fold_results]

subj_acc = [fr["subject"]["acc"] for fr in fold_results]
subj_f1  = [fr["subject"]["f1"]  for fr in fold_results]
subj_roc = [fr["subject"]["roc"] for fr in fold_results]

m_sa, s_sa = mean_std(slice_acc)
m_sf, s_sf = mean_std(slice_f1)
m_sr, s_sr = mean_std(slice_roc)

m_ua, s_ua = mean_std(subj_acc)
m_uf, s_uf = mean_std(subj_f1)
m_ur, s_ur = mean_std(subj_roc)

print("\n========== RÉSUMÉ CROSS-VAL ==========")
print(f"Par coupe  : Acc={m_sa*100:.1f}±{s_sa*100:.1f}% | F1={m_sf:.3f}±{s_sf:.3f} | ROC-AUC={m_sr:.3f}±{s_sr:.3f}")
print(f"PAR SUJET : Acc={m_ua*100:.1f}±{s_ua*100:.1f}% | F1={m_uf:.3f}±{s_uf:.3f} | ROC-AUC={m_ur:.3f}±{s_ur:.3f}")



Train: 17 sujets, Test: 4 sujets


Génération triplets [TRAIN]: 100%|██████████| 268/268 [00:45<00:00,  5.95it/s]
Génération triplets [TEST]: 100%|██████████| 75/75 [00:09<00:00,  8.01it/s]


  early stopping.
[FOLD 1] Par coupe  : Acc=60.5% | F1=0.357 | ROC-AUC=0.539
[FOLD 1] PAR SUJET : Acc=75.0% | F1=0.000 | ROC-AUC=0.667 | n_test_sujets=4

Train: 17 sujets, Test: 4 sujets


Génération triplets [TRAIN]: 100%|██████████| 283/283 [00:49<00:00,  5.73it/s]
Génération triplets [TEST]: 100%|██████████| 60/60 [00:04<00:00, 13.82it/s]


  early stopping.
[FOLD 2] Par coupe  : Acc=62.0% | F1=0.748 | ROC-AUC=0.533
[FOLD 2] PAR SUJET : Acc=75.0% | F1=0.857 | ROC-AUC=0.667 | n_test_sujets=4

Train: 18 sujets, Test: 3 sujets


Génération triplets [TRAIN]: 100%|██████████| 304/304 [00:47<00:00,  6.39it/s]
Génération triplets [TEST]: 100%|██████████| 39/39 [00:06<00:00,  5.61it/s]


  early stopping.




[FOLD 3] Par coupe  : Acc=92.8% | F1=0.963 | ROC-AUC=nan
[FOLD 3] PAR SUJET : Acc=100.0% | F1=1.000 | ROC-AUC=nan | n_test_sujets=3

Train: 16 sujets, Test: 5 sujets


Génération triplets [TRAIN]: 100%|██████████| 250/250 [00:35<00:00,  7.01it/s]
Génération triplets [TEST]: 100%|██████████| 93/93 [00:17<00:00,  5.24it/s]


  early stopping.
[FOLD 4] Par coupe  : Acc=49.5% | F1=0.319 | ROC-AUC=0.452
[FOLD 4] PAR SUJET : Acc=40.0% | F1=0.000 | ROC-AUC=0.333 | n_test_sujets=5

Train: 16 sujets, Test: 5 sujets


Génération triplets [TRAIN]: 100%|██████████| 267/267 [00:42<00:00,  6.30it/s]
Génération triplets [TEST]: 100%|██████████| 76/76 [00:11<00:00,  6.43it/s]


  early stopping.
[FOLD 5] Par coupe  : Acc=37.9% | F1=0.330 | ROC-AUC=0.338
[FOLD 5] PAR SUJET : Acc=40.0% | F1=0.400 | ROC-AUC=0.167 | n_test_sujets=5

Par coupe  : Acc=60.5±18.3% | F1=0.543±0.264 | ROC-AUC=0.466±0.082
PAR SUJET : Acc=66.0±23.1% | F1=0.451±0.419 | ROC-AUC=0.458±0.217
