In [None]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import cv2

In [None]:
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from groundingdino.util.inference import predict

In [None]:
config_file_path = './tracking_SAM/third_party/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
model_path = './pretrained_weights/groundingdino_swint_ogc.pth'

args = SLConfig.fromfile(config_file_path) 
device = 'cpu'

dino_model = build_model(args)

checkpoint = torch.load(model_path, map_location='cpu')
log = dino_model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
dino_model.eval()
dino_model = dino_model.to(device)

print(log)

In [None]:
test_img_path = "./sample_data/DAVIS_bear/images/00000.jpg"
image_np = np.asarray(Image.open(test_img_path).convert("RGB"))

transform = T.Compose(
    [
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

img_chw, _ = transform(Image.fromarray(image_np), None)

In [None]:
text_prompt = 'bear'

BOX_TRESHOLD = 0.3
TEXT_TRESHOLD = 0.25

In [None]:
boxes, logits, phrases = predict(
    model=dino_model, 
    image=img_chw, 
    caption=text_prompt, 
    box_threshold=BOX_TRESHOLD, 
    text_threshold=TEXT_TRESHOLD,
    device=device
)

In [None]:
H, W, _ = image_np.shape

boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

In [None]:
viz_img = image_np.copy()

for box, phrase in zip(boxes_xyxy, phrases):
    box = box.cpu().numpy().astype(np.int32)
    viz_img = cv2.rectangle(viz_img, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
    viz_img = cv2.putText(viz_img, phrase, (box[0], box[1]), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)

plt.imshow(viz_img)

In [None]:
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "./pretrained_weights/sam_vit_h_4b8939.pth"  # default model

model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

In [None]:
predictor.set_image(image_np)

In [None]:
assert len(boxes_xyxy) == 1

In [None]:
input_box = boxes_xyxy[0].cpu().numpy()

In [None]:
masks, _, _ = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box[None, :],
    multimask_output=False,
)

In [None]:
plt.imshow(masks[0])