In [1]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import random
random.seed(0)
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4'
from scipy.ndimage import label
import time
import torch
import tqdm
from PIL import Image
import pprint
from segmentation import Segmenter, visualized_masks, get_masked_area

In [2]:
num_test_images = 100
visualize = False

img_dir = '/home/dchenbs/workspace/datasets/coco2017/images/val2017'
image_paths = []
for i in range(num_test_images):
    img_path = os.path.join(img_dir, random.choice(os.listdir(img_dir)))
    image_paths.append(img_path)

In [4]:
import json, requests
import pickle
import base64

port = '5000'

url = f'http://0.0.0.0:{port}/segment_provider'

for image_path in tqdm.tqdm(image_paths):
    content_lst = {
        'post_processing': True, 
        'img_path': image_path
    }
    d = {"content_lst": content_lst, 'typ': 'None'}
    d = json.dumps(d).encode('utf8')
    r = requests.post(url, data=d)
    js = json.loads(r.text)

    masks_bytes = base64.b64decode(js['result']['response'])
    masks = pickle.loads(masks_bytes)


100%|██████████| 100/100 [00:51<00:00,  1.96it/s]


In [None]:
models = [
    
    # ('sam', '/home/dchenbs/workspace/cache/sam_weights/sam/sam_vit_h_4b8939.pth'), 
    #     # 6.972 GB, 1.83 s/image

    # ('sam', '/home/dchenbs/workspace/cache/sam_weights/sam/sam_vit_l_0b3195.pth'), 
    #     # 5.516 GB, 1.59 s/image

    # ('sam', '/home/dchenbs/workspace/cache/sam_weights/sam/sam_vit_b_01ec64.pth'), 
    #     # 4.572 GB, 1.12 s/image
    
    # ('mobile_sam', '/home/dchenbs/workspace/cache/sam_weights/mobile_sam.pt'),
    #     # 4.376 GB, 1.24 s/image
    
    ('mobile_sam_v2', '/home/dchenbs/workspace/cache/sam_weights/mobile_sam_v2/l2.pt'),
    #     # 11.982 GB, 0.20 s/image
    
    # ('repvit_sam', '/home/dchenbs/workspace/Seq2Seq-AutoEncoder/RepViT/sam/weights/repvit_sam.pt'), 
    #     # 4.722 GB, 1.35 s/image
       
    # --- somehow broken

    # ('fast_sam', '/home/dchenbs/workspace/cache/sam_weights/fast_sam/FastSAM-s.pt'),  
    #     # 1.326 GB, 0.34 s/image

    # ('fast_sam', '/home/dchenbs/workspace/cache/sam_weights/fast_sam/FastSAM.pt'), 
    #     # 1.946 GB, 0.24 s/image
]

results = []
for model_name, checkpoint in models:
    print(f'Running [{model_name.upper()}]: {checkpoint.split("/")[-1]}')

    segmenter = None
    torch.cuda.empty_cache()
    segmenter = Segmenter(model_name, checkpoint)
    
    start = time.time()
    for img_path in tqdm.tqdm(image_paths):
        image = np.array(Image.open(img_path).convert('RGB'))
        masks = segmenter(img_path, post_processing=True)
        
        if visualize:
            plt.figure(figsize=(20, 8))
            plt.subplot(1, 3, 1)
            plt.imshow(image)
            plt.axis('off')

            canvas = visualized_masks(masks, image)
            plt.subplot(1, 3, 2)
            plt.imshow(canvas)
            plt.axis('off')

            masked_area, non_masked_area = get_masked_area(masks)
            plt.subplot(1, 3, 3)
            plt.imshow(masked_area*int(255/max(masked_area.flatten())))
            plt.axis('off')

            plt.tight_layout()
            plt.show()

    result = {
        'model_name': model_name,
        'checkpoint': checkpoint,
        'seconds_per_image': (time.time()-start)/num_test_images,
    }
    results.append(result)

pprint.pprint(results)