In [1]:
%matplotlib inline
import torch
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os

from segment_anything import sam_model_registry, SamPredictor
from ultralytics import YOLO



In [2]:
%matplotlib qt
#%wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

In [3]:
print(os.getcwd())

/Users/romainmorin/Desktop/TN/3A/PI/Projet/src/TrainPI/src/segmentation


In [4]:
# Chemin vers les poids du modèle
sam_checkpoint = "sam_weights/sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d()
    )


In [5]:
image_path = "../../data/content/helicopter.jpg"
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
model = YOLO("yolov8x.pt")
results = model(image, conf=0.25)

Downloading https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8x.pt to 'yolov8x.pt'...


100%|██████████| 131M/131M [00:06<00:00, 22.4MB/s] 



0: 512x640 1 airplane, 590.3ms
Speed: 2.1ms preprocess, 590.3ms inference, 1.4ms postprocess per image at shape (1, 3, 512, 640)


In [7]:
helicopter_box = None
for r in results:
    for box in r.boxes:
        cls_id = int(box.cls)
        cls_name = r.names[cls_id]
        if cls_name == "airplane":
            # Récupérer la bbox (x_min, y_min, x_max, y_max)
            x_min, y_min, x_max, y_max = box.xyxy[0].cpu().numpy().astype(int)
            helicopter_box = [x_min, y_min, x_max, y_max]
            break
    if helicopter_box is not None:
        break

if helicopter_box is None:
    print("Aucun hélicoptère détecté !")
else:
    # Utiliser SAM avec la boîte englobante
    predictor = SamPredictor(sam)
    predictor.set_image(image)

    box_coords = np.array(helicopter_box)
    masks, scores, logits = predictor.predict(
        box=box_coords[None, :],  # (1,4)
        multimask_output=True
    )


N_POINTS_DESIRES = 1

input_points = []
input_labels = []

# Fonction de rappel pour événement de clic
def onclick(event):
    ix, iy = event.xdata, event.ydata
    if ix is not None and iy is not None:
        print(f'Point sélectionné : ({ix}, {iy})')
        input_points.append([ix, iy])
        input_labels.append(1)  # Indicateur positif

        # Marquer le point sur l'image
        ax.plot(ix, iy, 'ro')  # point rouge
        fig.canvas.draw()

        if len(input_points) == N_POINTS_DESIRES:
            fig.canvas.mpl_disconnect(cid)
            print("Nombre de points requis atteints.")

# Affichage de l'image et sélection du point
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image)
ax.set_title("Cliquez sur l'hélicoptère pour sélectionner un point")
cid = fig.canvas.mpl_connect('button_press_event', onclick)
plt.show()

# Attente jusqu'à ce qu'un point soit sélectionné
while len(input_points) < N_POINTS_DESIRES:
    plt.pause(0.1)

input_point = np.array(input_points)
input_label = np.array(input_labels)


"""input_points = np.array([500, 300])
input_labels = np.array([1, 0])"""

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True
)

In [8]:
# Afficher les masques
fig, axes = plt.subplots(1, len(masks), figsize=(15, 5))
for i, (mask, score) in enumerate(zip(masks, scores)):
    axes[i].imshow(image)
    axes[i].imshow(mask, alpha=0.5)
    axes[i].set_title(f"Masque {i+1} - Score : {score:.3f}")
    axes[i].axis('off')
plt.show()

In [9]:
selected_mask = masks[0]

# Post-traitement du masque
# 1. Sélection de la plus grande composante connectée
mask_uint8 = selected_mask.astype(np.uint8)
num_labels, labels_im = cv2.connectedComponents(mask_uint8)

largest_label = 0
largest_area = 0
for label in range(1, num_labels):  # Ignorer l'arrière-plan (label=0)
    area = np.sum(labels_im == label)
    if area > largest_area:
        largest_area = area
        largest_label = label

final_mask = (labels_im == largest_label).astype(np.uint8)

In [10]:
kernel = np.ones((3, 3), np.uint8)
# Ouverture pour supprimer les petits artefacts isolés
final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_OPEN, kernel)
# Fermeture pour combler les petits trous à l'intérieur du masque
final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel)


# Application du masque sur l'image
masked_image = image.copy()
masked_image[final_mask == 0] = 0  # Mettre le fond à noir
cv2.imwrite("helicopter_masked.png", masked_image)
# Affichage de l'image masquée finale
plt.figure(figsize=(10, 10))
plt.imshow(masked_image)
plt.title("Hélicoptère isolé avec fond noir")
plt.axis('off')
plt.show()
