In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import random
random.seed(0)
import os
from scipy.ndimage import label
import time
import torch

os.environ['CUDA_VISIBLE_DEVICES'] = '5'

def normalize_image(image):
    return (image - image.min()) / (image.max() - image.min())


def visualized_masks(masks, image):
    canvas = np.zeros_like(image)
    masks = sorted(masks, key=lambda x: x['area'], reverse=True)
    for mask in masks:
        average_color = np.mean(image[mask['segmentation'] == 1], axis=0)
        canvas[mask['segmentation'] == 1] = average_color

        # visualize segment boundary
        contours, _ = cv2.findContours(mask['segmentation'].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(canvas, contours, -1, (200, 200, 200), 1)

    return canvas


def expand_mask_blur(mask, kernel):
    mask = mask.copy()
    mask['segmentation'] = mask['segmentation'].astype(np.uint8)
    blurred_mask = cv2.filter2D(mask['segmentation'],-1,kernel)
    expanded_mask = (blurred_mask > 0).astype(bool)
    return expanded_mask


def post_processing_masks(masks, image):

    kernel_size = int(min(image.shape[:2]) * 0.015) // 2 * 2 + 1
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(kernel_size,kernel_size))

    masked_area = None
    post_processed_masks = []
    for mask in masks:
        expanded_mask = expand_mask_blur(mask, kernel)
        post_processed_masks.append({
            'segmentation': expanded_mask.astype(bool),
            'area': expanded_mask.sum(),
            'bbox': cv2.boundingRect(expanded_mask.astype(np.uint8)),
        })
        if masked_area is None:
            masked_area = expanded_mask.astype(np.uint8)
        else:
            masked_area[expanded_mask] += 1

    non_masked_area = masked_area == 0
    labeled_mask, num_labels = label(non_masked_area)
    
    for i in range(1, num_labels + 1):
        post_processed_masks.append({
            'segmentation': labeled_mask == i,
            'area': (labeled_mask == i).sum(),
            'bbox': cv2.boundingRect((labeled_mask == i).astype(np.uint8)),
        })
    return post_processed_masks


def get_masked_area(masks):
    masked_area = None
    for mask in masks:
        if masked_area is None:
            masked_area = mask['segmentation'].astype(np.uint8)
        else:
            masked_area[mask['segmentation']] += 1

    non_masked_area = (masked_area == 0).astype(np.uint8)
    return masked_area, non_masked_area


In [None]:
# checkpoint="/home/dchenbs/workspace/cache/sam_vit_h_4b8939.pth"
# sam = sam_model_registry["vit_h"](checkpoint=checkpoint).to("cuda").eval()

# checkpoint="/home/dchenbs/workspace/cache/sam_vit_l_0b3195.pth"
# sam = sam_model_registry["vit_l"](checkpoint=checkpoint).to("cuda").eval()

checkpoint="/home/dchenbs/workspace/cache/sam_vit_b_01ec64.pth"
sam = sam_model_registry["vit_b"](checkpoint=checkpoint).to("cuda").eval()

In [None]:
# generator = SamAutomaticMaskGenerator(sam)

generator = SamAutomaticMaskGenerator(
    sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    # min_mask_region_area=100,  # Requires open-cv to run post-processing
    )

In [None]:
times = []
for i in range(5):

    img_dir = '/home/dchenbs/workspace/datasets/coco2017/images/val2017'
    img_path = os.path.join(img_dir, random.choice(os.listdir(img_dir)))

    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    start = time.time()
    masks = generator.generate(image)
    print(f'AMG time: {time.time() - start:.2f}s')
    times.append(time.time() - start)

    plt.figure(figsize=(30, 30))

    plt.subplot(1, 5, 1)
    plt.imshow(image)
    plt.title('Input Image')

    canvas = visualized_masks(masks, image)
    plt.subplot(1, 5, 2)
    plt.imshow(canvas)
    plt.title('Raw Masks from SAM')

    start = time.time()
    masks = post_processing_masks(masks, image)
    print(f'Post processing time: {time.time() - start:.2f}s')
    canvas = visualized_masks(masks, image)

    plt.subplot(1, 5, 3)
    plt.imshow(canvas)
    plt.title('Post-processed Masks')

    masked_area, non_masked_area = get_masked_area(masks)

    plt.subplot(1, 5, 4)
    plt.imshow(masked_area*int(255/max(masked_area.flatten())))
    plt.title('Mask Overlaps')

    plt.tight_layout()
    plt.show()
print(f'Average time: {np.mean(times):.2f}s')