In [1]:
import matplotlib.pyplot as plt
from PIL import Image
import json
import random
import cv2
import pickle
import numpy as np
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
import torch
import tqdm
device = "cuda:0" if torch.cuda.is_available() else "cpu"
random.seed(41)

In [2]:
max_dim = 2000
min_dim = 600

def getFrontCutout(masks,image):
    frontMask = None
    if len(masks) == 0:
        return frontMask
    elif len(masks) == 1:
        frontMask = masks[0]
    elif len(masks) > 2:
        if masks[0]['bbox'][0] <=5 and masks[0]['bbox'][1] <=5:
            #masks[0] is background, choose from 1 or 2
            if masks[1]['area'] > masks[2]['area']*1.5: #masks[1] is much bigger than masks[2]
                frontMask = masks[1]
            elif masks[1]['bbox'][1] < masks[2]['bbox'][1]:  # ycoordinate of front mask will be smaller
                frontMask = masks[1]
            else:
                frontMask = masks[2]

        else:
            frontMask = masks[0]
    else:
        if masks[0]['bbox'][0] <=5 and masks[0]['bbox'][1] <=5:
            frontMask = masks[1]
        else:
            frontMask = masks[0]
    x,y,w,h = frontMask['bbox']
    x,y,w,h = int(x), int(y), int(w), int(h)
    cutout = image[y:y+h, x:x+w]
    return cutout

def resizeImage(image):
    
    while image.shape[0]/max_dim > 1 or image.shape[1]/max_dim > 1:
        dim = (int(image.shape[1]/2), int(image.shape[0]/2))
        image = cv2.resize(image, dim, interpolation = cv2.INTER_AREA)
    return image 


def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((np.array(sorted_anns[0]['segmentation']).shape[0], np.array(sorted_anns[0]['segmentation']).shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)



In [3]:
#try segmentation on single image
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=5,
    pred_iou_thresh=0.94,
    stability_score_thresh=0.90,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=10000,  # Requires open-cv to run post-processing
)

In [4]:
with open('../../../full_data/all_ids.json', 'r') as f:
    all_ids = json.load(f)
    
len(all_ids)

56694

In [None]:
for pid in tqdm.tqdm(all_ids[56000:56300]):
    try:
        image_path = "../../../full_data/images/"+ "P"+ str(pid).zfill(6)+".jpg"
        image = cv2.imread(image_path)
        image_print = cv2.cvtColor(image, cv2. COLOR_BGR2RGB)
        width = image.shape[1]
        height = image.shape[0]

        #if very high resolution image, then resize, run segmentation, save cutout, masks and resized image
        if height>max_dim or width>max_dim:
            image = resizeImage(image)
            cv2.imwrite("../../../full_data/images/"+ "P"+ str(pid).zfill(6)+".jpg", image)
            image_print = cv2.cvtColor(image, cv2. COLOR_BGR2RGB)
            masks = mask_generator.generate(image)
            masks = sorted(masks, key = lambda d: d['area'], reverse = True)
            topFive = masks[:5]
            newFilePath = "../../../full_data/segmented_mask_info_compressed/P" + str(pid).zfill(6) +".pkl"
            with open(newFilePath, 'wb') as f:
                pickle.dump(topFive,f)
            cutout = getFrontCutout(topFive, image)
            cutout = cv2.cvtColor(cutout, cv2. COLOR_BGR2RGB)
            cv2.imwrite("../../../full_data/segmented_images/"+ "P"+ str(pid).zfill(6)+".jpg", cutout)

        #if low resolution image, keep as it is
#         elif height<=min_dim and width<=min_dim:
#             continue

#         #use existing masks, but get a new cutout
#         else:

#             masks_filepath = "../../../full_data/segmented_mask_info_compressed/"+ "P"+ str(pid).zfill(6)+".pkl"
#             with open (masks_filepath, 'rb') as f:
#                 masks = pickle.load(f)

#             cutout = getFrontCutout(masks, image)
#             cutout = cv2.cvtColor(cutout, cv2. COLOR_BGR2RGB)
#             cv2.imwrite("../../../full_data/segmented_images/"+ "P"+ str(pid).zfill(6)+".jpg", cutout)
    except:
        print(pid)

 67%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                         | 200/300 [06:36<07:32,  4.53s/it]