In [7]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb

# Automatically generating object masks with SAM

Since SAM 2 can efficiently process prompts, masks for the entire image can be generated by sampling a large number of prompts over an image.

The class `SAM2AutomaticMaskGenerator` implements this capability. It works by sampling single-point input prompts in a grid over the image, from each of which SAM can predict multiple masks. Then, masks are filtered for quality and deduplicated using non-maximal suppression. Additional options allow for further improvement of mask quality and quantity, such as running prediction on multiple crops of the image or postprocessing masks to remove small disconnected regions and holes.

<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/automatic_mask_generator_example.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## Environment Set-up

If running locally using jupyter, first install `SAM 2` in your environment using the installation instructions in the repository.

If running from Google Colab, set `using_colab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'. Note that it's recommended to use **A100 or L4 GPUs when running in Colab** (T4 GPUs might also work, but could be slow and might run out of memory in some cases).

In [8]:
using_colab = False

In [9]:
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/sam2.git'

    !mkdir -p images
    !wget -P images https://raw.githubusercontent.com/facebookresearch/sam2/main/notebooks/images/cars.jpg

    !mkdir -p ../checkpoints/
    !wget -P ../checkpoints/ https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt

## Set-up

In [10]:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from sklearn.cluster import KMeans

In [11]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

using device: cpu


## Mask Generation with Sam2 model

In [12]:
from sam2MaskUtils import*

# --- Initialisation du modèle SAM2 ---

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

sam2_checkpoint = "../../models/sam2/checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

print("Chargement du modèle SAM2...")
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(sam2)
print("Modèle SAM2 chargé.")

# --- Parcours des images du dossier ---

image_dir = "../../../data/unified_set_rename/images/"
label_dir = "../../../data/unified_set_rename/labels/"
output_dir = "../../../data/sam2Labels00"
os.makedirs(output_dir, exist_ok=True)

# Récupère uniquement les fichiers PNG
image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')])

missing_files = ['00306', '00384', '00513', '01035', '01235', '01239', '01241', '01247', '01250', '01257', '01271', '01272', '01281', '01305', '01306', '01313', '01317', '01321', '01331', '01338', '01344', '01497', '01500', '01512', '01540', '01553', '01577', '01608', '01616', '01627', '01651', '01653', '01669', '01673', '01696', '01697', '01700', '01913', '01916', '01927', '01936', '01941', '02005', '02051', '02127', '02136', '02137', '02207']
image_files = [f"cell_{num}.png" for num in missing_files]

# Préparation du CSV pour enregistrer les résultats (écriture au fur et à mesure)
csv_path = os.path.join(output_dir, "sam2_f1_scores_onlyfull.csv")
csv_file = open(csv_path, mode='w', newline='')
csv_writer = csv.writer(csv_file)
csv_writer.writerow(["image_name","label_type","f1_score"])
csv_file.flush()

print(f"Nombre d'images à traiter : {len(image_files)}")

for image_file in tqdm(image_files, desc="Traitement des images"):
    # Extraction du numéro de l'image (exemple: cell_00225.png)
    base_name = os.path.splitext(image_file)[0]  # "cell_00225"
    try:
        image_number = base_name.split('_')[1]
    except IndexError:
        print(f"Format inattendu pour le nom de fichier : {image_file}")
        continue

    print(f"\nTraitement de l'image {image_file} (numéro {image_number})")

    # Chargement de l'image et de son annotation
    image_path = os.path.join(image_dir, image_file)
    label_file = f"cell_{image_number}_label.tiff"
    label_path = os.path.join(label_dir, label_file)

    try:
        image = np.array(Image.open(image_path).convert("RGB"))
        imageSoluce = np.array(Image.open(label_path).convert("RGB"))
    except Exception as e:
        print(f"Erreur lors du chargement de l'image ou de l'annotation pour {image_file}: {e}")
        continue

     # Génération des masques SAM2
    all_anns  = mask_generator.generate(image)
    """
    filt_anns = filter_masks_by_color_sam2(
        image,
        all_anns,
        hue_range=(120, 170),
        sat_threshold=30,
        min_masks=3
    )
    """

    # --- CALCUL DES MOYENNES DE predicted_iou ---
    # predicted_iou est dans ann["predicted_iou"]
    def mean_iou(anns):
        if not anns:
            return 0.0
        return float(np.mean([ann["predicted_iou"] for ann in anns]))

    mean_all  = mean_iou(all_anns)
    # mean_filt = mean_iou(filt_anns)

    print(f"IoU prédite moyenne (tous masques)   = {mean_all:.3f}")
    #print(f"IoU prédite moyenne (masques filtrés)= {mean_filt:.3f}")

    # Choix de la version retenue
    #if mean_all >= mean_filt:
    chosen_anns = all_anns
    label_type = "full"
    chosen_mean = mean_all
    print(f"--> On garde TOUS les masques (mean_iou={mean_all:.3f})")
    #else:
      #  chosen_anns = filt_anns
      #  label_type = "filtered"
      #  chosen_mean = mean_filt
      #  print(f"--> On garde les masques FILTRÉS (mean_iou={mean_filt:.3f})")

    # --- AGRÉGATION DES MASQUES RETENUS ---
    # on crée un masque binaire unique : union de tous les segmentations retenues
    h, w, _ = image.shape
    agg_mask = np.zeros((h, w), dtype=bool)
    for ann in chosen_anns:
        agg_mask |= ann["segmentation"]

    # Binarisation pour sauvegarde et calcul F1
    agg_mask8 = (agg_mask.astype(np.uint8) * 255)

    # Calcul du F1-score avec imageSoluce
    f1 = calculate_f1(imageSoluce, agg_mask8)

    # Sauvegarde du masque agrégé
    output_filename = f"cell_{image_number}_label_sam2.tiff"
    output_path = os.path.join(output_dir, output_filename)
    Image.fromarray(agg_mask8).save(output_path)
    print(f"Masque agrégé sauvegardé → {output_filename}")

    # Enregistrement dans le CSV selon le format demandé
    # image_name,type_label,f1_score
    csv_writer.writerow([output_filename, label_type, f"{f1:.3f}"])
    csv_file.flush()

# Fermeture du fichier CSV une fois le traitement terminé
csv_file.close()
print(f"\nLes résultats ont été sauvegardés dans {csv_path}")

Chargement du modèle SAM2...
Modèle SAM2 chargé.
Nombre d'images à traiter : 48


Traitement des images:   0%|          | 0/48 [00:00<?, ?it/s]


Traitement de l'image cell_00306.png (numéro 00306)


Traitement des images:   0%|          | 0/48 [00:04<?, ?it/s]


KeyboardInterrupt: 