### Trouver les évolutions des lésions en comparant deux timepoints différents d'un seul patient (non automatisé)

In [None]:
import SimpleITK as sitk
import math
import csv
from pathlib import Path

# MODIFIABLE PARAMETERS : Modifier selon le patient et les timepoints souhaités
PATIENT_ID = "P1"   # Patient_id varie de P1 à P53 mais pas pour de P54 à P75 car il n'y a qu'un seul timepoint
DEPARTURE_TIMEPOINT = "T1"
ARRIVAL_TIMEPOINT = "T2"

data_path = Path.cwd().parent.parent / "data"

variation_folder_path = data_path / "11_lesion_evolution_results"
variation_folder_path.mkdir(parents = True, exist_ok = True)

departure_mask_path = data_path / f"01_MSLesSeg_Dataset/train/{PATIENT_ID}/{DEPARTURE_TIMEPOINT}/{PATIENT_ID}_{DEPARTURE_TIMEPOINT}_MASK.nii.gz"
arrival_mask_path = data_path / f"01_MSLesSeg_Dataset/train/{PATIENT_ID}/{ARRIVAL_TIMEPOINT}/{PATIENT_ID}_{ARRIVAL_TIMEPOINT}_MASK.nii.gz"

if not departure_mask_path.exists():
    arg = "Chemin non validé (attention de ne pas mettre un patient de P54 à P75 car il n'y a qu'un seul timepoint T1) "
    raise ValueError(arg)

if sum(1 for sub_dir in departure_mask_path.parent.iterdir() if sub_dir.is_dir()) < 2:
    arg = f"Il y a un seul timepoint T1 pour le patient {PATIENT_ID}, changer de patient svp."
    raise ValueError(arg)


In [None]:



# ==========================
# 2) Lecture des deux masques
# ==========================
mask_T1_img = sitk.ReadImage(departure_mask_path)
mask_T2_img = sitk.ReadImage(arrival_mask_path)

print("T1 mask size :", mask_T1_img.GetSize())
print("T2 mask size :", mask_T2_img.GetSize())
print("T1 pixel type:", mask_T1_img.GetPixelIDTypeAsString())
print("T2 pixel type:", mask_T2_img.GetPixelIDTypeAsString())
print()

# Vérifier alignement (même espace)
same_size      = (mask_T1_img.GetSize()      == mask_T2_img.GetSize())
same_origin    = (mask_T1_img.GetOrigin()    == mask_T2_img.GetOrigin())
same_spacing   = (mask_T1_img.GetSpacing()   == mask_T2_img.GetSpacing())
same_direction = (mask_T1_img.GetDirection() == mask_T2_img.GetDirection())

if not (same_size and same_origin and same_spacing and same_direction):
    raise RuntimeError("Les masques T1 et T2 ne sont PAS alignés (taille/origin/spacing/direction différents).")

# ==========================================
# 3) Binarisation + cast en UInt8
# ==========================================
mask_T1_bin = sitk.BinaryThreshold(mask_T1_img, 0.5, 1e9, 1, 0)
mask_T1_bin = sitk.Cast(mask_T1_bin, sitk.sitkUInt8)

mask_T2_bin = sitk.BinaryThreshold(mask_T2_img, 0.5, 1e9, 1, 0)
mask_T2_bin = sitk.Cast(mask_T2_bin, sitk.sitkUInt8)

# =======================================
# 4) Composantes connexes + stats
# =======================================
cc_T1 = sitk.ConnectedComponent(mask_T1_bin)
cc_T2 = sitk.ConnectedComponent(mask_T2_bin)

stats_T1 = sitk.LabelShapeStatisticsImageFilter()
stats_T1.Execute(cc_T1)

stats_T2 = sitk.LabelShapeStatisticsImageFilter()
stats_T2.Execute(cc_T2)

labels_T1 = sorted(list(stats_T1.GetLabels()))
labels_T2 = sorted(list(stats_T2.GetLabels()))

print(f"Nombre de lésions à T1 : {len(labels_T1)}")
print(f"Nombre de lésions à T2 : {len(labels_T2)}")
print()

# ==========================================
# 5) Fonction chevauchement T1 -> T2
# ==========================================
def find_overlapping_lesions_T2(label_T1):
    lesion_T1 = sitk.Equal(cc_T1, label_T1)
    lesion_T1 = sitk.Cast(lesion_T1, sitk.sitkUInt8)

    overlap_img = sitk.Mask(cc_T2, lesion_T1)

    overlap_stats = sitk.LabelShapeStatisticsImageFilter()
    overlap_stats.Execute(overlap_img)

    overlap_labels = sorted(list(overlap_stats.GetLabels()))

    result = {}
    for lab2 in overlap_labels:
        result[lab2] = overlap_stats.GetNumberOfPixels(lab2)  # nb voxels en overlap
    return result

# ==========================================
# 6) CSV – ouverture
# ==========================================
csv_path = "data/lesion_evolution.csv"

with open(csv_path, mode="w", newline="", encoding="utf-8") as f:
    w = csv.writer(f, delimiter=';')


    w.writerow([
        "temps",                 # T1 ou T2
        "label",                 # label à ce temps
        "label_match",           # label correspondant à l'autre temps (sinon vide)
        "type_lesion",           # appariee / disparue / nouvelle
        "classification",        # stable / progression / regression (si appariée)
        "volume_voxels",
        "volume_mm3",
        "centroid_x",
        "centroid_y",
        "centroid_z",
        "ratio_volume_T2surT1"   # rempli uniquement sur la ligne T1 appariée
    ])

    matched_T2_labels = set()

    # seuils de classification
    ratio_low  = 0.8
    ratio_high = 1.2

    # ==========================================
    # 7) Parcours T1 : appariement + disparues
    # ==========================================
    for lab1 in labels_T1:
        vol1_vox = stats_T1.GetNumberOfPixels(lab1)
        vol1_mm3 = stats_T1.GetPhysicalSize(lab1)
        c1 = stats_T1.GetCentroid(lab1)

        overlaps = find_overlapping_lesions_T2(lab1)

        if len(overlaps) == 0:
            # ---- Lésion disparue
            w.writerow([
                "T1",
                lab1,
                "",
                "disparue",
                "",
                vol1_vox,
                vol1_mm3,
                c1[0], c1[1], c1[2],
                ""
            ])
            continue

        # prendre la lésion T2 avec le plus grand overlap
        best_lab2 = max(overlaps, key=overlaps.get)
        matched_T2_labels.add(best_lab2)

        vol2_vox = stats_T2.GetNumberOfPixels(best_lab2)
        vol2_mm3 = stats_T2.GetPhysicalSize(best_lab2)
        c2 = stats_T2.GetCentroid(best_lab2)

        ratio = vol2_mm3 / (vol1_mm3 + 1e-6)

        if ratio_low <= ratio <= ratio_high:
            classification = "stable"
        elif ratio > ratio_high:
            classification = "progression"
        else:
            classification = "regression"

        # ---- Ligne T1 (appariée)
        w.writerow([
            "T1",
            lab1,
            best_lab2,
            "appariee",
            classification,
            vol1_vox,
            vol1_mm3,
            c1[0], c1[1], c1[2],
            ratio
        ])

        # ---- Ligne T2 (appariée)
        w.writerow([
            "T2",
            best_lab2,
            lab1,
            "appariee",
            classification,
            vol2_vox,
            vol2_mm3,
            c2[0], c2[1], c2[2],
            ""  # ratio laissé vide côté T2 (tu peux le dupliquer si tu veux)
        ])

    # ==========================================
    # 8) Parcours T2 : nouvelles
    # ==========================================
    for lab2 in labels_T2:
        if lab2 in matched_T2_labels:
            continue

        vol2_vox = stats_T2.GetNumberOfPixels(lab2)
        vol2_mm3 = stats_T2.GetPhysicalSize(lab2)
        c2 = stats_T2.GetCentroid(lab2)

        w.writerow([
            "T2",
            lab2,
            "",
            "nouvelle",
            "",
            vol2_vox,
            vol2_mm3,
            c2[0], c2[1], c2[2],
            ""
        ])

print(f"CSV généré : {csv_path}")
