In [None]:
import SimpleITK as sitk
import numpy as np
import csv
from pathlib import Path
from scipy.ndimage import label as cc_label, binary_dilation



In [None]:


data_path = Path.cwd().parent.parent.parent / "data"
mslesseg_path = data_path / "01_MSLesSeg_Dataset"

# Un CSV par patient + analyse 2D (slice par slice)
lesion_contrast_path = data_path / "14_lesion_contrast_results_per_patient_csv_2D"
lesion_contrast_path.mkdir(parents=True, exist_ok=True)

"""
01 -> 12
- Connected components en 2D (slice par slice)
- Dilatation 2D (voisinage dans la slice uniquement)
- 1 CSV par patient (tous timepoints + toutes modalités)
- Contraste = lesion_mean / neighborhood_mean
"""

# ---------------------------
# 2D CC slice-by-slice
# ---------------------------
def connected_components_2d_slice_by_slice(mask3d: np.ndarray, connectivity: int = 1):
    """
    mask3d: bool array (Z,Y,X)
    return: labels3d int32, n_components total
    """
    zdim = mask3d.shape[0]
    labels3d = np.zeros(mask3d.shape, dtype=np.int32)
    current = 0

    if connectivity == 1:
        # 4-connectivité
        struct = np.array([[0,1,0],[1,1,1],[0,1,0]], dtype=np.uint8)
    else:
        # 8-connectivité
        struct = np.ones((3,3), dtype=np.uint8)

    for z in range(zdim):
        lbl2d, n = cc_label(mask3d[z].astype(np.uint8), structure=struct)
        if n == 0:
            continue
        # remap 1..n -> current+1..current+n
        out = np.zeros_like(lbl2d, dtype=np.int32)
        for k in range(1, n + 1):
            current += 1
            out[lbl2d == k] = current
        labels3d[z] = out

    return labels3d, current


# ---------------------------
# 2D dilatation slice-by-slice
# ---------------------------
def dilate_2d_slice_by_slice(mask3d: np.ndarray, iterations: int = 1, connectivity: int = 1):
    """
    mask3d: bool array (Z,Y,X)
    return: bool array dilatée (Z,Y,X), dilatation 2D dans chaque slice
    """
    if connectivity == 1:
        struct = np.array([[0,1,0],[1,1,1],[0,1,0]], dtype=np.uint8)
    else:
        struct = np.ones((3,3), dtype=np.uint8)

    zdim = mask3d.shape[0]
    out = np.zeros_like(mask3d, dtype=bool)
    for z in range(zdim):
        out[z] = binary_dilation(mask3d[z], structure=struct, iterations=iterations)
    return out


# ---------------------------
# Indexation par patient
# ---------------------------
mask_paths = list(mslesseg_path.rglob("*_MASK.nii.gz"))

# patient_id = "sub-xx" (ou autre), timepoint = "ses-yy" si présent
def parse_patient_timepoint(mask_path: Path):
    parts = mask_path.stem.split("_")
    if len(parts) >= 2:
        patient_id = parts[0]
        timepoint = parts[1]
    else:
        patient_id = parts[0]
        timepoint = "T1"
    return patient_id, timepoint

patients = sorted({parse_patient_timepoint(p)[0] for p in mask_paths})

for patient_id in patients:
    csv_file = lesion_contrast_path / f"{patient_id}_lesion_contrast_2D.csv"
    keys = [
        "patient","timepoint","modality","lesion_id",
        "lesion_mean","neighborhood_mean","contrast_lesion_neighborhood"
    ]

    with open(csv_file, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=keys)
        writer.writeheader()

        # tous les masques de ce patient
        patient_masks = [p for p in mask_paths if parse_patient_timepoint(p)[0] == patient_id]
        patient_masks = sorted(patient_masks)

        if len(patient_masks) == 0:
            print(f"[WARN] Aucun masque pour {patient_id}")
            continue

        for mask_path in patient_masks:
            patient_id2, timepoint = parse_patient_timepoint(mask_path)
            name = f"{patient_id2}_{timepoint}"
            print(f"\nTraitement {name}")

            # 1) récupérer les modalités dispo (T1/T2/FLAIR etc.)
            modalities = []
            for img_path in mask_path.parent.glob("*_*.nii.gz"):
                if "_MASK" not in img_path.stem:
                    mod = (img_path.stem.split("_")[-1]).split(".")[0]
                    modalities.append((mod, img_path))

            if len(modalities) == 0:
                print(f"Aucune séquence trouvée pour {name}, skipping.")
                continue

            # 2) lecture masque + labels 2D
            mask_img = sitk.Cast(sitk.ReadImage(str(mask_path)) > 0, sitk.sitkUInt8)
            mask_array = sitk.GetArrayFromImage(mask_img).astype(bool)  # (Z,Y,X)

            labels_array, num_lesions = connected_components_2d_slice_by_slice(mask_array, connectivity=1)
            print(f"Nombre de lésions détectées (2D) : {num_lesions}")

            if num_lesions == 0:
                continue

            # 3) boucle modalités + calcul contraste
            for modality, img_path in modalities:
                t_img = sitk.Cast(sitk.ReadImage(str(img_path)), sitk.sitkFloat64)
                t_array = sitk.GetArrayFromImage(t_img)

                for lesion_id in range(1, num_lesions + 1):
                    lesion_mask = (labels_array == lesion_id)
                    if not lesion_mask.any():
                        continue

                    lesion_voxels = t_array[lesion_mask]
                    lesion_mean = round(float(lesion_voxels.mean()), 1)

                    # voisinage 2D = dilatation dans la slice uniquement
                    dilated = dilate_2d_slice_by_slice(lesion_mask, iterations=1, connectivity=1)

                    # IMPORTANT:
                    # - version "propre" (évite de prendre d'autres lésions dans le voisinage):
                    neighborhood_mask = dilated & (labels_array == 0)

                    neighborhood_voxels = t_array[neighborhood_mask]

                    if neighborhood_voxels.size > 0:
                        neighborhood_mean = round(float(neighborhood_voxels.mean()), 1)
                        contrast = round(float(lesion_mean / neighborhood_mean), 2) if neighborhood_mean != 0 else ""
                    else:
                        neighborhood_mean = ""
                        contrast = ""

                    writer.writerow({
                        "patient": patient_id2,
                        "timepoint": timepoint,
                        "modality": modality,
                        "lesion_id": lesion_id,
                        "lesion_mean": lesion_mean,
                        "neighborhood_mean": neighborhood_mean,
                        "contrast_lesion_neighborhood": contrast
                    })

    print(f"CSV sauvegardé : {csv_file}")



Traitement P1_T1
Nombre de lésions détectées (2D) : 251
