In [13]:
import os
import time
import groundingdino.datasets.transforms as T
import numpy as np
import torch
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.inference import predict
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from huggingface_hub import hf_hub_download
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches


def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)

    args = SLConfig.fromfile(cache_config_file)
    model = build_model(args)
    args.device = device

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location='cpu')
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    print(f"Model loaded from {cache_file} \n => {log}")
    model.eval()
    return model


def transform_image(image) -> torch.Tensor:
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    image_transformed, _ = transform(image, None)
    return image_transformed


def plot_boxes(image, boxes):
    fig, ax = plt.subplots(1)
    ax.imshow(image)
    for box in boxes:
        x_min, y_min, x_max, y_max = box
        box_width = x_max - x_min
        box_height = y_max - y_min
        rect = patches.Rectangle((x_min, y_min), box_width, box_height,
                                  linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        # ax.text(x_min, y_min, phrase, color='r', fontsize=12, verticalalignment='top')
    plt.show()


class GDino():
    def __init__(self, return_prompts: bool = False):
        self.return_prompts = return_prompts
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.build_groundingdino()

    def build_groundingdino(self):
        ckpt_repo_id = "ShilongLiu/GroundingDINO"
        ckpt_filename = "groundingdino_swinb_cogcoor.pth"
        ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"
        self.groundingdino = load_model_hf(ckpt_repo_id, ckpt_filename, ckpt_config_filename)

    def predict_dino(self, image_pil, text_prompt, box_threshold, text_threshold):
        image_trans = transform_image(image_pil)
        boxes, logits, phrases = predict(model=self.groundingdino,
                                         image=image_trans,
                                         caption=text_prompt,
                                         box_threshold=box_threshold,
                                         text_threshold=text_threshold,
                                        #  remove_combined=self.return_prompts, # modified
                                         device=self.device)
        W, H = image_pil.size
        boxes = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

        # return boxes, logits, phrases
        return boxes

    def predict(self, image_pil, text_prompt, box_threshold=0.3, text_threshold=0.25):
        boxes = self.predict_dino(image_pil, text_prompt, box_threshold, text_threshold)
        return boxes

model = GDino()

final text_encoder_type: bert-base-uncased




Model loaded from C:\Users\Sushaanth\.cache\huggingface\hub\models--ShilongLiu--GroundingDINO\snapshots\a94c9b567a2a374598f05c584e96798a170c56fb\groundingdino_swinb_cogcoor.pth 
 => _IncompatibleKeys(missing_keys=[], unexpected_keys=['label_enc.weight', 'bert.embeddings.position_ids'])


In [15]:
start_time = time.time()

prompt = 'cube'

image_pil = Image.open("./captured_images/20240519-152936649981.png").convert("RGB")
boxes = model.predict(image_pil, prompt)
# plot_boxes(image_pil, boxes)

print("--- %s seconds ---" % (time.time() - start_time))

--- 1.921921968460083 seconds ---


In [18]:
# start_time = time.time()

prompt = 'cube'
image_pil = Image.open("./captured_images/20240519-152939714675.png").convert("RGB")
boxes = model.predict(image_pil, prompt)
# plot_boxes(image_pil, boxes)

# print("--- %s seconds ---" % (time.time() - start_time))