Install the Required Packages and Download the SAM Checkpoint

In [None]:
!pip install 'git+https://github.com/facebookresearch/segment-anything.git'
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth


Import the Required Libraries

In [None]:
import sys
import os
import cv2
from google.colab.patches import cv2_imshow
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator


Define the SAMMaskGenerator Class

In [None]:
class SAMMaskGenerator:
    def __init__(self, model_type, checkpoint, device):
        self.sam = sam_model_registry[model_type](checkpoint=checkpoint)  # Initialize the SAM model
        self.sam.to(device=device)
        self.mask_generator = SamAutomaticMaskGenerator(self.sam)  # Initialize the mask generator

    def generate_and_save_mask(self, image_path, save_all=False):
        # Load the image
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Generate the masks
        results = self.mask_generator.generate(image)

        # Sort the results by area in descending order
        results.sort(key=lambda x: x['area'], reverse=True)

        # Create a folder with the same name as the image
        image_name = os.path.splitext(os.path.basename(image_path))[0]
        folder_name = f'{image_name}_masks'
        os.makedirs(folder_name, exist_ok=True)

        if save_all:
            # Save all masks
            for i, result in enumerate(results):
                # Get the binary mask
                binary_mask = result['segmentation']

                # Convert the binary mask to uint8 format
                binary_mask_uint8 = (binary_mask * 255).astype('uint8')

                # Save the mask inside the folder
                mask_name = f'{folder_name}/mask_{i}.png'
                cv2.imwrite(mask_name, binary_mask_uint8)
        else:
            # Get the binary mask with the largest area
            binary_mask = results[0]['segmentation']

            # Convert the binary mask to uint8 format
            binary_mask_uint8 = (binary_mask * 255).astype('uint8')

            # Display the mask
            cv2_imshow(binary_mask_uint8)

            # Save the mask inside the folder
            mask_name = f'{folder_name}/mask_with_largest_area.png'
            cv2.imwrite(mask_name, binary_mask_uint8)


Define the Checkpoint and Model Type

In [None]:
sam_checkpoint = "sam_vit_h_4b8939.pth"  # Replace with the path to the downloaded SAM checkpoint
model_type = "vit_h"


Initialize the SAMMaskGenerator

In [None]:
mask_generator = SAMMaskGenerator(model_type, sam_checkpoint, 'cuda')


 Generate and Save the Masks (Save All Masks)

In [None]:
image_path = 'path_to_your_image.jpg'  # Replace with the actual path to your image file
mask_generator.generate_and_save_mask(image_path, save_all=True)
