### Générateur du Best Slice (01 -> 17)

In [7]:
import os
import csv
import numpy as np
import SimpleITK as sitk
from pathlib import Path

data_path = Path.cwd().parent.parent / "data"
mslesseg_path = data_path / "01_MSLesSeg_Dataset"

best_slice_csv_dir = data_path / "17_best_slice_csv"
if not best_slice_csv_dir.exists():
    best_slice_csv_dir.mkdir(parents = True, exist_ok = True)

In [None]:
"""
01 -> 17
"""

CSV_DELIM = ","

# Champs CSV (identiques pour chaque lésion)
fieldnames = [
    "lesion_id",

    "best_slice_z",
    "best_slice_y",
    "best_slice_x",

    "centroid_x",
    "centroid_y",
    "centroid_z",

    "centroid_x_3d",
    "centroid_y_3d",
    "centroid_z_3d",

    "best_area_z_px",
    "best_area_y_px",
    "best_area_x_px",

    "flatness_3d",
    "density_3d",
]


def compute_flatness_from_shape_stats(shape_filter, lab):
    """
    Flatness ITK/SimpleITK :
    1) GetFlatness si disponible
    2) sinon sqrt(m3/m1) via moments principaux
    3) sinon NaN
    """
    if hasattr(shape_filter, "GetFlatness"):
        try:
            return float(shape_filter.GetFlatness(lab))
        except Exception:
            pass

    if hasattr(shape_filter, "GetPrincipalMoments"):
        try:
            m = shape_filter.GetPrincipalMoments(lab)
            m = sorted([float(x) for x in m], reverse=True)  # m1 >= m2 >= m3
            if len(m) >= 3 and m[0] > 0:
                return float(np.sqrt(m[2] / m[0]))
        except Exception:
            pass

    return float("nan")


# ======================================================================
# MAIN LOOP
# ======================================================================

for mask_path in mslesseg_path.rglob("*_MASK.nii.gz"):

    # -------------------------
    # Patient / timepoint
    # -------------------------
    arg = mask_path.stem.split("_")
    if len(arg) == 3:
        patient_id, timepoint = arg[0], arg[1]
    elif len(arg) == 2:
        patient_id = arg[0]
    else:
        raise ValueError(f"Nom de fichier inattendu : {mask_path.name}")

    print(f"\nTraitement : {mask_path.name}")

    # -------------------------
    # Lecture & binarisation
    # -------------------------
    mask_img = sitk.ReadImage(mask_path)

    mask_bin = sitk.BinaryThreshold(
        mask_img,
        lowerThreshold=0.5,
        upperThreshold=1e9,
        insideValue=1,
        outsideValue=0
    )
    mask_bin = sitk.Cast(mask_bin, sitk.sitkUInt8)

    # -------------------------
    # Connected components 3D
    # -------------------------
    cc = sitk.ConnectedComponent(mask_bin)

    shape3d = sitk.LabelShapeStatisticsImageFilter()
    shape3d.Execute(cc)
    labels = list(shape3d.GetLabels())

    if len(labels) == 0:
        raise RuntimeError("Aucune lésion dans le masque.")

    # -------------------------
    # Tri par volume décroissant
    # -------------------------
    vols = [(lab, shape3d.GetPhysicalSize(lab)) for lab in labels]
    vols.sort(key=lambda x: x[1], reverse=True)

    label_to_newid = {lab: i + 1 for i, (lab, _) in enumerate(vols)}
    newid_to_label = {new_id: lab for lab, new_id in label_to_newid.items()}

    # -------------------------
    # Conversion numpy
    # -------------------------
    cc_arr = sitk.GetArrayFromImage(cc).astype(np.int32)  # (z, y, x)

    cc_new = np.zeros_like(cc_arr, dtype=np.int32)
    for old_lab, new_id in label_to_newid.items():
        cc_new[cc_arr == old_lab] = new_id

    # ==================================================================
    # LOOP PAR LÉSION → UN CSV PAR LÉSION
    # ==================================================================

    for lesion_id in range(1, len(vols) + 1):

        old_lab = newid_to_label[lesion_id]
        lesion_3d = (cc_new == lesion_id)

        # -------------------------
        # Best slices
        # -------------------------
        areas_z = lesion_3d.sum(axis=(1, 2))
        best_z = int(np.argmax(areas_z))
        best_area_z = int(areas_z[best_z])

        if best_area_z <= 0:
            continue

        areas_y = lesion_3d.sum(axis=(0, 2))
        best_y = int(np.argmax(areas_y))
        best_area_y = int(areas_y[best_y])

        areas_x = lesion_3d.sum(axis=(0, 1))
        best_x = int(np.argmax(areas_x))
        best_area_x = int(areas_x[best_x])

        # -------------------------
        # Centroids
        # -------------------------
        ys, xs = np.where(lesion_3d[best_z])
        centroid_x = float(xs.mean())
        centroid_y = float(ys.mean())

        zs3, ys3, xs3 = np.where(lesion_3d)
        centroid_x_3d = float(xs3.mean())
        centroid_y_3d = float(ys3.mean())
        centroid_z_3d = float(zs3.mean())

        # -------------------------
        # Flatness & density
        # -------------------------
        flatness_3d = compute_flatness_from_shape_stats(shape3d, old_lab)

        bbox = shape3d.GetBoundingBox(old_lab)  # (x,y,z,sizeX,sizeY,sizeZ)
        sizeX, sizeY, sizeZ = int(bbox[3]), int(bbox[4]), int(bbox[5])
        bbox_vox = max(1, sizeX * sizeY * sizeZ)

        nvox = int(shape3d.GetNumberOfPixels(old_lab))
        density_3d = float(nvox / bbox_vox)

        # -------------------------
        # Ligne CSV
        # -------------------------
        row = {
            "lesion_id": lesion_id,

            "best_slice_z": best_z,
            "best_slice_y": best_y,
            "best_slice_x": best_x,

            "centroid_x": centroid_x,
            "centroid_y": centroid_y,
            "centroid_z": float(best_z),

            "centroid_x_3d": centroid_x_3d,
            "centroid_y_3d": centroid_y_3d,
            "centroid_z_3d": centroid_z_3d,

            "best_area_z_px": best_area_z,
            "best_area_y_px": best_area_y,
            "best_area_x_px": best_area_x,

            "flatness_3d": flatness_3d,
            "density_3d": density_3d,
        }

        # -------------------------
        # Écriture CSV par lésion
        # -------------------------
        out_csv = best_slice_csv_dir / f"{patient_id}_L{lesion_id}_best_slice.csv"
        os.makedirs(out_csv.parent, exist_ok=True)

        with open(out_csv, "w", newline="", encoding="utf-8") as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames, delimiter=CSV_DELIM)
            writer.writeheader()
            writer.writerow(row)

        print(f"  OK: {out_csv.name}")
