In [1]:
import matplotlib
matplotlib.use('TkAgg')  # or 'Qt5Agg'

from ultralytics import SAM, YOLO
import matplotlib.pyplot as plt
import cv2
import numpy as np


In [2]:
# Load SAM model
sam_model = SAM('./sam2.1_b.pt')

# Load YOLO model
yolo_model = YOLO('./yolo11x-cls.pt')

# Load and preprocess the image
image_path = './image.png'
image = cv2.imread(image_path)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


In [8]:
# Run SAM inference to get segmentation masks
sam_results = sam_model.predict(image_rgb)



0: 1024x1024 1 0, 1 1, 1 2, 1 3, 1 4, 1 5, 1 6, 1 7, 1 8, 1 9, 1 10, 1 11, 1 12, 1 13, 1 14, 1 15, 1 16, 1 17, 1 18, 3649.4ms
Speed: 4.9ms preprocess, 3649.4ms inference, 0.6ms postprocess per image at shape (1, 3, 1024, 1024)


In [19]:
masks = sam_results[0].masks  # Adjust based on actual output structure
# output_masks = masks.data.cpu().numpy()  # Ensure the mask is on the CPU and convert to numpy array
output_masks = masks.xy


range(0, 19)


In [23]:
# Create a figure
plt.figure(figsize=(10, 10))

# Display the original image
plt.imshow(image_rgb)

# Overlay each mask on the original image with transparency
for idx, mk in enumerate(output_masks):
    mask = output_masks[idx]
    plt.imshow(mask, cmap='jet', alpha=0.5)  # Use alpha for transparency

    # Convert mask to the correct type
    mask_uint8 = cv2.fillPoly(np.zeros(image_rgb.shape[:2], dtype=np.uint8), [mask.astype(np.int32)], 1)

    # Extract the segment from the original image
    segment = cv2.bitwise_and(image_rgb, image_rgb, mask=mask_uint8)

    # Run YOLO inference on the segment to get labels
    yolo_results = yolo_model.predict(segment)
    labels = yolo_results[0].names  # Adjust based on actual output structure

    # Display the label
    plt.text(10, 10 + idx * 20, f'Label {idx}: {labels[0]}', color='white', fontsize=12, backgroundcolor='black')

plt.title('Original Image with All Masks and Labels')
plt.axis('off')
plt.show()


0: 224x224 cleaver 0.03, space_shuttle 0.02, oboe 0.01, cassette 0.01, notebook 0.01, 86.5ms
Speed: 6.2ms preprocess, 86.5ms inference, 0.1ms postprocess per image at shape (1, 3, 224, 224)

0: 224x224 Band_Aid 0.33, digital_clock 0.04, rule 0.03, envelope 0.02, street_sign 0.02, 8.9ms
Speed: 4.4ms preprocess, 8.9ms inference, 0.1ms postprocess per image at shape (1, 3, 224, 224)

0: 224x224 digital_clock 0.08, digital_watch 0.07, Band_Aid 0.07, analog_clock 0.03, spotlight 0.03, 8.1ms
Speed: 4.4ms preprocess, 8.1ms inference, 0.0ms postprocess per image at shape (1, 3, 224, 224)

0: 224x224 jellyfish 0.32, sea_urchin 0.06, ping-pong_ball 0.04, digital_clock 0.03, nematode 0.03, 9.8ms
Speed: 4.9ms preprocess, 9.8ms inference, 0.0ms postprocess per image at shape (1, 3, 224, 224)

0: 224x224 Band_Aid 0.05, cleaver 0.05, loudspeaker 0.04, notebook 0.02, window_shade 0.02, 8.2ms
Speed: 4.5ms preprocess, 8.2ms inference, 0.0ms postprocess per image at shape (1, 3, 224, 224)

0: 224x224 Ba