In [44]:
import os
import cv2
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from PIL import Image

In [45]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

using device: cuda


In [46]:
def show_anns(anns, borders=True):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.5]])
        img[m] = color_mask 
        if borders:
            import cv2
            contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
            # Try to smooth contours
            contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
            cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) 

    ax.imshow(img)

def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='.', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='.', s=marker_size, edgecolor='white', linewidth=1.25)   

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()

In [47]:
wing_1 = Image.open('../testdata/set_1/043870_L_O.JPG')
wing_2 = Image.open('../testdata/set_1/043870_R_O.JPG')
wing_3 = Image.open('../testdata/set_1/043874_L_O.JPG')
wing_4 = Image.open('../testdata/set_1/043874_R_O.JPG')
wing_5 = Image.open('../testdata/set_1/043878_L_O.JPG')
wing_6 = Image.open('../testdata/set_1/043878_R_O.JPG')
wing_7 = Image.open('../testdata/set_1/043884_L_O.JPG')

wing = wing_2
wing = np.array(wing.convert("RGB"))

In [48]:
"""
scale_percent = 0.30 
new_width = int(wing.shape[1] * scale_percent)
new_height = int(wing.shape[0] * scale_percent)
wing = cv2.resize(wing, (new_width, new_height))
"""

'\nscale_percent = 0.30 \nnew_width = int(wing.shape[1] * scale_percent)\nnew_height = int(wing.shape[0] * scale_percent)\nwing = cv2.resize(wing, (new_width, new_height))\n'

In [49]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

sam2_checkpoint = "/home/wsl/bin/segment-anything-2/checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

predictor = SAM2ImagePredictor(sam2_model)

In [None]:
predictor.set_image(wing)

In [None]:
width = wing.shape[1]
height = wing.shape[0]

input_point = np.array([[1/3 * width, 1/2 * height], [2/3 * width, 1/2 * height]])
input_label = np.array([1, 1])

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(wing)
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()  

In [None]:
mask, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=False,
)
sorted_ind = np.argsort(scores)[::-1]
mask = mask[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]

In [None]:
show_masks(wing, mask, scores, point_coords=input_point, input_labels=input_label)

In [None]:
mask2 = mask.squeeze()  # Removes the extra dimension if it's size 1
new_image = mask2[..., None] * wing  # Add an axis to the mask to match the image's shape
new_image = new_image.astype('uint8')

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(new_image)
plt.axis('off')
plt.show()  

In [None]:
new_image.shape

In [None]:
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
import time
start = time.time()

sam2_checkpoint = "/home/wsl/bin/segment-anything-2/checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"

sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)

mask_generator = SAM2AutomaticMaskGenerator(sam2)

masks = mask_generator.generate(new_image)

end = time.time()
print(round(end - start, 2), "seconds")

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(wing)
show_anns(masks)
plt.axis('off')
plt.show() 