In [1]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import random
random.seed(0)
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4'
from scipy.ndimage import label
import time
import torch
import tqdm
from PIL import Image

In [2]:
# from segmenter import Segmenter

def visualized_masks(masks, image):
    canvas = np.ones_like(image) * 255
    masks = sorted(masks, key=lambda x: x['area'], reverse=True)
    for mask in masks:
        average_color = np.mean(image[mask['segmentation'] == 1], axis=0)
        canvas[mask['segmentation'] == 1] = average_color

        # visualize segment boundary
        contours, _ = cv2.findContours(mask['segmentation'].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(canvas, contours, -1, (200, 200, 200), 1)

    return canvas


def get_masked_area(masks):
    masked_area = None
    for mask in masks:
        if masked_area is None:
            masked_area = mask['segmentation'].astype(np.uint8)
        else:
            masked_area[mask['segmentation']] += 1

    non_masked_area = (masked_area == 0).astype(np.uint8)
    return masked_area, non_masked_area


In [3]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import random
random.seed(0)
import os
from scipy.ndimage import label
import time
import torch
import tqdm
from PIL import Image


class Segmenter():

    def __init__(
            self, 
            model_name,
            checkpoint,
            points_per_side = 32,
            points_per_batch = 64,
            pred_iou_thresh = 0.88,
            stability_score_thresh = 0.95,
            stability_score_offset = 1.0,
            box_nms_thresh = 0.7,
            crop_n_layers = 0,
            crop_nms_thresh = 0.7,
            crop_overlap_ratio = 512 / 1500,
            crop_n_points_downscale_factor = 1,
            min_mask_region_area = 0,
            device = 'cuda',
            ):
        self.generator = None
        self.model_name = model_name
        
        if self.model_name=='fast_sam':
            # Fast Segment Anything 
            # https://arxiv.org/abs/2306.12156 (21 Jun 2023)
            # https://github.com/CASIA-IVA-Lab/FastSAM

            from fastsam import FastSAM
            self.generator = FastSAM(checkpoint)

        elif self.model_name=='mobile_sam':
            # Faster Segment Anything: Towards Lightweight SAM for Mobile Applications 
            # https://arxiv.org/abs/2306.14289.pdf (25 Jun 2023)
            # https://github.com/ChaoningZhang/MobileSAM
            
            from mobile_sam import sam_model_registry, SamAutomaticMaskGenerator
            sam = sam_model_registry["vit_t"](checkpoint=checkpoint).to(device).eval()

        elif self.model_name=='repvit_sam':
            from repvit_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
            sam = sam_model_registry["repvit"](checkpoint=checkpoint).to(device).eval()

        elif self.model_name=='sam':
            from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
            model_type = checkpoint.split('/')[-1][4:9]
            sam = sam_model_registry[model_type](checkpoint=checkpoint).to(device).eval()

        else:
            raise NotImplementedError(f'Model {self.model_name} not implemented')
        
        if self.generator is None:
            self.generator = SamAutomaticMaskGenerator(
                sam,
                points_per_side=points_per_side,
                points_per_batch=points_per_batch,
                pred_iou_thresh=pred_iou_thresh,
                stability_score_thresh=stability_score_thresh,
                stability_score_offset=stability_score_offset,
                box_nms_thresh=box_nms_thresh,
                crop_n_layers=crop_n_layers,
                crop_nms_thresh=crop_nms_thresh,
                crop_overlap_ratio=crop_overlap_ratio,
                crop_n_points_downscale_factor=crop_n_points_downscale_factor,
                min_mask_region_area=min_mask_region_area,
                )
        
        
    def __call__(self, image_path, post_processing=True):  
        
        image = np.array(Image.open(image_path).convert('RGB')) 
        
        if self.model_name=='fast_sam':
            everything_results = self.generator(image_path, device='cuda', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9,)
            masks = []
            for i in range(everything_results[0].boxes.data.shape[0]):
                box = everything_results[0].boxes.data[i]
                mask = everything_results[0].masks.data[i]
                masks.append({'segmentation': mask.cpu().numpy().astype(bool), 'area': mask.sum(), 'bbox': box.cpu().tolist(),})
        else:
            masks = self.generator.generate(image)

        if post_processing:
            masks = self.post_processing_masks(masks, image)
        return masks

    def expand_mask_blur(self, mask, kernel):
        mask = mask.copy()
        mask['segmentation'] = mask['segmentation'].astype(np.uint8)
        blurred_mask = cv2.filter2D(mask['segmentation'],-1,kernel)
        expanded_mask = (blurred_mask > 0).astype(bool)
        return expanded_mask


    def post_processing_masks(self, masks, image):

        kernel_size = int(min(image.shape[:2]) * 0.015) // 2 * 2 + 1
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(kernel_size,kernel_size))

        masked_area = None
        post_processed_masks = []
        for mask in masks:
            expanded_mask = self.expand_mask_blur(mask, kernel)
            post_processed_masks.append({
                'segmentation': expanded_mask.astype(bool),
                'area': expanded_mask.sum(),
                'bbox': list(cv2.boundingRect(expanded_mask.astype(np.uint8))),
            })
            if masked_area is None:
                masked_area = expanded_mask.astype(np.uint8)
            else:
                masked_area[expanded_mask] += 1

        non_masked_area = masked_area == 0
        labeled_mask, num_labels = label(non_masked_area)
        
        for i in range(1, num_labels + 1):
            post_processed_masks.append({
                'segmentation': labeled_mask == i,
                'area': (labeled_mask == i).sum(),
                'bbox': list(cv2.boundingRect((labeled_mask == i).astype(np.uint8))),
            })
        return post_processed_masks

In [5]:

num_test_images = 10
visualize = False

img_dir = '/home/dchenbs/workspace/datasets/coco2017/images/val2017'
image_paths = []
for i in range(num_test_images):
    img_path = os.path.join(img_dir, random.choice(os.listdir(img_dir)))
    image_paths.append(img_path)


models = [
    
    # ('sam', '/home/dchenbs/workspace/cache/sam_vit_h_4b8939.pth'), 
        # 6.798 GB, 2.47 s/image

    # ('sam', '/home/dchenbs/workspace/cache/sam_vit_l_0b3195.pth'), 
        # 5.346 GB, 1.76 s/image

    # ('sam', '/home/dchenbs/workspace/cache/sam_vit_b_01ec64.pth'), 
        # 4.404 GB, 1.14 s/image

    # ('fast_sam', '/home/dchenbs/workspace/cache/FastSAM-s.pt'),  
        # 1.326 GB, 0.34 s/image

    # ('fast_sam', '/home/dchenbs/workspace/cache/FastSAM-x.pt'), 
        # 1.946 GB, 0.24 s/image
    
    ('mobile_sam', '/home/dchenbs/workspace/cache/mobile_sam.pt'),
        # 4.376 GB, 1.15 s/image
    
    # ('repvit_sam', '/home/dchenbs/workspace/Seq2Seq-AutoEncoder/RepViT/sam/weights/repvit_sam.pt'), 
        # 4.722 GB, 1.49 s/image
    
    
]


for model_name, checkpoint in models:
    start = time.time()
    print(f'Running [{model_name.upper()}]: {checkpoint.split("/")[-1]}')

    segmenter = None
    torch.cuda.empty_cache()

    segmenter = Segmenter(model_name, checkpoint)
    for img_path in tqdm.tqdm(image_paths):
        masks = segmenter(img_path, post_processing=True)
        image = np.array(Image.open(img_path).convert('RGB'))
        
        if visualize:
            plt.figure(figsize=(20, 8))
            plt.subplot(1, 3, 1)
            plt.imshow(image)
            plt.axis('off')

            canvas = visualized_masks(masks, image)
            plt.subplot(1, 3, 2)
            plt.imshow(canvas)
            plt.axis('off')

            masked_area, non_masked_area = get_masked_area(masks)
            plt.subplot(1, 3, 3)
            plt.imshow(masked_area*int(255/max(masked_area.flatten())))
            plt.axis('off')

            plt.tight_layout()
            plt.show()

    print(f'[{model_name.upper()}]: {checkpoint.split("/")[-1]}\n{(time.time()-start)/num_test_images :2f}s/image')

Running [MOBILE_SAM]: mobile_sam.pt


100%|██████████| 10/10 [00:11<00:00,  1.12s/it]

[MOBILE_SAM]: mobile_sam.pt
1.149364s/image



