In [40]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import random
random.seed(0)
import os
from scipy.ndimage import label
import time
import torch
import tqdm
from PIL import Image


os.environ['CUDA_VISIBLE_DEVICES'] = '2'

class Segmenter():

    def __init__(self, sam_ckpt):
        
        model_type = sam_ckpt.split('/')[-1][4:9]
        print(f'Loading SAM model {model_type} from {sam_ckpt}')

        sam = sam_model_registry[model_type](checkpoint=sam_ckpt).to("cuda").eval()

        # generator = SamAutomaticMaskGenerator(sam)
        self.generator = SamAutomaticMaskGenerator(
            sam,
            points_per_side=32,
            pred_iou_thresh=0.86,
            stability_score_thresh=0.92,
            crop_n_layers=1,
            crop_n_points_downscale_factor=2,
            # min_mask_region_area=100,  # Requires open-cv to run post-processing
            )
        
    def __call__(self, image):        
        masks = self.generator.generate(image)
        masks = self.post_processing_masks(masks, image)
        return masks

    def expand_mask_blur(self, mask, kernel):
        mask = mask.copy()
        mask['segmentation'] = mask['segmentation'].astype(np.uint8)
        blurred_mask = cv2.filter2D(mask['segmentation'],-1,kernel)
        expanded_mask = (blurred_mask > 0).astype(bool)
        return expanded_mask


    def post_processing_masks(self, masks, image):

        kernel_size = int(min(image.shape[:2]) * 0.015) // 2 * 2 + 1
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(kernel_size,kernel_size))

        masked_area = None
        post_processed_masks = []
        for mask in masks:
            expanded_mask = self.expand_mask_blur(mask, kernel)
            post_processed_masks.append({
                'segmentation': expanded_mask.astype(bool),
                'area': expanded_mask.sum(),
                'bbox': cv2.boundingRect(expanded_mask.astype(np.uint8)),
            })
            if masked_area is None:
                masked_area = expanded_mask.astype(np.uint8)
            else:
                masked_area[expanded_mask] += 1

        non_masked_area = masked_area == 0
        labeled_mask, num_labels = label(non_masked_area)
        
        for i in range(1, num_labels + 1):
            post_processed_masks.append({
                'segmentation': labeled_mask == i,
                'area': (labeled_mask == i).sum(),
                'bbox': cv2.boundingRect((labeled_mask == i).astype(np.uint8)),
            })
        return post_processed_masks

In [41]:

def visualized_masks(masks, image):
    canvas = np.ones_like(image) * 255
    masks = sorted(masks, key=lambda x: x['area'], reverse=True)
    for mask in masks:
        average_color = np.mean(image[mask['segmentation'] == 1], axis=0)
        canvas[mask['segmentation'] == 1] = average_color

        # visualize segment boundary
        contours, _ = cv2.findContours(mask['segmentation'].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(canvas, contours, -1, (200, 200, 200), 1)

    return canvas


def get_masked_area(masks):
    masked_area = None
    for mask in masks:
        if masked_area is None:
            masked_area = mask['segmentation'].astype(np.uint8)
        else:
            masked_area[mask['segmentation']] += 1

    non_masked_area = (masked_area == 0).astype(np.uint8)
    return masked_area, non_masked_area





img_dir = '/home/dchenbs/workspace/datasets/coco2017/images/val2017'
num_test_images = 100

checkpoints = [
    "/home/dchenbs/workspace/cache/sam_vit_h_4b8939.pth",
    "/home/dchenbs/workspace/cache/sam_vit_l_0b3195.pth",
    "/home/dchenbs/workspace/cache/sam_vit_b_01ec64.pth",
]
"""
Loading SAM model vit_h from /home/dchenbs/workspace/cache/sam_vit_h_4b8939.pth
100%|██████████| 100/100 [07:46<00:00,  4.67s/it]
Loading SAM model vit_l from /home/dchenbs/workspace/cache/sam_vit_l_0b3195.pth
100%|██████████| 100/100 [06:08<00:00,  3.69s/it]
Loading SAM model vit_b from /home/dchenbs/workspace/cache/sam_vit_b_01ec64.pth
100%|██████████| 100/100 [04:13<00:00,  2.54s/it]
"""

# checkpoint="/home/dchenbs/workspace/cache/sam_vit_b_01ec64.pth"
for checkpoint in checkpoints:
    segmenter = Segmenter(checkpoint)

    for i in tqdm.tqdm(range(num_test_images)):
        img_path = os.path.join(img_dir, random.choice(os.listdir(img_dir)))
        image = np.array(Image.open(img_path).convert('RGB'))
        masks = segmenter(image)

        # plt.figure(figsize=(20, 4))
        # plt.subplot(1, 5, 1)
        # plt.imshow(image)
        # plt.axis('off')

        # canvas = visualized_masks(masks, image)
        # plt.subplot(1, 5, 3)
        # plt.imshow(canvas)
        # plt.axis('off')

        # masked_area, non_masked_area = get_masked_area(masks)
        # plt.subplot(1, 5, 4)
        # plt.imshow(masked_area*int(255/max(masked_area.flatten())))
        # plt.axis('off')

        # plt.tight_layout()
        # plt.show()

Using /home/dchenbs/workspace/cache/sam_vit_h_4b8939.pth
Loading SAM model vit_h from /home/dchenbs/workspace/cache/sam_vit_h_4b8939.pth


100%|██████████| 100/100 [07:46<00:00,  4.67s/it]


Using /home/dchenbs/workspace/cache/sam_vit_l_0b3195.pth
Loading SAM model vit_l from /home/dchenbs/workspace/cache/sam_vit_l_0b3195.pth


100%|██████████| 100/100 [06:08<00:00,  3.69s/it]


Using /home/dchenbs/workspace/cache/sam_vit_b_01ec64.pth
Loading SAM model vit_b from /home/dchenbs/workspace/cache/sam_vit_b_01ec64.pth


100%|██████████| 100/100 [04:13<00:00,  2.54s/it]
