In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

In [None]:
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

In [None]:
image = cv2.imread('../data/images/000_img.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

In [None]:
plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()

In [None]:
import glob

images = glob.glob('../data/images/*')

In [None]:
import sys
sys.path.append("..")




from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "../sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

In [None]:
def mask_generator_factory(points_per_side, pred_iou_thresh, stability_score_thresh, crop_n_layers, crop_n_points_downscale_factor, min_mask_region_area):
    return SamAutomaticMaskGenerator(sam, points_per_side=points_per_side, pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh, crop_n_layers=crop_n_layers, crop_n_points_downscale_factor=crop_n_points_downscale_factor, min_mask_region_area=min_mask_region_area)

In [None]:
import time
initial_time = time.time
masks = mask_generator.generate(image)
print(time.time-initial_time)

In [None]:
data = {
    'image' : [],
    'anns' : [],
    'points_per_side' : [],
    'pred_iou_thresh' : [],
    'stability_score_thresh' : [],
    'crop_n_layers' : [],
    'crop_n_points_downscale_factor' : [],
    'min_mask_region_area' : [],
    'time' : []
}

In [None]:
for points_per_side in {32, 64}:
    for pred_iou_thresh in {0.8, 0.85, 0.9}:
        for stability_score_thresh in range(0.8, 1, 0.05):
            for crop_n_layers in {0, 1, 2}:
                for crop_n_points_downscale_factor in {1, 2}:
                    for min_mask_region_area in {0, 50, 100}:
                        
                        initial_time = time.time
                        
                        mask_generator = mask_generator_factory(points_per_side, pred_iou_thresh, stability_score_thresh, crop_n_layers, crop_n_points_downscale_factor, min_mask_region_area)
                        masks = mask_generator.generate(image)
                        
                        print("time elapsed: " + str(time.time-initial_time))
                        print("points_per_side: ", points_per_side, "pred_iou_thresh: ", pred_iou_thresh, "stability_score_thresh: ", stability_score_thresh, "crop_n_layers: ", crop_n_layers, "crop_n_points_downscale_factor: ", crop_n_points_downscale_factor, "min_mask_region_area: ", min_mask_region_area)

                        plt.figure(figsize=(20,20))
                        plt.imshow(masks)
                        plt.axis('off')
                        plt.show()

                        predictor = SamPredictor(sam, mask_generator)
                        anns = predictor.predict(image)
                        show_anns(anns)
                        plt.show()
                        
                        data['image'].append(image)
                        data['anns'].append(anns)
                        data['points_per_side'].append(points_per_side)
                        data['pred_iou_thresh'].append(pred_iou_thresh)
                        data['stability_score_thresh'].append(stability_score_thresh)
                        data['crop_n_layers'].append(crop_n_layers)
                        data['crop_n_points_downscale_factor'].append(crop_n_points_downscale_factor)
                        data['min_mask_region_area'].append(min_mask_region_area)
                        data['time'].append(time.time-initial_time)
                        