In [None]:
import SimpleITK as sitk
import numpy as np
import csv
from pathlib import Path
from scipy.stats import skew, kurtosis
from scipy.ndimage import label as cc_label

In [None]:
data_path = Path.cwd().parent.parent.parent / "data"
mslesseg_path = data_path / "01_MSLesSeg_Dataset"
registration_path = data_path / "17_registered_aseg_results_2D"

if not registration_path.exists():
    raise FileNotFoundError("You must run first the script of 'aseg_t1_registration.py' to generate the registered data")

# Un CSV par patient (tous timepoints/modalités regroupés) + analyse des lésions en 2D (slice par slice)
csv_folder_path = data_path / "13_lesion_info_per_patient_csv_2D"
csv_folder_path.mkdir(parents=True, exist_ok=True)


"""
01 and 07 -> 13
Génère un CSV par patient.
Analyse des lésions en 2D (composantes connexes slice par slice).
"""

# =============================
# INDEXATION DES FICHIERS
# =============================
aseg_split_path_dict = {}
t1_split_path_dict = {}

for aseg_path in registration_path.rglob("*_aseg.nii.gz"):
    parts = aseg_path.stem.split("_")
    arg = tuple(parts[:2]) if len(parts) == 3 else parts[0]
    aseg_split_path_dict[arg] = aseg_path

for t1_path in mslesseg_path.rglob("*_t1.nii.gz"):
    # Exemple: "sub-01_ses-01_t1.nii.gz" -> (sub-01, ses-01)
    parts = t1_path.stem.split("_")
    arg = tuple(parts[:2]) if len(parts) >= 2 else parts[0]
    t1_split_path_dict[arg] = t1_path

# =============================
# DICO DES ZONES FREE-SURFER
# =============================
zone_dict = {
    2: "Left-Cerebral-White-Matter",
    3: "Left-Cerebral-Cortex",
    4: "Left-Lateral-Ventricle",
    5: "Left-Inf-Lat-Vent",
    7: "Left-Cerebellum-White-Matter",
    8: "Left-Cerebellum-Cortex",
    10: "Left-Thalamus-Proper",
    11: "Left-Caudate",
    12: "Left-Putamen",
    13: "Left-Pallidum",
    14: "3rd-Ventricle",
    15: "4th-Ventricle",
    16: "Brain-Stem",
    17: "Left-Hippocampus",
    18: "Left-Amygdala",
    24: "CSF",
    26: "Left-Accumbens-area",
    28: "Left-VentralDC",
    30: "Left-vessel",
    31: "Left-choroid-plexus",
    41: "Right-Cerebral-White-Matter",
    42: "Right-Cerebral-Cortex",
    43: "Right-Lateral-Ventricle",
    44: "Right-Inf-Lat-Vent",
    46: "Right-Cerebellum-White-Matter",
    47: "Right-Cerebellum-Cortex",
    49: "Right-Thalamus-Proper",
    50: "Right-Caudate",
    51: "Right-Putamen",
    52: "Right-Pallidum",
    53: "Right-Hippocampus",
    54: "Right-Amygdala",
    58: "Right-Accumbens-area",
    60: "Right-VentralDC",
    62: "Right-vessel",
    63: "Right-choroid-plexus",
    # Zones spécifiques souvent touchées par la SEP
    72: "5th-Ventricle",
    77: "WM-hypointensities",
    78: "Left-WM-hypointensities",
    79: "Right-WM-hypointensities",
    80: "Non-WM-hypointensities",
    81: "Left-Non-WM-hypointensities",
    82: "Right-Non-WM-hypointensities",
    250: "Fornix",
    251: "CC_Posterior",
    252: "CC_Mid_Posterior",
    253: "CC_Central",
    254: "CC_Mid_Anterior",
    255: "CC_Anterior",
}

# Colonnes de sortie pour les proportions de recouvrement par zone
zone_columns = [f"prop_{v}" for v in zone_dict.values()]

# =============================
# OUTILS / FONCTIONS
# =============================
def safe_stats(values: np.ndarray) -> dict:
    """Retourne stats robustes pour un tableau 1D (float)."""
    if values.size == 0:
        return {
            "mean": np.nan,
            "std": np.nan,
            "median": np.nan,
            "vmin": np.nan,
            "vmax": np.nan,
            "range": np.nan,
            "skew": np.nan,
            "kurt": np.nan,
        }
    v = values.astype(np.float64)
    vmin = float(np.min(v))
    vmax = float(np.max(v))
    return {
        "mean": float(np.mean(v)),
        "std": float(np.std(v)),
        "median": float(np.median(v)),
        "vmin": vmin,
        "vmax": vmax,
        "range": float(vmax - vmin),
        "skew": float(skew(v)) if v.size >= 3 else np.nan,
        "kurt": float(kurtosis(v)) if v.size >= 4 else np.nan,
    }


def lesion_zone_proportions(lesion_mask: np.ndarray, aseg: np.ndarray) -> dict:
    """
    Proportions du recouvrement de la lésion avec chaque zone.
    lesion_mask: binaire 3D (bool)
    aseg: labels 3D (int)
    Retour: { "prop_<ZoneName>": proportion }
    """
    vox = int(np.sum(lesion_mask))
    if vox == 0:
        return {}

    props = {}
    for label_id, zone_name in zone_dict.items():
        overlap = int(np.sum(lesion_mask & (aseg == label_id)))
        props[f"prop_{zone_name}"] = overlap / vox
    return props


def connected_components_2d_slice_by_slice(mask3d: np.ndarray, connectivity: int = 1):
    """
    Détecte les composantes connexes en 2D, indépendamment sur chaque slice axiale.
    mask3d: bool 3D (Z, Y, X) ou (k, i, j)
    Retour:
      - label3d: int 3D avec labels uniques globalement (0 = fond)
      - n_components: total composantes
    """
    zdim = mask3d.shape[0]
    label3d = np.zeros(mask3d.shape, dtype=np.int32)
    current = 0
    struct = None
    if connectivity == 1:
        # 4-connectivité en 2D
        struct = np.array([[0,1,0],[1,1,1],[0,1,0]], dtype=np.int32)
    else:
        # 8-connectivité en 2D
        struct = np.ones((3,3), dtype=np.int32)

    for z in range(zdim):
        lbl2d, n = cc_label(mask3d[z].astype(np.int32), structure=struct)
        if n == 0:
            continue
        # Remap labels 1..n vers current+1..current+n
        lbl2d_remap = lbl2d.copy()
        for k in range(1, n + 1):
            current += 1
            lbl2d_remap[lbl2d == k] = current
        label3d[z] = lbl2d_remap

    return label3d, current


# =============================
# BOUCLE PATIENTS / TIMEPOINTS
# =============================

# On regroupe par patient
# keys dans t1_split_path_dict / aseg_split_path_dict: (sub-XX, ses-YY)
patients = sorted({k[0] if isinstance(k, tuple) else k for k in t1_split_path_dict.keys()})

# Colonnes CSV (fixes)
base_columns = [
    "patient",
    "timepoint",
    "modality",
    "lesion_id",
    "lesion_voxels",
    "lesion_volume_mm3",
    "lesion_mean",
    "lesion_std",
    "lesion_median",
    "lesion_min",
    "lesion_max",
    "lesion_range",
    "lesion_skew",
    "lesion_kurt",
]
columns = base_columns + zone_columns

for patient in patients:
    csv_path = csv_folder_path / f"{patient}_lesions_2D.csv"

    # On cherche tous les timepoints du patient
    timepoints = sorted({k[1] for k in t1_split_path_dict.keys() if isinstance(k, tuple) and k[0] == patient})

    if len(timepoints) == 0:
        # fallback si pas en tuple (cas rare)
        print(f"[WARN] Aucun timepoint trouvé pour {patient}")
        continue

    with open(csv_path, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=columns)
        writer.writeheader()

        for timepoint in timepoints:
            key = (patient, timepoint)

            if key not in t1_split_path_dict:
                print(f"[WARN] T1 manquant pour {patient} {timepoint}")
                continue
            if key not in aseg_split_path_dict:
                print(f"[WARN] ASEG manquant pour {patient} {timepoint}")
                continue

            t1_path = t1_split_path_dict[key]
            aseg_path = aseg_split_path_dict[key]

            # Lecture images
            t1_img = sitk.ReadImage(str(t1_path))
            aseg_img = sitk.ReadImage(str(aseg_path))

            t1 = sitk.GetArrayFromImage(t1_img).astype(np.float32)  # (Z,Y,X)
            aseg = sitk.GetArrayFromImage(aseg_img).astype(np.int32)

            # Résolution voxel (mm) => volume voxel mm^3
            spacing = t1_img.GetSpacing()  # (X,Y,Z) en SITK
            voxel_volume_mm3 = float(spacing[0] * spacing[1] * spacing[2])

            # Segmentation lésions : on utilise les labels "WM-hypointensities" (77-79) + Non-WM-hypointensities (80-82)
            lesion_mask = np.isin(aseg, [77, 78, 79, 80, 81, 82])

            # Connected components 2D slice-by-slice
            lesion_labels, n_lesions = connected_components_2d_slice_by_slice(lesion_mask, connectivity=1)

            lesion_counter = 0
            for lesion_id in range(1, n_lesions + 1):
                comp = (lesion_labels == lesion_id)
                voxels = int(np.sum(comp))
                if voxels == 0:
                    continue

                lesion_counter += 1

                # Intensités t1 dans la lésion
                vals = t1[comp]
                st = safe_stats(vals)

                # Proportions par zone
                zone_props = lesion_zone_proportions(comp, aseg)

                row = {
                    "patient": patient,
                    "timepoint": timepoint,
                    "modality": "t1",
                    "lesion_id": lesion_counter,
                    "lesion_voxels": voxels,
                    "lesion_volume_mm3": round(voxels * voxel_volume_mm3, 3),
                    "lesion_mean": round(st["mean"], 3) if np.isfinite(st["mean"]) else st["mean"],
                    "lesion_std": round(st["std"], 3) if np.isfinite(st["std"]) else st["std"],
                    "lesion_median": round(st["median"], 3) if np.isfinite(st["median"]) else st["median"],
                    "lesion_min": round(st["vmin"], 3) if np.isfinite(st["vmin"]) else st["vmin"],
                    "lesion_max": round(st["vmax"], 3) if np.isfinite(st["vmax"]) else st["vmax"],
                    "lesion_range": round(st["range"], 3) if np.isfinite(st["range"]) else st["range"],
                    "lesion_skew": round(st["skew"], 3) if np.isfinite(st["skew"]) else st["skew"],
                    "lesion_kurt": round(st["kurt"], 3) if np.isfinite(st["kurt"]) else st["kurt"],
                }

                for col in zone_columns:
                    row[col] = zone_props.get(col, 0.0)

                writer.writerow(row)

            print(f"[{patient} {timepoint}] lésions 2D écrites : {lesion_counter}")

    print(f"CSV généré : {csv_path}")


FileNotFoundError: You must run first the script of 'aseg_t1_registration.py' to generate the registered data