In [1]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import random
random.seed(0)
import os
from scipy.ndimage import label
import time
import torch
import tqdm
from PIL import Image


from segmenter import Segmenter

In [None]:

def visualized_masks(masks, image):
    canvas = np.ones_like(image) * 255
    masks = sorted(masks, key=lambda x: x['area'], reverse=True)
    for mask in masks:
        average_color = np.mean(image[mask['segmentation'] == 1], axis=0)
        canvas[mask['segmentation'] == 1] = average_color

        # visualize segment boundary
        contours, _ = cv2.findContours(mask['segmentation'].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(canvas, contours, -1, (200, 200, 200), 1)

    return canvas


def get_masked_area(masks):
    masked_area = None
    for mask in masks:
        if masked_area is None:
            masked_area = mask['segmentation'].astype(np.uint8)
        else:
            masked_area[mask['segmentation']] += 1

    non_masked_area = (masked_area == 0).astype(np.uint8)
    return masked_area, non_masked_area



os.environ['CUDA_VISIBLE_DEVICES'] = '4'

img_dir = '/home/dchenbs/workspace/datasets/coco2017/images/val2017'
num_test_images = 8
visualize = True

checkpoints = [
    '/home/dchenbs/workspace/Seq2Seq-AutoEncoder/RepViT/sam/weights/repvit_sam.pt',
    '/home/dchenbs/workspace/cache/mobile_sam.pt',
    "/home/dchenbs/workspace/cache/sam_vit_h_4b8939.pth",
    "/home/dchenbs/workspace/cache/sam_vit_l_0b3195.pth",
    "/home/dchenbs/workspace/cache/sam_vit_b_01ec64.pth",
]

"""
(Default)
crop_n_layers=0,
crop_n_points_downscale_factor=1,

    RepVIT SAM
    100%|██████████| 100/100 [02:16<00:00,  1.36s/it]
    Mobile SAM
    100%|██████████| 100/100 [02:05<00:00,  1.25s/it]
    SAM vit_h
    100%|██████████| 100/100 [03:03<00:00,  1.84s/it]
    SAM vit_l
    100%|██████████| 100/100 [02:43<00:00,  1.64s/it]
    SAM vit_b
    100%|██████████| 100/100 [01:55<00:00,  1.16s/it]

------------------------
crop_n_layers=1,
crop_n_points_downscale_factor=2,

    RepVIT SAM
    100%|██████████| 100/100 [04:15<00:00,  2.55s/it]
    Mobile SAM
    100%|██████████| 100/100 [04:01<00:00,  2.41s/it]
    SAM vit_h
    100%|██████████| 100/100 [07:46<00:00,  4.67s/it]
    SAM vit_l
    100%|██████████| 100/100 [06:08<00:00,  3.69s/it]
    SAM vit_b
    100%|██████████| 100/100 [04:13<00:00,  2.54s/it]
"""


for checkpoint in checkpoints:
    # Default
    segmenter = Segmenter(
        checkpoint,
        # points_per_side = 32,
        # points_per_batch = 64,
        # pred_iou_thresh = 0.88,
        # stability_score_thresh = 0.95,
        # stability_score_offset = 1.0,
        # box_nms_thresh = 0.7,
        # crop_n_layers = 0,
        # crop_nms_thresh = 0.7,
        # crop_overlap_ratio = 512 / 1500,
        # crop_n_points_downscale_factor = 1,
        # min_mask_region_area = 0,
    )
    
    # # Finer but longer
    # segmenter = Segmenter(
    #     checkpoint,
    #     points_per_side=32,
    #     pred_iou_thresh=0.86,
    #     stability_score_thresh=0.92,
    #     crop_n_layers=1,
    #     crop_n_points_downscale_factor=2,
    #     )

    for i in tqdm.tqdm(range(num_test_images)):
        img_path = os.path.join(img_dir, random.choice(os.listdir(img_dir)))
        image = np.array(Image.open(img_path).convert('RGB'))
        masks = segmenter(image)

        if visualize:
            plt.figure(figsize=(20, 8))
            plt.subplot(1, 3, 1)
            plt.imshow(image)
            plt.axis('off')

            canvas = visualized_masks(masks, image)
            plt.subplot(1, 3, 2)
            plt.imshow(canvas)
            plt.axis('off')

            masked_area, non_masked_area = get_masked_area(masks)
            plt.subplot(1, 3, 3)
            plt.imshow(masked_area*int(255/max(masked_area.flatten())))
            plt.axis('off')

            plt.tight_layout()
            plt.show()