### Based on the following:
Understanding how and why it works: <br>
> https://www.youtube.com/watch?v=0Fpb8TBH0nM <br>

Example for using SAM and processing its output: <br>
> https://github.com/IDEA-Research/Grounded-Segment-Anything/blob/main/grounded_sam.ipynb

# Preparing the environment for first time use
## Dependencies
To start off, you need to install <font color='green'>__CUDA__</font>:<br> 
> __https://developer.nvidia.com/cuda-downloads__

After you have installed the <font color='green'>__CUDA toolkit__</font>, you will have to install the latest version of <font color='ff8000'>__PyTorch__</font>.
> __https://pytorch.org/get-started/locally/__
<br>
This site will generate the command you need to install the package.

In your Python environment, run the following command to install missing dependencies:
> `pip install diffusers transformers accelerate scipy safetensors`

## AI Models
This script uses two models:
- [Segment Anything](https://github.com/facebookresearch/segment-anything) is a strong segmentation model. But it need prompts (like boxes/points) to generate masks.
- [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO) is a strong zero-shot detector which enable to generate high quality boxes and labels with free-form text.

## Script setup
Create a folder on your system. For the sake of this tutorial, we will call it `mainFolder` <br>
Once you are in `mainFolder`, copy the Grounding DINO repository by running the command
> __<font color='#ff8000'>git clone https://github.com/IDEA-Research/GroundingDINO.git</font>__

In the same folder, clone the SAM repository:
> __<font color='#ff8000'>git clone https://github.com/facebookresearch/segment-anything.git</font>__

You also need to download a checkpoint for SAM, which has to be placed in `mainFolder`.
> __https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth__

Once you have all the files, you must do the following using your Python environment:
>1. Navigate to `mainFolder\GroundingDINO` and run the command `pip install -e .`
>
>3. Navigate to `mainfolder\segment_anything` and run the command `pip install -e .`

If there are no errors, then you have successfully installed the required dependencies.

The script will look for images to process in `mainFolder\content\input` , and will generate the output in `mainFolder\content\output` . It will mirror the folder structure from the input folder.
<br>
Once you have finished a batch of files, <font color='red'>make sure to remove them from the input folder</font>, as they will get processed again otherwise.

# Settings


In [19]:
# Object(s) to search for. Examples: "rock, card, plane" etc.
# Value type: string
search_for = 'rock; card'

# Set whether or not to stop at the first matching object.
# Value type: boolean
only_first_match = False

# Print additional information (used for debugging)
# Value type: boolean
debug_print = True

def dprint(message: str):
    if debug_print:
        print(message)

### Just run everything from top to bottom

In [2]:
import os, sys
import argparse
import os
import copy

import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont, ImageSequence
from torchvision.ops import box_convert

# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from GroundingDINO.groundingdino.util.inference import annotate, load_image, predict

import supervision as sv

# segment anything
from segment_anything import build_sam, SamPredictor
import cv2
import numpy as np
import matplotlib.pyplot as plt


# diffusers
import PIL
import requests
import torch
from io import BytesIO
from diffusers import StableDiffusionInpaintPipeline


from huggingface_hub import hf_hub_download

# If you have multiple GPUs, you can set the GPU to use here.
# The default is to use the first GPU, which is usually GPU 0.
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Load Grounding DINO model

In [None]:
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)

    args = SLConfig.fromfile(cache_config_file)
    model = build_model(args)
    args.device = device

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location='cpu')
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    print("Model loaded from {} \n => {}".format(cache_file, log))
    _ = model.eval()
    return model

# Use this command for evaluate the Grounding DINO model
# Or you can download the model by yourself
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"

groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)

# Load SAM model

In [None]:
sam_checkpoint = 'sam_vit_h_4b8939.pth'
sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))

# Run Grounding DINO for detection
Runs the Grounding DINO model on a given image.

Returns an array with processed boxes for said image.

In [5]:
def run_grounding_dino(image):
    # search for objects of interest
    TEXT_PROMPT = search_for
    BOX_TRESHOLD = 0.3
    TEXT_TRESHOLD = 0.25

    boxes, logits, phrases = predict(
        model=groundingdino_model,
        image=image,
        caption=TEXT_PROMPT,
        box_threshold=BOX_TRESHOLD,
        text_threshold=TEXT_TRESHOLD
    )

    return boxes

# Run the segmentation model
Runs the Segment Anything model on a given image with annotated boxes.

Returns an array with the masks identified by the model for a given image.

In [6]:
def run_sam(image_source, boxes):
    # set image for SAM
    sam_predictor.set_image(image_source)

    # box: normalized box xywh -> unnormalized xyxy
    H, W, _ = image_source.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

    transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2])
    masks, _, _ = sam_predictor.predict_torch(
            point_coords = None,
            point_labels = None,
            boxes = transformed_boxes,
            multimask_output = False,
        )

    return masks

# Image Crop
Crops the image using the masks.

Returns the result image in OpenCV format.

In [8]:
# use for multiple objects
def crop_image(image_source, masks):
    # extract masks
    image_mask_pil = Image.fromarray(masks[0][0].cpu().numpy())
    for mask in masks:
        current_mask = Image.fromarray(mask[0].cpu().numpy())
      
    # combine masks into one
    image_mask_pil.paste(current_mask, (0, 0), mask=current_mask)
 
    # convert mask & image to np arrays
    mask_cv = np.array(image_mask_pil.convert('RGB'))
    image_cv = np.array(Image.fromarray(image_source).convert('RGB'))

    # Convert RGB to BGR for cv2
    mask_cv = mask_cv[:, :, ::-1].copy()

    # result image with black background
    masked = cv2.bitwise_and(image_cv, mask_cv)

    tmp = cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY)
    _, alpha = cv2.threshold(tmp,0,255,cv2.THRESH_BINARY)
    b, g, r = cv2.split(masked)
    rgba = [r,g,b, alpha]

    # result image with transparent background
    masked_transparent = cv2.merge(rgba,4)

    return masked_transparent

# use for a single object
def crop_image_single(image_source, masks):
    # extract masks
    image_mask_pil = Image.fromarray(masks[0][0].cpu().numpy())
 
    # convert mask & image to np arrays
    mask_cv = np.array(image_mask_pil.convert('RGB'))
    image_cv = np.array(Image.fromarray(image_source).convert('RGB'))

    # Convert RGB to BGR for cv2
    mask_cv = mask_cv[:, :, ::-1].copy()

    # result image with black background
    masked = cv2.bitwise_and(image_cv, mask_cv)

    tmp = cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY)
    _, alpha = cv2.threshold(tmp,0,255,cv2.THRESH_BINARY)
    b, g, r = cv2.split(masked)
    rgba = [r,g,b, alpha]

    # result image with transparent background
    masked_transparent = cv2.merge(rgba,4)

    return masked_transparent

# Image Processing
Runs the model (GroundingDINO + SAM + crop) for a given image.

Returns the result image in OpenCV format.

In [9]:
if only_first_match:
    cropping_function = crop_image_single
else:
    cropping_function = crop_image

def process_image(image_path):
    image_source, image = load_image(image_path)
    boxes = run_grounding_dino(image)
    if not (boxes.numel()):
        raise Exception('No object matching description was found')
      
    masks = run_sam(image_source, boxes)
    final_image = cropping_function(image_source, masks)
    return final_image

def load_frame(frame):
    transform = T.Compose(
                [
                    T.RandomResize([800], max_size=1333),
                    T.ToTensor(),
                    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                ]
            )
    frame = frame.convert('RGB')
    image = np.asarray(frame)
    image_transformed, _ = transform(frame, None)
    return image, image_transformed

def process_gif(image_path):  
    with Image.open(image_path) as im:
        frame_count = 1
        filename, file_extension = os.path.splitext(image_path)
        file_extension = file_extension.lower()
        result_path = filename.replace('input', 'output', 1) + '\\'

        # create a folder having the file's name
        if not os.path.exists(result_path):
           os.makedirs(result_path)
                    
        # extract all frames from gif
        frames = [frame.copy() for frame in ImageSequence.Iterator(im)]
        dprint(f'Current file is a gif, having {len(frames)} frames')
        
        # process frame by frame
        for frame in frames:
            image, image_transformed = load_frame(frame)
            
            boxes = run_grounding_dino(image_transformed)
            if not (boxes.numel()):
                dprint(f'Object not found in frame {str(frame_count)}')
                frame_count += 1
                continue
                          
            masks = run_sam(image, boxes)
            final_image = cropping_function(image, masks)   
            frame_path = result_path + '\\' + str(frame_count) + '.png'
            
            dprint(f'Saving frame {str(frame_count)} to {frame_path}')
            cv2.imwrite(frame_path, final_image)
            frame_count += 1

        print(f'Saved processed frames in {result_path}')

# Batch Processing
Process all images from the input folder and save the output.


In [None]:
image_extensions = ['.png', '.jpg', '.jpeg', '.webp', '.gif']

def process_image_or_navigate_folder(file, in_path):
    current_path = in_path + '\\' + file
    filename, file_extension = os.path.splitext(current_path)
    
    try:
        # current path is a folder
        if (os.path.isdir(current_path)):
            for file_name in os.listdir(current_path):
                process_image_or_navigate_folder(file_name, current_path)

        # or an image
        elif file_extension.lower() in image_extensions:
            dprint(f'Currently in {in_path}')
            print(f'Processing image {file}')

            if file_extension == '.gif':
                process_gif(current_path)
                    
            else:
                file = file.replace(file_extension, '.png')
        
                result = process_image(current_path)
                result_path = in_path.replace('input', 'output', 1) + '\\'
                
                if not os.path.exists(result_path):
                    os.makedirs(result_path)
                    
                result_path += file
                print(f'Saving processed file to {result_path}')
                cv2.imwrite(result_path, result)
            
    except Exception as error:
        print(error)

def process_all():
    if only_first_match:
        dprint('Script will save only the first matching object')
    else:
        dprint('Script will save all matching objects')

    dprint('Searching for: ' + search_for)
    
    current_directory = os.getcwd()
    input_directory = current_directory + '\\content\\input'
    output_directory = current_directory + '\\content\\output'

    dprint(f'Files will be read from {input_directory}')
    dprint(f'Processed images will be saved to {output_directory}')
    
    if not os.path.exists(input_directory):
        os.makedirs(input_directory)
        print('Input directory has been created. Add images and restart the script.')
    
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    
    for file_name in os.listdir(input_directory):
        process_image_or_navigate_folder(file_name, input_directory)

process_all()