In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb

# Automatically generating object masks with SAM

Since SAM 2 can efficiently process prompts, masks for the entire image can be generated by sampling a large number of prompts over an image.

The class `SAM2AutomaticMaskGenerator` implements this capability. It works by sampling single-point input prompts in a grid over the image, from each of which SAM can predict multiple masks. Then, masks are filtered for quality and deduplicated using non-maximal suppression. Additional options allow for further improvement of mask quality and quantity, such as running prediction on multiple crops of the image or postprocessing masks to remove small disconnected regions and holes.

<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/automatic_mask_generator_example.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## Environment Set-up

If running locally using jupyter, first install `SAM 2` in your environment using the installation instructions in the repository.

If running from Google Colab, set `using_colab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'. Note that it's recommended to use **A100 or L4 GPUs when running in Colab** (T4 GPUs might also work, but could be slow and might run out of memory in some cases).

In [None]:
using_colab = False

In [None]:
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/sam2.git'

    !mkdir -p images
    !wget -P images https://raw.githubusercontent.com/facebookresearch/sam2/main/notebooks/images/cars.jpg

    !mkdir -p ../checkpoints/
    !wget -P ../checkpoints/ https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt

## Set-up

In [None]:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from sklearn.cluster import KMeans

In [None]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

## Mask Generation with Sam2 model

In [None]:
from sam2MaskUtils import*

# --- Initialisation du modèle SAM2 ---

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

sam2_checkpoint = "../models/sam2/checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

print("Chargement du modèle SAM2...")
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(sam2)
print("Modèle SAM2 chargé.")

# --- Parcours des images du dossier ---

image_dir = "../../data/unified_set_rename/images/"
label_dir = "../../data/unified_set_rename/labels/"
output_dir = "../../data/sam2Labels"
os.makedirs(output_dir, exist_ok=True)

# Récupère uniquement les fichiers PNG
image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')])

print(f"Nombre d'images à traiter : {len(image_files)}")

for image_file in tqdm(image_files, desc="Traitement des images"):
    # Extraction du numéro de l'image (exemple: cell_00225.png)
    base_name = os.path.splitext(image_file)[0]  # "cell_00225"
    try:
        image_number = base_name.split('_')[1]
    except IndexError:
        print(f"Format inattendu pour le nom de fichier : {image_file}")
        continue

    print(f"\nTraitement de l'image {image_file} (numéro {image_number})")

    # Chargement de l'image et de son annotation
    image_path = os.path.join(image_dir, image_file)
    label_file = f"cell_{image_number}_label.tiff"
    label_path = os.path.join(label_dir, label_file)

    try:
        image = np.array(Image.open(image_path).convert("RGB"))
        imageSoluce = np.array(Image.open(label_path).convert("RGB"))
    except Exception as e:
        print(f"Erreur lors du chargement de l'image ou de l'annotation pour {image_file}: {e}")
        continue

    # Génération des masques avec SAM2
    masks = mask_generator.generate(image)
    # Filtrage des masques selon la couleur
    filtered_masks = filter_masks_by_color_sam2(
        image,
        masks,
        hue_range=(120, 170),
        sat_threshold=30,
        min_masks=3
    )

    # Création des deux versions de masque
    all_masks_display = create_grayscale_mask(masks) if len(masks) > 0 else None
    filtered_display = create_grayscale_mask_sam2(filtered_masks, image.shape) if len(filtered_masks) > 0 else None

    # Calcul des F1 scores
    f1_all = calculate_f1(imageSoluce, all_masks_display) if all_masks_display is not None else 0
    f1_filtered = calculate_f1(imageSoluce, filtered_display) if filtered_display is not None else 0

    print(f"Image {image_number} : F1 (tous les masques) = {f1_all:.3f} ; F1 (masques filtrés) = {f1_filtered:.3f}")

    # Choix de l'image à enregistrer
    if f1_all >= f1_filtered:
        chosen_mask = all_masks_display
        print(f"--> Choix de la segmentation avec tous les masques pour l'image {image_number}.")
    else:
        chosen_mask = filtered_display
        print(f"--> Choix de la segmentation avec masques filtrés pour l'image {image_number}.")

    # Sauvegarde du masque choisi
    if chosen_mask is not None:
        # Conversion en format uint8 si nécessaire
        if chosen_mask.dtype != np.uint8:
            chosen_mask = (chosen_mask * 255).astype(np.uint8)
        output_path = os.path.join(output_dir, f"cell_{image_number}_label_sam2.tiff")
        Image.fromarray(chosen_mask).save(output_path)
        print(f"Résultat sauvegardé : {output_path}")
    else:
        print(f"Aucun masque généré pour l'image {image_number}.")
