In [1]:
import cv2
import time
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from segmentation import Sam2HF

In [2]:
mask_gen = Sam2HF()

In [None]:
image = Image.open("no-smile-1024.png")
image = np.array(image.convert("RGB"))

plt.figure(figsize=(5, 5))
plt.imshow(image)
plt.axis('off')
plt.show()

In [4]:
# Initialize global variables
clicked = []
labels = []
rectangles = []
mode = 'point'  # Default mode
ix, iy = -1, -1
drawing = False
last_point_time = 0  # To keep track of the last point creation time
delay = 0.2  # Time delay in seconds
 
# Mouse callback function
def draw(event, x, y, flags, param):
    global ix, iy, drawing, rectangles, clicked, labels, mode, last_point_time
 
    current_time = time.time()
     
    if mode == 'point':
        if event == cv2.EVENT_LBUTTONDOWN:
            clicked.append([x, y])
            labels.append(1)
            cv2.circle(show_image, (x, y), 5, (0, 255, 0), -1)
            cv2.imshow('image', show_image)
        elif event == cv2.EVENT_MBUTTONDOWN:
            clicked.append([x, y])
            labels.append(0)
            cv2.circle(show_image, (x, y), 5, (0, 0, 255), -1)
            cv2.imshow('image', show_image)
        elif event == cv2.EVENT_MOUSEMOVE:
            if flags & cv2.EVENT_FLAG_LBUTTON:
                if current_time - last_point_time >= delay:
                    clicked.append([x, y])
                    labels.append(1)
                    cv2.circle(show_image, (x, y), 5, (0, 255, 0), -1)
                    cv2.imshow('image', show_image)
                    last_point_time = current_time
    elif mode == 'rectangle':
        if event == cv2.EVENT_LBUTTONDOWN:
            drawing = True
            ix, iy = x, y
        elif event == cv2.EVENT_MOUSEMOVE:
            if drawing:
                img = show_image.copy()
                cv2.rectangle(img, (ix, iy), (x, y), (0, 255, 0), 2)
                cv2.imshow('image', img)
        elif event == cv2.EVENT_LBUTTONUP:
            drawing = False
            cv2.rectangle(show_image, (ix, iy), (x, y), (0, 255, 0), 2)
            cv2.imshow('image', show_image)
            rectangles.append([ix, iy, x, y])

In [None]:
# Load an image
cv2.namedWindow('image')
cv2.setMouseCallback('image', draw)
image = Image.open("no-smile-1024.png")
show_image = np.array(image)
# Press 'p' to switch to point mode, 'r' to switch to rectangle mode, 'q' to quit
while True:
    cv2.imshow('image', show_image)
    key = cv2.waitKey(1) & 0xFF
    if key == ord('p'):
        mode = 'point'
        print("Switched to point mode")
    elif key == ord('r'):
        mode = 'rectangle'
        print("Switched to rectangle mode")
    elif key == ord('q'):
        break
 
cv2.destroyAllWindows()

In [6]:
input_point = np.array(clicked)
input_label = np.array(labels)
input_rectangles = np.array(rectangles)

In [7]:
image_array = np.array(image.convert("RGB"))
masks, scores, unkn = mask_gen.generate_masks_from_prompt(
    image_array, input_points=input_point, input_labels=input_label, input_rectangles=input_rectangles)

In [8]:
rgb_mask = Sam2HF.get_rgba_mask(masks=masks, borders=True)

In [9]:
cv2.imshow('Image', rgb_mask)
cv2.waitKey(0)
cv2.destroyAllWindows()

In [10]:
final_image = Sam2HF.image_overlay(image_array, rgb_mask)
cv2.imshow('Image', final_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

In [11]:
rgb_mask_bw = np.where(rgb_mask != np.array([0., 0., 0., 0.]), np.array([1., 1., 1., 1.]), rgb_mask)

In [33]:
cv2.imshow('Image', rgb_mask_bw)
cv2.waitKey(0)
cv2.destroyAllWindows()

In [12]:
rgba_array = (rgb_mask_bw * 255).astype(np.uint8)

In [13]:
pil_mask = Image.fromarray(rgba_array, 'RGBA')

In [None]:
pil_mask.show()

In [None]:
pil_mask.convert("RGB")