<a href="https://colab.research.google.com/github/anarghya15/notebooks/blob/main/notebooks/Grounded_SAM_Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Grounding DINO and Segment Anything Model

In [None]:
import os
HOME = os.getcwd()
print("HOME:", HOME)
%cd {HOME}
!git clone https://github.com/IDEA-Research/GroundingDINO.git
%cd {HOME}/GroundingDINO
!git checkout -q 57535c5a79791cb76e36fdb64975271354f10251
!pip install -q -e .

In [None]:
%cd {HOME}

import sys
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
!pip uninstall -y supervision
!pip install -q supervision==0.6.0

## Download Grounding DINO Model Weights

In [None]:
import os

GROUNDING_DINO_CONFIG_PATH = os.path.join(HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py")
print(GROUNDING_DINO_CONFIG_PATH, "; exist:", os.path.isfile(GROUNDING_DINO_CONFIG_PATH))

In [None]:
%cd {HOME}
!mkdir -p {HOME}/weights
%cd {HOME}/weights

!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth

GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(HOME, "weights", "groundingdino_swint_ogc.pth")
print(GROUNDING_DINO_CHECKPOINT_PATH, "; exist:", os.path.isfile(GROUNDING_DINO_CHECKPOINT_PATH))

## Download Segment Anything Model (SAM) Weights

In [None]:
%cd {HOME}
!mkdir -p {HOME}/weights
%cd {HOME}/weights

!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

SAM_CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(SAM_CHECKPOINT_PATH, "; exist:", os.path.isfile(SAM_CHECKPOINT_PATH))

# Load Models

In [None]:
import torch

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

%cd {HOME}/GroundingDINO

from groundingdino.util.inference import Model

grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)

SAM_ENCODER_VERSION = "vit_h"

from segment_anything import sam_model_registry, SamPredictor

sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH).to(device=DEVICE)
sam_predictor = SamPredictor(sam)

# Test

In [None]:
import cv2

def detect_and_segment(image_path, classes, BOX_TRESHOLD, TEXT_TRESHOLD):
    image = cv2.imread(image_path)

    detections = grounding_dino_model.predict_with_classes(
        image=image,
        classes=enhance_class_name(class_names=classes),
        box_threshold=BOX_TRESHOLD,
        text_threshold=TEXT_TRESHOLD
    )
    detections = detections[detections.class_id != None]
    detections.mask = segment(
        sam_predictor=sam_predictor,
        image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
        xyxy=detections.xyxy
    )

    labels = [
        f"{classes[class_id]} {confidence:0.2f}"
        for _, _, confidence, class_id, _
        in detections]
    annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
    annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)

    title = " ".join(set([
        classes[class_id]
        for class_id
        in detections.class_id
    ]))

    return annotated_image, title

## Classes are generic and threshold is low

In [None]:
SOURCE_IMAGE_PATH = f"{HOME}/data/dog-3.jpeg"
CLASSES = ['car', 'dog', 'person', 'nose', 'chair', 'shoe', 'ear']

BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25

In [None]:
annotated_image, title = detect_and_segment(SOURCE_IMAGE_PATH, CLASSES, BOX_TRESHOLD, TEXT_TRESHOLD)
%matplotlib inline
sv.plot_image(image=annotated_image, title=title, size=(16, 16))

## Classes are specific, threshold is low

In [None]:
SOURCE_IMAGE_PATH = f"{HOME}/data/dog-3.jpeg"
CLASSES = ['white car', 'black dog']

BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25

In [None]:
annotated_image, title = detect_and_segment(SOURCE_IMAGE_PATH, CLASSES, BOX_TRESHOLD, TEXT_TRESHOLD)
%matplotlib inline
sv.plot_image(image=annotated_image, title=title, size=(16, 16))

## Classes are generic, threshold is high

In [None]:
SOURCE_IMAGE_PATH = f"{HOME}/data/dog-3.jpeg"
CLASSES = ['car', 'dog', 'person', 'nose', 'chair', 'shoe', 'ear']

BOX_TRESHOLD = 0.5
TEXT_TRESHOLD = 0.5

In [None]:
annotated_image, title = detect_and_segment(SOURCE_IMAGE_PATH, CLASSES, BOX_TRESHOLD, TEXT_TRESHOLD)
%matplotlib inline
sv.plot_image(image=annotated_image, title=title, size=(16, 16))

## Classes are specific threshold is high

In [None]:
SOURCE_IMAGE_PATH = f"{HOME}/data/dog-3.jpeg"
CLASSES = CLASSES = ['white car', 'black dog']

BOX_TRESHOLD = 0.5
TEXT_TRESHOLD = 0.5

In [None]:
annotated_image, title = detect_and_segment(SOURCE_IMAGE_PATH, CLASSES, BOX_TRESHOLD, TEXT_TRESHOLD)
%matplotlib inline
sv.plot_image(image=annotated_image, title=title, size=(16, 16))