In [None]:
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install nibabel
!pip install scipy
!pip install matplotlib

Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-o190k8ad
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-o190k8ad
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: segment_anything
  Building wheel for segment_anything (setup.py) ... [?25l[?25hdone
  Created wheel for segment_anything: filename=segment_anything-1.0-py3-none-any.whl size=36592 sha256=488ee06d520fbc5f79728146fe6f4220f0f6f9a682efd68f86e2ef86345b7e5a
  Stored in directory: /tmp/pip-ephem-wheel-cache-pbrufgzt/wheels/29/82/ff/04e2be9805a1cb48bec0b85b5a6da6b63f647645750a0e42d4
Successfully built segment_anything
Installing collected packages: segment_anything
Successfully 

In [None]:
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth #importer les poids

--2025-11-15 13:55:56--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 65.8.76.77, 65.8.76.89, 65.8.76.35, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|65.8.76.77|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1249524607 (1.2G) [binary/octet-stream]
Saving to: ‘sam_vit_l_0b3195.pth’


2025-11-15 13:56:05 (136 MB/s) - ‘sam_vit_l_0b3195.pth’ saved [1249524607/1249524607]



In [None]:
from google.colab import drive # se connecter au drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
from pathlib import Path
import numpy as np
import nibabel as nib
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from scipy.spatial.distance import directed_hausdorff
import matplotlib.pyplot as plt

In [None]:
torch.cuda.is_available() # on vérifie qu'on peut utiliser un gpu

True

In [None]:
def load_nii(path):
    img = nib.load(str(path))
    arr = img.get_fdata(dtype=np.float32)
    arr = np.transpose(arr, (2,1,0))  # (Z,Y,X)
    return arr, img

def zscore(x, eps=1e-6):
    m, s = x.mean(), x.std()
    return (x - m) / (s + eps)

def dice_score(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    return 2. * intersection / (np.sum(y_true) + np.sum(y_pred) + 1e-6)

def iou_score(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-6)

In [None]:
ROOT = Path("/content/drive/MyDrive/projet_ML_CS/train")
OUT_DIR = Path("/content/drive/MyDrive/projet_ML_CS/sam_predictions")
OUT_DIR.mkdir(exist_ok=True, parents=True)

SAM_MODEL = "vit_l"
SAM_CKPT = "sam_vit_l_0b3195.pth"
DEVICE = "cuda"

In [None]:
sam = sam_model_registry[SAM_MODEL](checkpoint=SAM_CKPT)
sam.to(device=DEVICE)
mask_generator = SamAutomaticMaskGenerator(sam)

In [None]:
patient_dir = sorted(ROOT.glob("*"))[0]  # test sur le premier patient
print(f"Patient : {patient_dir.name}")

flair_path = next(patient_dir.glob("*flair*.nii*"), None)
seg_path = next(patient_dir.glob("*seg*.nii*"), None)

flair, ref_img = load_nii(flair_path)
seg, _ = load_nii(seg_path)
flair = zscore(flair)
seg = (seg > 0).astype(np.uint8)

pred_mask = np.zeros_like(flair, dtype=np.uint8)

# --- SAM slice par slice
for z in range(flair.shape[0]):
    slice2d = flair[z]

    # Ajustement contraste et conversion en uint8
    slice2d_uint8 = ((slice2d - slice2d.min()) / (slice2d.max() - slice2d.min() + 1e-6) * 255).astype(np.uint8)

    # Passage en 3 channels
    slice2d_3c = np.stack([slice2d_uint8]*3, axis=-1)

    masks = mask_generator.generate(slice2d_3c)

    print(f"Slice {z} : {len(masks)} masques générés")

    if masks:
        largest = max(masks, key=lambda x: x['segmentation'].sum())
        pred_mask[z] = largest['segmentation'].astype(np.uint8)

        # Affichage pour debug
        plt.figure(figsize=(6,6))
        plt.imshow(slice2d_uint8, cmap='gray')
        plt.imshow(largest['segmentation'], alpha=0.5)
        plt.title(f"Slice {z} - masque superposé")
        plt.show()

# --- Sauvegarde masque
out_path = OUT_DIR / f"{patient_dir.name}_sam_mask.nii.gz"
nib.save(nib.Nifti1Image(np.transpose(pred_mask, (2,1,0)), affine=ref_img.affine), str(out_path))
print(f"Masque sauvegardé : {out_path}")

# --- Calcul des scores
dice = dice_score(seg, pred_mask)
iou = iou_score(seg, pred_mask)
print(f"Dice={dice:.3f}, IoU={iou:.3f}")

Output hidden; open in https://colab.research.google.com to view.

In [None]:
results = []

# Boucle sur tous les patients (ou un sous-ensemble pour test)
for patient_dir in sorted(ROOT.glob("*")):
    print(f"\n=== Traitement du patient : {patient_dir.name} ===")

    flair_path = next(patient_dir.glob("*flair*.nii*"), None)
    seg_path = next(patient_dir.glob("*seg*.nii*"), None)

    if not flair_path or not seg_path:
        print(f" Données manquantes pour {patient_dir.name}", flush=True)
        continue

    # --- Charger les volumes
    flair, ref_img = load_nii(flair_path)
    seg, _ = load_nii(seg_path)

    print(f"Flair shape : {flair.shape}, min/max : {flair.min():.3f}/{flair.max():.3f}")
    print(f"Seg shape : {seg.shape}, min/max : {seg.min():.3f}/{seg.max():.3f}")

    flair = zscore(flair)
    seg = (seg > 0).astype(np.uint8)
    pred_mask = np.zeros_like(flair, dtype=np.uint8)

    # --- SAM slice par slice
    for z in range(flair.shape[0]):
        slice2d = flair[z]

        # Conversion en uint8 avec contraste
        slice2d_uint8 = ((slice2d - slice2d.min()) / (slice2d.max() - slice2d.min() + 1e-6) * 255).astype(np.uint8)

        # Passage en 3 channels
        slice2d_3c = np.stack([slice2d_uint8]*3, axis=-1)

        masks = mask_generator.generate(slice2d_3c)
        print(f"Slice {z} : {len(masks)} masques générés", flush=True)

        if masks:
            # Conserver le plus grand masque
            largest = max(masks, key=lambda x: x['segmentation'].sum())
            pred_mask[z] = largest['segmentation'].astype(np.uint8)

    # --- Sauvegarde masque
    out_path = OUT_DIR / f"{patient_dir.name}_sam_mask.nii.gz"
    nib.save(nib.Nifti1Image(np.transpose(pred_mask, (2,1,0)), affine=ref_img.affine), str(out_path))
    print(f"Masque sauvegardé : {out_path}")

    # --- Scores
    dice = dice_score(seg, pred_mask)
    iou = iou_score(seg, pred_mask)
    print(f"{patient_dir.name} : Dice={dice:.3f}, IoU={iou:.3f}")
    results.append({"patient": patient_dir.name, "dice": dice, "iou": iou})

# --- Résumé global
if results:
    mean_dice = np.mean([r["dice"] for r in results])
    mean_iou = np.mean([r["iou"] for r in results])
    print(f"\nDice moyen : {mean_dice:.3f}, IoU moyen : {mean_iou:.3f}")
else:
    print(" Aucun résultat à calculer", flush=True)