In [None]:
"""  OPTIONAL: Download the strawberries dataset from Roboflow  """
!pip install roboflow

from roboflow import Roboflow
rf = Roboflow(api_key="6kfjjO565pfpxqd6YL4S")
project = rf.workspace("skripsie").project("strawberry.00")
dataset = project.version(15).download("yolov5")

In [None]:
from collections import namedtuple
import numpy as np
import cv2
import os
import random
import torch
import time
from torchvision import ops

In [None]:
""" Setup Grounding DINO """

HOME = os.getcwd()
print(HOME)
%cd {HOME}
!git clone https://github.com/IDEA-Research/GroundingDINO.git
%cd {HOME}/GroundingDINO
!pip install -q -e .
!pip install -q roboflow

CONFIG_PATH = os.path.join(HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py")
print(CONFIG_PATH, "; exist:", os.path.isfile(CONFIG_PATH))

%cd {HOME}
!mkdir {HOME}/weights
%cd {HOME}/weights

!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
WEIGHTS_NAME = "groundingdino_swint_ogc.pth"
WEIGHTS_PATH = os.path.join(HOME, "weights", WEIGHTS_NAME)
print(WEIGHTS_PATH, "; exist:", os.path.isfile(WEIGHTS_PATH))

In [None]:
%cd {HOME}/GroundingDINO

from groundingdino.util.inference import load_model, load_image, predict, annotate
import supervision as sv

model = load_model(CONFIG_PATH, WEIGHTS_PATH)

In [None]:
""" Helpful Functions """


def get_image_label_paths(folder_path, n_files):
  paths = random.choices(os.listdir(folder_path), k=n_files)
  image_paths = [os.path.join(folder_path, path) for path in paths]
  label_paths = [path.replace("images", "labels") \
                 .replace("jpg", "txt") for path in image_paths]
  return image_paths, label_paths


def get_bbox_annotations(label_paths):
  anns = []
  for i, label_path in enumerate(label_paths):
    with open(label_path, "r") as file:
      ann = file.read().split("\n")
      ann = [x.split(" ") for x in ann]
      ann = [[float(x) for x in y][1:] for y in ann]
    anns.append(ann)
  return anns


def create_detection(path, gtruth, predict):
  detection_obj = Detection(path, gtruth, predict)
  return detection_obj


def compute_correct_coord(boxes, w, h):
  ok_boxes = []

  for box in boxes:
    blx = box[0]*w
    bly = box[1]*h
    trx = blx + box[2]*w
    tr_y = bly - box[3]*h

    tlx, tly = round(blx), round(tr_y)
    brx, bry = round(trx), round(bly)

    ok_boxes.append([tlx, tly, brx, bry])
  return ok_boxes


def compute_pixel_boxes(gt_boxes, pred_boxes, image):
  h, w = image.shape[:2]
  # input: bottom_left, width, height
  # output: top_left, bottom_right
  gt_boxes = compute_correct_coord(gt_boxes, w, h)
  pred_boxes = compute_correct_coord(pred_boxes, w, h)
  return gt_boxes, pred_boxes


def compute_iou(gtruth_bbox, pred_bbox):
  gtruth_bbox = torch.tensor(gtruth_bbox, dtype=torch.float)
  pred_bbox = torch.tensor(pred_bbox, dtype=torch.float)

  iou = ops.box_iou(gtruth_bbox, pred_bbox).numpy()
  iou = np.max(iou, axis=1)
  return iou


def compute_prec(iou, threshold, length_pred):
  prec = sum(iou > threshold)/length_pred
  return prec


def compute_average_precision(iou, length_pred):
  precs = [compute_prec(iou, thresh, length_pred) \
                for thresh in np.arange(0.5, 0.951, 0.05)]
  avg_prec = np.mean(precs)
  avg_prec = round(avg_prec, 2)
  return avg_prec

In [None]:
""" Variables """

INSPECT_VAR = True
BOX_TRESHOLD = 0.65
TEXT_TRESHOLD = 0.50
N_FILES = 50
FOLDER_PATH = f'{HOME}/strawberry.00-15/test/images'
TEXT_PROMPT = 'red strawberry'

# Define the 'Detection' object
Detection = namedtuple("Detection", ["image_path", "gt", "pred"])

image_paths, label_paths = get_image_label_paths(FOLDER_PATH, N_FILES)
bbox_anns = get_bbox_annotations(label_paths)


In [None]:

avg_precs = []
exec_times = []
for i, image_path in enumerate(image_paths):
  start_time = time.time()
  image_source, image = load_image(image_path)

  pred_boxes, logits, phrases = predict(model=model,
                                  image=image,
                                  caption=TEXT_PROMPT,
                                  box_threshold=BOX_TRESHOLD,
                                  text_threshold=TEXT_TRESHOLD,
                                  )
  gt_boxes_ok, pred_boxes_ok = compute_pixel_boxes(bbox_anns[i], pred_boxes.numpy(), image_source)
  length_pred = len(pred_boxes_ok)

  if length_pred == 0:
    continue

  detection = create_detection(path=image_path,
                              gtruth=gt_boxes_ok,
                              predict=pred_boxes_ok
                              )

  iou = compute_iou(detection.gt, detection.pred)

  avg_prec = compute_average_precision(iou, length_pred)
  avg_precs.append(avg_prec)

  end_time = time.time()
  exec_time = end_time - start_time
  exec_times.append(exec_time)
  # print(f'{i} - mAP: {mAP}')
  # print(f'----- IoU: {iou}\n')

mAP = round(np.mean(avg_precs), 2)
mET = round(np.mean(exec_times), 2)

In [None]:
print(f'mAP: {mAP}')
print(f'mET: {mET}s')

In [None]:
""" Inspect Variables """

if INSPECT_VAR:
  print(f'boxes type: {type(pred_boxes_ok)}')
  print(f'boxes: {pred_boxes_ok}\n')
  print(f'logits type: {type(logits)}')
  print(f'logits: {logits}\n')
  print(f'phrases type: {type(phrases)}')
  print(f'phrases: {phrases}\n')
  print(f'image : {image.shape}')
  print(f'image_source: {image_source.shape}')

In [None]:
annotated_frame = annotate(image_source=image_source,
                           boxes=pred_boxes,
                           logits=logits,
                           phrases=phrases)

%matplotlib inline
sv.plot_image(annotated_frame, (16, 16))