# Extraction de la meilleure coupe 2D par lésion (SimpleITK only)

Ce notebook parcourt les masques 3D de MSLesSeg, identifie les lésions 3D, et extrait pour chacune la coupe axiale (Z) avec l'aire maximale.

**Contraintes respectées :**
- SimpleITK uniquement (pas de NumPy)
- Sauvegarde de masques 2D en NIfTI


In [None]:
import SimpleITK as sitk
from pathlib import Path


## Chemins

In [None]:
data_path = Path.cwd().parent.parent / "data"
mslesseg_path = data_path / "01_MSLesSeg_Dataset"

out_dir = data_path / "18_best_slice_2d_masks"
out_dir.mkdir(parents=True, exist_ok=True)


## Boucle principale

In [None]:
for mask_path in mslesseg_path.rglob("*_MASK.nii.gz"):

    arg = mask_path.stem.split("_")
    if len(arg) == 3:
        patient_id, timepoint = arg[0], arg[1]
    elif len(arg) == 2:
        patient_id = arg[0]
        timepoint = "T1"
    else:
        raise ValueError(f"Nom de fichier inattendu : {mask_path.name}")

    print(f"\nTraitement : {mask_path.name}")

    mask_img = sitk.ReadImage(mask_path)

    mask_bin = sitk.BinaryThreshold(
        mask_img,
        lowerThreshold=0.5,
        upperThreshold=1e9,
        insideValue=1,
        outsideValue=0
    )
    mask_bin = sitk.Cast(mask_bin, sitk.sitkUInt8)

    cc = sitk.ConnectedComponent(mask_bin)

    shape3d = sitk.LabelShapeStatisticsImageFilter()
    shape3d.Execute(cc)
    labels = list(shape3d.GetLabels())

    if len(labels) == 0:
        print("  Aucune lésion trouvée")
        continue

    vols = [(lab, shape3d.GetPhysicalSize(lab)) for lab in labels]
    vols.sort(key=lambda x: x[1], reverse=True)

    size_x, size_y, size_z = cc.GetSize()

    for lesion_id, (label, _) in enumerate(vols, start=1):

        best_z = None
        best_area = 0

        for z in range(size_z):
            slice_cc = sitk.Extract(
                cc,
                size=[size_x, size_y, 0],
                index=[0, 0, z]
            )

            slice_lesion = sitk.BinaryThreshold(
                slice_cc,
                lowerThreshold=label,
                upperThreshold=label,
                insideValue=1,
                outsideValue=0
            )

            stats2d = sitk.StatisticsImageFilter()
            stats2d.Execute(slice_lesion)
            area = int(stats2d.GetSum())

            if area > best_area:
                best_area = area
                best_z = z

        if best_area == 0:
            continue

        lesion_2d = sitk.Extract(
            cc,
            size=[size_x, size_y, 0],
            index=[0, 0, best_z]
        )

        lesion_2d = sitk.BinaryThreshold(
            lesion_2d,
            lowerThreshold=label,
            upperThreshold=label,
            insideValue=1,
            outsideValue=0
        )

        spacing = mask_img.GetSpacing()[:2]
        origin = mask_img.GetOrigin()[:2]
        direction = mask_img.GetDirection()[:4]

        lesion_2d.SetSpacing(spacing)
        lesion_2d.SetOrigin(origin)
        lesion_2d.SetDirection(direction)

        out_name = f"{patient_id}_{timepoint}_A_L{lesion_id}_Mask.nii.gz"
        out_path = out_dir / out_name

        sitk.WriteImage(lesion_2d, out_path)
        print(f"  OK: L{lesion_id} → slice {best_z} → {out_name}")
