In [1]:
import os
import cv2
import torch
import numpy as np
import urllib.request
import matplotlib.pyplot as plt

from tqdm import tqdm

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

In [2]:
if not os.path.isdir('models/'):
  print('Downloading Model')
  os.mkdir('models/')
  urllib.request.urlretrieve('https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', 'models/defaultModel.pth')

In [3]:
def resize_image(image, length):
    """
    Resize the given image while maintaining its aspect ratio.

    Parameters:
    image (numpy.ndarray): The input image to be resized.
    length (int): The desired length (either width or height) of the resized image.

    Returns:
    numpy.ndarray: The resized image.

    """
    # Get the current dimensions of the image
    height, width = image.shape[:2]
    
    if height > width:
        is_height = True
    else:
        is_height = False

    if is_height:
        # Calculate the new width based on the provided height
        new_width = int((length / height) * width)
        # Resize the image using the new dimensions
        resized_image = cv2.resize(image, (new_width, length))
    else:
        # Calculate the new height based on the provided width
        new_height = int((length / width) * height)
        # Resize the image using the new dimensions
        resized_image = cv2.resize(image, (length, new_height))
    
    return resized_image

# Segmentation Function

In [4]:
def pre_segmentate(image_dir: str, size: int, max_num_of_segments: int = 5, output_dir_masks: str = 'segments', output_dir_img: str = None,
                   model_checkpoint: str = 'models/defaultModel.pth'):
    """
    Pre-segmentates images in the given directory.

    Args:
        image_dir (str): The directory path containing the images to be pre-segmented.
        size (int): The desired size of the pre-segmented images.
        max_num_of_segments (int, optional): The maximum number of segments to be generated for each image. Defaults to 5.
        output_dir_masks (str, optional): The directory path to save the generated segment masks. Defaults to 'segments'.
        output_dir_img (str, optional): The directory path to save the resized images. Defaults to None.

    Returns:
        None
    """
    # Create output directories if they don't exist
    if not os.path.isdir(output_dir_masks):
        os.mkdir(output_dir_masks)
    if output_dir_img is not None and not os.path.isdir(output_dir_img):
        os.mkdir(output_dir_img)
    
    # Loading sam model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print('Using:',device)
    sam = sam_model_registry['default'](checkpoint=model_checkpoint).to(device)
    mask_generator = SamAutomaticMaskGenerator(sam)

    # Iterate over all images in the directory
    for file in tqdm(os.listdir(image_dir)):
        image_name = file[:file.rfind('.')]

        # Check if image was already segmented
        if os.path.isdir(os.path.join(output_dir_masks, image_name)):
            continue         
        
        # Load image
        image = cv2.imread(os.path.join(image_dir, file))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = resize_image(image, size)

        # Save resized image
        if output_dir_img is not None:
            cv2.imwrite(os.path.join(output_dir_img, image_name + '.png'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

        # Generate masks
        masks = mask_generator.generate(image)

        # Segment selection #################################################################
        masks = sorted(masks, key=lambda x: x['area'], reverse=True)

        selected = []
        for mask in masks:
            # Check if the bbox is too big
            bboxArea = (mask['bbox'][2] - mask['bbox'][0]) * (mask['bbox'][3] - mask['bbox'][1])
            imgArea = image.shape[0] * image.shape[1]
            if bboxArea/imgArea <= 0.9:
                selected.append(mask)

            if len(selected) == max_num_of_segments:
                break

        #Save masks
        os.mkdir(os.path.join(output_dir_masks, image_name))
        for i,mask in enumerate(selected):
            toSave = mask['segmentation'].astype(np.uint8) * 255
            cv2.imwrite(os.path.join(output_dir_masks, image_name, f'{i}.png'), toSave)

            
pre_segmentate('datasets/GonAesthetics', 224, output_dir_img='resized_images')

Using: cuda


100%|██████████| 5/5 [00:22<00:00,  4.52s/it]
