<a href="https://colab.research.google.com/github/RohanT766/Prompt-Mass-Object-Extraction/blob/main/Prompt_Mass_Object_Extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Make sure to run on GPU by going Runtime -> Change runtime type -> GPU or T4 GPU.

In [None]:
import cv2
import numpy as np
from PIL import Image
from google.colab import files
import matplotlib.pyplot as plt
import os
HOME = os.getcwd()

# Create necessary directories
INPUT_FOLDER = os.path.join(HOME, 'input_images')
OUTPUT_FOLDER = os.path.join(HOME, 'output_images')
os.makedirs(INPUT_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

### Run and click "Choose FIles" and select the images to be uploaded.

In [None]:
# Upload a folder of images
uploaded_files = files.upload()

In [None]:
# Move uploaded images to the input folder
for file_name in uploaded_files.keys():
    os.rename(os.path.join(HOME, file_name), os.path.join(INPUT_FOLDER, file_name))

In [None]:
# Clone the GroundingDINO repository
!git clone https://github.com/IDEA-Research/GroundingDINO.git
%cd {HOME}/GroundingDINO
!git checkout -q 57535c5a79791cb76e36fdb64975271354f10251
!pip install -q -e .

In [None]:
!pip install torch
!pip install torchvision

In [None]:
!pip install supervision==0.12.0

### If you receive an error saying "You must restart the runtime in order to use newly installed versions." click "RESTART RUNTIME", run the first cell, and then skip down to this cell below.

In [None]:
# Create weights directory
!mkdir -p {HOME}/weights

In [None]:
# Download GroundingDINO model weights
!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P {HOME}/weights

In [None]:
!pip install 'git+https://github.com/facebookresearch/segment-anything.git'

In [None]:
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
# Set up the environment
import torch
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# Set model paths
CHECKPOINT_PATH = f'{HOME}/weights/sam_vit_h_4b8939.pth'
GROUNDING_DINO_CONFIG_PATH = f'{HOME}/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
GROUNDING_DINO_CHECKPOINT_PATH = f'{HOME}/weights/groundingdino_swint_ogc.pth'

In [None]:
import supervision as sv
from groundingdino.util.inference import Model
from segment_anything import sam_model_registry, SamPredictor

In [None]:
GD_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)

In [None]:
# Load models
MODEL_TYPE = "vit_h"
CHECKPOINT_PATH = '/content/sam_vit_h_4b8939.pth'
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
mask_predictor = SamPredictor(sam)

### Input the object that you wish to extract out of all the images.

In [None]:
# Get the text prompt from the user
text_prompt = []
text_prompt.append(input("Object to segment: "))

In [None]:
# Detect and segment objects for each image
for image_name in os.listdir(INPUT_FOLDER):
    image_path = os.path.join(INPUT_FOLDER, image_name)
    image_bgr = cv2.imread(image_path)

    # Check if image_bgr is empty
    if image_bgr is None:
        print(f"Skipping image {image_name} due to empty image_bgr")
        continue

    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    image_original = image_bgr

    # Detect objects using Grounding Dino
    detections = GD_model.predict_with_classes(
        image=image_rgb,
        classes=text_prompt,
        box_threshold=0.30,
        text_threshold=0.25
    )
    print(detections)
    print(detections.xyxy, type(detections.xyxy))
    detected_boxes = detections.xyxy
    class_id = detections.class_id
    print(class_id)

    # Assuming class_id contains the class IDs of the detected objects
    class_labels = [f"Object_{id}" for id in class_id]

    box_annotator = sv.BoxAnnotator()
    annotated_frame = box_annotator.annotate(scene=image_bgr.copy(), detections=detections, skip_label=False, labels=class_labels)

    mask_annotator = sv.MaskAnnotator(color=sv.Color.blue())
    segmented_mask = []
    counter = 0
    for mybox in detected_boxes:
        mybox = np.array(mybox)
        print(mybox)

        try:
            mask_predictor.set_image(image_rgb)
            masks, scores, logits = mask_predictor.predict(
                point_coords=None,
                point_labels=None,
                box=mybox,
                multimask_output=False
            )

            segmented_mask.append(masks)
            print(len(masks), masks.shape)

            # plot mask on image using supervision
            detections = sv.Detections(
                xyxy=sv.mask_to_xyxy(masks=masks),
                mask=masks
            )

            detections = detections[detections.area == np.max(detections.area)]
            print(text_prompt[class_id[counter]])

            annotated_image = box_annotator.annotate(scene=image_original.copy(), detections=detections, skip_label=False, labels=[text_prompt[class_id[counter]]])
            annotated_image = mask_annotator.annotate(scene=annotated_image.copy(), detections=detections)
            image_original = annotated_image

            counter+=1

            #sv.plot_images_grid(
              #images=[image_bgr, annotated_image],
              #grid_size=(1,2),
              #titles=['Original Image', 'Mask Image']
            #)

            print(len(segmented_mask), type(segmented_mask[0]), segmented_mask[0].shape)

            for i in range(len(segmented_mask)):

              segmented_mask[i] = segmented_mask[i].transpose(1,2,0)
              segmented_mask[i] = np.array(segmented_mask[i]*255).astype('uint8')
              segmented_mask[i]  = cv2.cvtColor(segmented_mask[i] , cv2.COLOR_GRAY2BGR)


            print(segmented_mask[0].shape)

            segmented_image = segmented_mask[0]

            for i in range(len(segmented_mask)):
                try:
                    segmented_image = cv2.bitwise_or(segmented_image, segmented_mask[i+1])
                except:
                    pass

            #sv.plot_image(segmented_image)

            segmented_image = cv2.bitwise_and(segmented_image, image_bgr)
            #sv.plot_image(segmented_image)

            segmented_image[np.where((segmented_image == [0, 0, 0]).all(axis=2))] = [255, 255, 255]
            #sv.plot_image(segmented_image)

            sv.plot_images_grid(
                images=[image_bgr, annotated_image, segmented_image],
                grid_size=(1, 3),
                titles=['Original Image', 'Annotated Image', 'Segmented Image'],
                #size=(48,48)
            )

            # Convert BGR to RGB using OpenCV
            segmented_image_rgb = cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB)

            # Convert the RGB image array to a PIL image
            pil_image = Image.fromarray(segmented_image_rgb)

            # Define the path to save the image
            original_image_name = os.path.basename(image_name)  # Update with the actual path to your original image
            output_image_path = os.path.splitext(original_image_name)[0] + '_segmented.png'
            output_folder = '/content/output_images/'
            os.makedirs(output_folder, exist_ok=True)  # Create the output_images folder if it doesn't exist

            output_image_path = os.path.join(output_folder, os.path.splitext(original_image_name)[0] + '_segmented.png')

            # Save the PIL image as a file
            pil_image.save(output_image_path)

        except cv2.error as e:
            print(f"Skipping image {image_name} due to OpenCV error: {e}")
            continue


### Run to download segmented images.

In [None]:
import shutil

# Define the folder containing your PNG images
folder_path = '/content/output_images'  # Change this path if needed

# Define the name of the ZIP file
zip_filename = '/content/output_images.zip'  # Change the filename if needed

# Create a ZIP file containing the PNG images
shutil.make_archive(zip_filename.split('.zip')[0], 'zip', folder_path)

# Move the ZIP file to the current directory
shutil.move(zip_filename + '.zip', zip_filename)

# Check if the ZIP file was created successfully
if os.path.exists(zip_filename):
    print(f'ZIP file "{zip_filename}" created successfully!')
else:
    print(f'Failed to create ZIP file.')

# Provide a download link for the ZIP file
from google.colab import files
files.download(zip_filename)