## IMPORTS

In [None]:
ON_COLAB = False

In [None]:
# Installer les d√©pendances
if ON_COLAB:
    !pip install torch>=2.0.0 torchvision>=0.15.0 lightning>=2.0.0 torchmetrics>=0.11.4 \
    hydra-core==1.3.2 hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 \
    mlflow opencv-python Pillow ultralytics tifffile \
    rootutils pre-commit rich pytest tqdm pandas
    !pip install codecarbon

In [None]:
import csv
from datetime import datetime
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import tifffile
import torch
from codecarbon import EmissionsTracker
from skimage.metrics import structural_similarity as ssim
from ultralytics import SAM

In [None]:
if ON_COLAB:
    from google.colab import drive

    drive.mount("/content/drive")

    # Acc√©der au fichier
    tif_path = "/content/drive/MyDrive/Romane_Martin_urne_sature_10-4.tif"
else:
    tif_path = "../data/Romane_Martin_urne_sature_10-4.tif"

## PARAMETRES

In [None]:
GRID_STRIDE = 64  # densit√© des points
POINTS_PER_CALL = 30  # taille des chunks envoy√©s √† SAM
MIN_AREA = 300  # aire min d‚Äôun masque (en pixels)
CONF_THR = 0.35  # seuil min confiance de SAM
DEDUP_IOU_THR = 0.90  # seuil de d√©-duplication
MIN_MASK_REGION_AREA = 200
IMAGE_3D_PATH = tif_path
SAM_WEIGHTS = "../data/sam_b.pt"  # path sam en local ou bien download from ultralytics
POINT_LABEL = 1

## UTILS

In [None]:
def make_grid_points(h, w, stride, label=1):
    xs = np.arange(stride // 2, w, stride)
    ys = np.arange(stride // 2, h, stride)
    pts = [(int(x), int(y)) for y in ys for x in xs]  # (x, y)
    lbls = [label] * len(pts)
    return pts, lbls


def dedup_by_iou(masks, iou_thr=0.9, min_area=0):
    keep = []
    for m in masks:
        if m.sum() < min_area:
            continue
        if any(
            (np.logical_and(m, k).sum() / max(np.logical_or(m, k).sum(), 1)) > iou_thr
            for k in keep
        ):
            continue
        keep.append(m)
    return keep


def colorize_masks(image_gray, masks_bool, seed=42):

    out = np.dstack([image_gray, image_gray, image_gray]).copy()  # [H, W, 3]
    out = out.astype(np.uint8, copy=False)
    rng = np.random.default_rng(seed)
    for m in masks_bool:
        out[m] = rng.integers(0, 256, size=3, dtype=np.uint8)  # couleur al√©atoire (R,G,B)
    return out

## CHARGEMENT DONNEES & MODELE


## chargement donn√©es

In [None]:
def to_sam_handled_picture(picture_3D: np.ndarray) -> np.ndarray:
    return np.repeat(picture_3D[..., np.newaxis], 3, -1)

In [None]:
# exemple extraction slice au milieu
vol = tifffile.imread(IMAGE_3D_PATH)
print(f"Volume: shape={vol.shape}, dtype={vol.dtype}")
mid = len(vol) // 2
sl = vol[mid]

In [None]:
plt.imshow(sl, cmap="gray")
plt.show()

## chargement SAM

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

if not Path(SAM_WEIGHTS).exists():
    print(f"Downloading {SAM_WEIGHTS}...")
    model = SAM(SAM_WEIGHTS)
else:
    print(f"{SAM_WEIGHTS} already exists locally")
    model = SAM(SAM_WEIGHTS)

print("SAM charg√©")

## INFERENCE PAR CHUNKS DE POINTS


In [None]:
from typing import Union


def run_sam_chunked_points(
    model: SAM,
    image,
    points: Union[list[tuple[int, int]], int],
    labels=None,
    conf=0.35,
    points_per_call=25,
    device="cpu",
    binarize_thr=0.5,
):
    img = image if image.ndim == 3 else np.dstack([image, image, image])
    # >>>> ???? >>>>
    if img.dtype != np.uint8:
        mn, mx = float(img.min()), float(img.max())
        img = (
            ((img - mn) / (mx - mn) * 255).astype(np.uint8)
            if mx > mn
            else np.zeros_like(img, dtype=np.uint8)
        )
    img = np.ascontiguousarray(img)
    # <<<< ???? <<<<
    masks = []
    with torch.inference_mode():
        if isinstance(points, int):
            r = model.predict(
                source=img,
                points_stride=points,
                points_batch_size=points_per_call,
                conf=conf,
                device=device,
            )
            if getattr(r, "masks", None) is not None and r.masks is not None:
                t = r.masks.data  # [N,H,W]
                arr = (t > binarize_thr).cpu().numpy()
                masks.extend(mi.astype(bool) for mi in arr)
        else:
            for i in range(0, len(points), points_per_call):
                r = model.predict(
                    source=img,
                    points=points[i : i + points_per_call],
                    labels=labels[i : i + points_per_call],
                    conf=conf,
                    device=device,
                )[0]
                if getattr(r, "masks", None) is not None and r.masks is not None:
                    t = r.masks.data  # [N,H,W]
                    arr = (t > binarize_thr).cpu().numpy()
                    masks.extend(mi.astype(bool) for mi in arr)
        if device == "cuda":
            torch.cuda.empty_cache()
    return masks  # non d√©dupliqu√©s

### grille de points et inf√©rence

In [None]:
type(sl)

In [None]:
tracker = EmissionsTracker()
tracker.start()
H, W = sl.shape[:2]
points, labels = make_grid_points(H, W, GRID_STRIDE, label=POINT_LABEL)
print(f"Points de grille: {len(points)}  (stride={GRID_STRIDE})")

In [None]:
raw_masks = run_sam_chunked_points(
    model, sl, points, labels, conf=CONF_THR, points_per_call=POINTS_PER_CALL, device=device
)

In [None]:
masks = dedup_by_iou(raw_masks, iou_thr=DEDUP_IOU_THR, min_area=MIN_AREA)
print(f" {len(masks)} masques apr√®s filtrage & d√©-dup")

## SAVE VISUALISATION

In [None]:
if len(masks) == 0:
    print(" Aucun objet d√©tect√© ")
else:
    masks_sorted = sorted(masks, key=lambda x: x.sum(), reverse=True)
    colored = colorize_masks(sl, masks_sorted, seed=0)
    plt.imshow(colored)
    plt.axis("off")
    plt.show()
    cv2.imwrite(f"sam_colored_slice_{mid}.png", colored)
    print(f" Sauvegarde: sam_colored_slice_{mid}.png")

## Segmentation avec bounding box


In [None]:
def make_grid_boxes(h, w, stride, box_size=64):
    xs = np.arange(stride // 2, w, stride)
    ys = np.arange(stride // 2, h, stride)
    boxes = []
    for y in ys:
        for x in xs:
            x1 = max(0, x - box_size // 2)
            y1 = max(0, y - box_size // 2)
            x2 = min(w, x + box_size // 2)
            y2 = min(h, y + box_size // 2)
            boxes.append([x1, y1, x2, y2])
    return boxes


def run_sam_chunked_boxes(
    model: SAM, image, boxes, conf=0.35, boxes_per_call=25, device="cpu", binarize_thr=0.5
):
    img = image if image.ndim == 3 else np.dstack([image, image, image])
    if img.dtype != np.uint8:
        mn, mx = float(img.min()), float(img.max())
        img = (
            ((img - mn) / (mx - mn) * 255).astype(np.uint8)
            if mx > mn
            else np.zeros_like(img, dtype=np.uint8)
        )
    img = np.ascontiguousarray(img)

    masks = []
    with torch.inference_mode():
        for i in range(0, len(boxes), boxes_per_call):
            r = model.predict(
                source=img, bboxes=boxes[i : i + boxes_per_call], conf=conf, device=device
            )[0]
            if getattr(r, "masks", None) is not None and r.masks is not None:
                t = r.masks.data
                arr = (t > binarize_thr).cpu().numpy()
                masks.extend(mi.astype(bool) for mi in arr)
        if device == "cuda":
            torch.cuda.empty_cache()
    return masks

## Analyse de l'IOU entre Slices


In [None]:
# Tracking d'objets entre slices
# we consider a binary mask
def match_masks_between_slices(masks_prev, masks_curr, iou_threshold=0.3):
    # combiner les masques dans un seul mask binaire
    mask_prev_combined = np.zeros_like(masks_prev[0])
    for mask in masks_prev:
        mask_prev_combined |= mask

    # combiner les masques dans un seul mask binaire
    mask_curr_combined = np.zeros_like(masks_curr[0])
    for mask in masks_curr:
        mask_curr_combined |= mask

    # calculer l'IOU entre les deux masks combin√©es
    intersection = np.logical_and(mask_prev_combined, mask_curr_combined)
    union = np.logical_or(mask_prev_combined, mask_curr_combined)
    iou = np.sum(intersection) / np.sum(union)

    return iou

In [None]:
# generating multi slice segmentation
print(f"Volume: shape={vol.shape}, dtype={vol.dtype}")

In [None]:
# Calculate the iou for 10 middle slices
slice_beginning = mid - 10
slice_end = mid + 10
iou_values = []

points, labels = make_grid_points(H, W, GRID_STRIDE, label=POINT_LABEL)

for i in range(slice_beginning, slice_end):
    print(f"Slice{i - slice_beginning}")
    masks_prev = run_sam_chunked_points(
        model,
        vol[i],
        points,
        labels,
        conf=CONF_THR,
        points_per_call=POINTS_PER_CALL,
        device=device,
    )
    masks_curr = run_sam_chunked_points(
        model,
        vol[i + 1],
        points,
        labels,
        conf=CONF_THR,
        points_per_call=POINTS_PER_CALL,
        device=device,
    )
    iou = match_masks_between_slices(masks_prev, masks_curr)
    iou_values.append(iou)


# printing ious using different grids
print("using point grids\n")
print(iou_values)

In [None]:
# AVEC BOUNDING BOX

iou_values_boxes = []
boxes = make_grid_boxes(H, W, GRID_STRIDE, box_size=64)

for i in range(slice_beginning, slice_end):
    print(f"[Boxes] Slice {i - slice_beginning}")
    masks_prev = run_sam_chunked_boxes(
        model, vol[i], boxes, conf=CONF_THR, boxes_per_call=POINTS_PER_CALL, device=device
    )
    masks_curr = run_sam_chunked_boxes(
        model, vol[i + 1], boxes, conf=CONF_THR, boxes_per_call=POINTS_PER_CALL, device=device
    )
    iou = match_masks_between_slices(masks_prev, masks_curr)
    iou_values_boxes.append(iou)

print("\nUsing bounding boxes:\n")
print(iou_values_boxes)

In [None]:
# IOU comparaison between points grid and bounding boxes
mean_iou_points = np.mean(iou_values)
mean_iou_boxes = np.mean(iou_values_boxes)

print(f"\nMoyenne IoU (points): {mean_iou_points:.4f}")
print(f"Moyenne IoU (boxes) : {mean_iou_boxes:.4f}")

plt.figure()
plt.plot(iou_values, label="Points grid")
plt.plot(iou_values_boxes, label="Bounding boxes")
plt.xlabel("Slice index")
plt.ylabel("IoU")
plt.legend()
plt.title("Comparaison IoU entre slices (points vs boxes)")
plt.show()

### WRITE CODECARBONE RESULTS AND STOP TRACKER

In [None]:
emissions = tracker.stop()
with open("emissions.csv", "a", newline="") as f:
    writer = csv.writer(f)
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    writer.writerow([timestamp, emissions])

In [None]:
from tqdm import tqdm


# ============================================================================
# FONCTION SIMPLIFI√âE: CALCULER SSIM ENTRE COUPES CONS√âCUTIVES
# ============================================================================
def calculate_consecutive_ssim(vol):
    """
    Calcule uniquement les SSIM entre coupes cons√©cutives (i, i+1)

    Args:
        vol: Volume 3D (Z, Y, X)

    Returns:
        ssim_values: Array des SSIM [slice_0 vs slice_1, slice_1 vs slice_2, ...]
        indices: Array des indices [0, 1, 2, ..., n-2]
    """
    n_slices = vol.shape[0]
    ssim_values = np.zeros(n_slices - 1)

    print(f"\n{'='*60}")
    print("üîÑ CALCUL SSIM ENTRE COUPES CONS√âCUTIVES")
    print(f"{'='*60}")
    print(f"Nombre de coupes: {n_slices}")
    print(f"Nombre de comparaisons: {n_slices - 1}")

    # Calculer avec barre de progression
    for i in tqdm(range(n_slices - 1), desc="Calcul SSIM", unit="comp"):
        slice_i = vol[i, :, :]
        slice_j = vol[i + 1, :, :]
        ssim_values[i] = ssim(slice_i, slice_j, data_range=255)

    indices = np.arange(n_slices - 1)

    print("\n‚úÖ SSIM calcul√©s avec succ√®s!")
    print(f"   Nombre de valeurs: {len(ssim_values)}")
    print(
        f"   SSIM min: {ssim_values.min():.4f} (entre coupes {np.argmin(ssim_values)} et {np.argmin(ssim_values)+1})"
    )
    print(f"   SSIM max: {ssim_values.max():.4f}")
    print(f"   SSIM moyen: {ssim_values.mean():.4f}")

    return ssim_values, indices


# ============================================================================
# FONCTION: PLOT SIMPLE DU SSIM
# ============================================================================
def plot_ssim_simple(ssim_values, indices):
    """
    Affiche simplement les SSIM avec la moyenne - sans d√©tection ni statistiques

    Args:
        ssim_values: Array des SSIM
        indices: Array des indices
    """
    fig, ax = plt.subplots(1, 1, figsize=(14, 6))

    # Plot des SSIM
    ax.plot(indices, ssim_values, "b-", linewidth=2, marker="o", markersize=3, label="SSIM")

    # Ligne de moyenne
    mean_ssim = np.mean(ssim_values)
    ax.axhline(
        y=mean_ssim, color="r", linestyle="--", linewidth=2, label=f"Moyenne = {mean_ssim:.4f}"
    )

    # Labels et titre
    ax.set_xlabel("Index de coupe Z", fontsize=12, fontweight="bold")
    ax.set_ylabel("SSIM", fontsize=12, fontweight="bold")
    ax.set_title("SSIM entre coupes cons√©cutives", fontsize=14, fontweight="bold", pad=15)

    # Grille et l√©gende
    ax.grid(True, alpha=0.3, linestyle="--")
    ax.legend(fontsize=11, loc="best")
    ax.set_ylim([0, 1.05])

    plt.tight_layout()
    plt.show()


# ============================================================================
# FONCTION: CALCULER SSIM AVEC UN DELTA SP√âCIFIQUE
# ============================================================================
def calculate_ssim_with_delta(vol, delta=1):
    """
    Calcule les SSIM avec un espacement donn√© entre les coupes

    Args:
        vol: Volume 3D (Z, Y, X)
        delta: Espacement entre les coupes (1=cons√©cutif, 2=tous les 2, etc.)

    Returns:
        ssim_values: Array des SSIM
        indices: Array des indices de d√©but
    """
    n_slices = vol.shape[0]
    ssim_values = []
    indices = []

    print(f"\n{'='*60}")
    print(f"üîÑ CALCUL SSIM AVEC DELTA={delta}")
    print(f"{'='*60}")
    print(f"Nombre de coupes: {n_slices}")

    z = 0
    with tqdm(desc=f"Calcul SSIM (delta={delta})", unit="comp") as pbar:
        while z + delta < n_slices:
            slice_i = vol[z, :, :]
            slice_j = vol[z + delta, :, :]
            ssim_values.append(ssim(slice_i, slice_j, data_range=255))
            indices.append(z)
            z += delta
            pbar.update(1)

    ssim_values = np.array(ssim_values)
    indices = np.array(indices)

    print(f"‚úÖ SSIM calcul√©s: {len(ssim_values)} valeurs")

    return ssim_values, indices

In [None]:
# ============================================================================
# CALCUL ET VISUALISATION SIMPLE DES SSIM
# ============================================================================

# 1. Calculer les SSIM entre coupes cons√©cutives
# ssim_values, indices = calculate_consecutive_ssim(vol)

ssim_values = np.load("ssim_consecutive.npy")
indices = np.arange(len(ssim_values))  # si indices n‚Äôont pas √©t√© sauvegard√©s
# 2. Sauvegarder les r√©sultats
np.save("ssim_consecutive.npy", ssim_values)
print("\nüíæ R√©sultats sauvegard√©s: ssim_consecutive.npy")

# 3. Plot simple : SSIM + moyenne uniquement
plot_ssim_simple(ssim_values, indices)

In [None]:
# ============================================================================
# ANALYSE MANUELLE DES SSIM (√† votre convenance)
# ============================================================================

# Vous pouvez maintenant analyser les donn√©es comme vous voulez :

# Trouver les SSIM les plus faibles
min_ssim_idx = np.argmin(ssim_values)
print(
    f"\nSSIM minimum: {ssim_values[min_ssim_idx]:.4f} entre coupes {min_ssim_idx} et {min_ssim_idx+1}"
)

# Trouver toutes les valeurs sous un certain seuil
threshold = 0.85
low_ssim_indices = np.where(ssim_values < threshold)[0]
print(f"\nNombre de SSIM < {threshold}: {len(low_ssim_indices)}")
if len(low_ssim_indices) > 0:
    print("Indices concern√©s:")
    for idx in low_ssim_indices[:10]:  # Afficher les 10 premiers
        print(f"  Coupe {idx} ‚Üí {idx+1}: SSIM = {ssim_values[idx]:.4f}")

# Statistiques
print("\nStatistiques:")
print(f"  Moyenne: {ssim_values.mean():.4f}")
print(f"  M√©diane: {np.median(ssim_values):.4f}")
print(f"  √âcart-type: {ssim_values.std():.4f}")

In [None]:
# ============================================================================
# OPTIONNEL: Comparer plusieurs deltas
# ============================================================================

# Pour chaque delta, on doit RECALCULER les SSIM car on compare des coupes diff√©rentes
# Delta=2 : compare (0,2), (2,4), (4,6)... et NON pas (0,1), (2,3), (4,5)...

# Delta 1 d√©j√† calcul√© dans ssim_consecutive.npy, on calcule juste les autres
deltas_to_test = [2, 3, 4, 5, 10]

fig, axes = plt.subplots(len(deltas_to_test) + 1, 1, figsize=(14, 4 * (len(deltas_to_test) + 1)))

# Premier subplot : delta=1 (d√©j√† calcul√©)
axes[0].plot(indices, ssim_values, "b-", linewidth=2, marker="o", markersize=3)
axes[0].axhline(
    y=np.mean(ssim_values),
    color="r",
    linestyle="--",
    linewidth=2,
    label=f"Moyenne = {np.mean(ssim_values):.4f}",
)
axes[0].set_xlabel("Index Z", fontsize=11, fontweight="bold")
axes[0].set_ylabel("SSIM", fontsize=11, fontweight="bold")
axes[0].set_title(f"Delta = 1 ({len(ssim_values)} comparaisons)", fontsize=12, fontweight="bold")
axes[0].grid(True, alpha=0.3)
axes[0].legend(fontsize=10)
axes[0].set_ylim([0, 1.05])

# Subplots suivants : autres deltas
for idx, delta in enumerate(deltas_to_test):
    # Calculer les SSIM pour ce delta sp√©cifique
    ssim_vals, inds = calculate_ssim_with_delta(vol, delta=delta)

    # Plot
    axes[idx + 1].plot(inds, ssim_vals, "b-", linewidth=2, marker="o", markersize=3)
    axes[idx + 1].axhline(
        y=np.mean(ssim_vals),
        color="r",
        linestyle="--",
        linewidth=2,
        label=f"Moyenne = {np.mean(ssim_vals):.4f}",
    )
    axes[idx + 1].set_xlabel("Index Z", fontsize=11, fontweight="bold")
    axes[idx + 1].set_ylabel("SSIM", fontsize=11, fontweight="bold")
    axes[idx + 1].set_title(
        f"Delta = {delta} ({len(ssim_vals)} comparaisons)", fontsize=12, fontweight="bold"
    )
    axes[idx + 1].grid(True, alpha=0.3)
    axes[idx + 1].legend(fontsize=10)
    axes[idx + 1].set_ylim([0, 1.05])

plt.tight_layout()
plt.savefig("ssim_multiple_deltas.png", dpi=300, bbox_inches="tight")
plt.show()