In [1]:
import os
import re
from itertools import combinations, product
from collections import defaultdict
from tqdm import tqdm
import yaml
import numpy as np
import cv2
import torch
import supervision as sv
from groundingdino.util.inference import Model
from mobile_sam import sam_model_registry, SamPredictor
  
# Define Function for IoU
def get_iou(boxA, boxB):
    # determine the (x, y)-coordinates of the intersection rectangle
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    # compute the area of intersection rectangle
    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
    # compute the area of both the prediction and ground-truth
    # rectangles
    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
    # compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the interesection area
    iou = interArea / float(boxAArea + boxBArea - interArea)

    return iou

# merge function to  merge all sublist having common elements. 
def merge_common(lists): 
    neigh = defaultdict(set) 
    visited = set() 
    for each in lists: 
        for item in each: 
            neigh[item].update(each) 
    def comp(node, neigh = neigh, visited = visited, vis = visited.add): 
        nodes = set([node]) 
        next_node = nodes.pop 
        while nodes: 
            node = next_node() 
            vis(node) 
            nodes |= neigh[node] - visited 
            yield node 
    for node in neigh: 
        if node not in visited: 
            yield sorted(comp(node))

# Define Function for refining detections
def refine_detections(classes, detections, labels):
    # Select only boxes that correspond to a class
    detect_dict = {}
    detect_index = []    
    for i in classes:
        detect_dict[i] = list(filter(lambda x: labels[x] == i, range(len(labels))))
        detect_index = detect_index + list(filter(lambda x: labels[x] == i, range(len(labels))))

    # Refine to match one box per object based on IoU
    dup_id_lst = []
    for class_c in list(combinations(detect_dict.keys(), 2)):
        for box_c in product(detect_dict[class_c[0]], detect_dict[class_c[1]]):
            box_iou = get_iou(detections.xyxy[box_c[0]], detections.xyxy[box_c[1]])
            if box_iou >= 0.9:
                dup_id_lst.append(list(box_c))
    # Refine detect_index based on duplicated list
    dup_lst = list(merge_common(dup_id_lst))
    for dup in dup_lst:
        dup_rep_id = dup[np.argmax([detections.confidence[i] for i in dup])]
        detect_index = list(set([dup_rep_id if i in dup else i for i in detect_index]))
    
    return detect_index
    

def dino_process(g_dino_config_path, g_dino_weights_path, img, box_thresh, text_thresh, classes=None, text_prompt=None, to_prompt=False):
    model = Model(model_config_path=g_dino_config_path, model_checkpoint_path=g_dino_weights_path)
    
    if classes is not None:
        class_dict = dict(zip(classes, range(len(classes))))
        if to_prompt==False:
            detections = model.predict_with_classes(
                image=img,
                classes=classes, #enhance_class_name(class_names=classes)
                box_threshold=box_thresh,
                text_threshold=text_thresh
            )
            labels = [
                f"{classes[class_id]}"
                for _, _, _, class_id, _
                in detections]
        else:
            prompt = " , ".join(classes)
            detections, labels = model.predict_with_caption(
                image=img,
                caption=prompt,
                box_threshold=box_thresh,
                text_threshold=text_thresh
            )
            # Refine bounding box result by iou
            detect_index = refine_detections(classes, detections, labels)
            # update labels
            labels = np.array(labels)[detect_index]
            # update detections
            detections.xyxy = np.array(detections.xyxy)[detect_index]
            detections.confidence = np.array(detections.confidence)[detect_index]
            detections.class_id = np.array([class_dict[x] for x in labels])

    elif (classes is None) & (text_prompt is not None):
        detections, labels = model.predict_with_caption(
            image=img,
            caption=text_prompt,
            box_threshold=box_thresh,
            text_threshold=text_thresh
        )
    # return the detections & labels
    return detections, labels

def mobile_sam_process(device_nm, img, detections):
    sam_ckpt = "./mobile_sam.pt"
    sam_type = "vit_t"
    sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device_nm).eval()
    # sam.to(device=device_nm)
    # sam.eval()
    sam_predictor = SamPredictor(sam)
    sam_predictor.set_image(img)
    masks_lst = []
    for bbox in detections.xyxy:        
        masks, scores, logits = sam_predictor.predict(
            point_coords=None,
            point_labels=None,
            box = bbox[None,:],
            multimask_output=True,
        )
        best_mask = masks[np.argmax(scores), :, :]

        masks_lst.append(best_mask)
    detections.mask = masks_lst
    
    return detections

def write_cfg_yaml(dir_path, classes):
    yaml_dict = [{'names':classes},
                 {'nc':len(classes)},
                 {'path':dir_path},
                 {'test':'test/images'},
                 {'train':'train/images'},
                 {'val':'valid/images'}]
    with open(f'{dir_path}/config.yaml', 'w') as file:
        documents = yaml.dump(yaml_dict, file)
    
    return True

def detections_to_str(img, detections, detect_type):
    img_width = img.shape[1]
    img_height = img.shape[0]
    result_str = ""
    if detect_type=='box':
        for i in range(len(detections.class_id)):
            c = detections.class_id[i]
            bb = detections.xyxy[i]
            nm_bb = np.hstack((np.expand_dims((bb[0]+bb[2])/(2*img_width), axis=0),
                               np.expand_dims((bb[1]+bb[3])/(2*img_height), axis=0),
                               np.expand_dims((bb[2]-bb[0])/img_width, axis=0),
                               np.expand_dims((bb[3]-bb[1])/img_height, axis=0)))
            nm_bb_str = " ".join(re.sub(r"[^.0-9\s]","",str(nm_bb.round(7))).split())
            c_bb = "%s %s\n" % (c,nm_bb_str)
            result_str = result_str + c_bb
    elif detect_type=='seg':
        for i in range(len(detections.class_id)):
            c = detections.class_id[i]
            m = detections.mask[i]
            p = sv.mask_to_polygons(m)
            nm_p = np.hstack((np.expand_dims(p[0][:,0]/img_width, axis=1) , np.expand_dims(p[0][:,1]/img_height, axis=1)))
            nm_p_str = " ".join(re.sub(r"[^.0-9\s]","",str(nm_p.round(7).flatten())).split())
            c_p = "%s %s\n" % (c,nm_p_str)
            result_str = result_str + c_p
    
    return result_str

def detections_to_img(img, detections, labels, detect_type):
    img_labels = [
        f"{label} {confidence:0.2f}"
        for label, (_, _, confidence, class_id, _)
        in zip(labels, detections)]
    
    if detect_type=='box':
        box_annotator = sv.BoxAnnotator()
        annotated_image = box_annotator.annotate(scene=img, detections=detections, labels=img_labels)
    
    elif detect_type=='seg':
        box_annotator = sv.BoxAnnotator()
        mask_annotator = sv.MaskAnnotator()
        mask_image = mask_annotator.annotate(scene=img, detections=detections)
        annotated_image = box_annotator.annotate(scene=mask_image, detections=detections, labels=img_labels)

    return annotated_image

def grounded_sam(dir_path, g_dino_cfg_path, g_dino_wgts_path, class_lst, detect_type, box_threshold, text_threshold, annotate_path=None):
    device = "cuda" if torch.cuda.is_available() else "cpu"   
    file_lst = os.listdir(dir_path)
    img_file_lst = [file for file in file_lst if file.endswith(tuple([".png",".jpg",".mp4"]))]
    if len(img_file_lst) > 0:
        # make directory annotate project directory
        if annotate_path is None:
            annotate_root_path = dir_path + "/annotate"
        else:
            annotate_root_path = annotate_path + f"{dir_path.split('/')[-1]}"
        
        detect_img_path = annotate_root_path+'/detect/images'
        detect_labels_path = annotate_root_path+'/detect/labels'
        annotated_img_path = annotate_root_path+'/detect/annotated_img'
        non_detect_img_path = annotate_root_path+'/non_detect/img'
        
        os.makedirs(annotate_root_path, exist_ok=True)
        os.makedirs(detect_img_path, exist_ok=True)
        os.makedirs(detect_labels_path, exist_ok=True)
        os.makedirs(annotated_img_path, exist_ok=True)
        os.makedirs(non_detect_img_path, exist_ok=True)
        
        # Make meta config yaml file

        # Annotate files
        for file_path in img_file_lst:
            file_nm = file_path.split(".")[0]
            file_type = file_path.split(".")[1]
            # For video file
            if file_type == "mp4":                
                vid = cv2.VideoCapture(dir_path+f'/{file_path}')
                frame_cnt = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))

                width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH))
                height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))

                annotate_vid = cv2.VideoWriter(f'{annotated_img_path}/{file_path}', cv2.VideoWriter_fourcc(*'mp4v'), 60, (width, height))
                
                # Get auto annotate result
                for _ in tqdm(range(frame_cnt)):
                    ret, frame = vid.read()
                    frame_nm = f'{file_nm}_frame_{_}'
                    # Get bounding box by grounding dino
                    dino_detections, labels = dino_process(g_dino_config_path=g_dino_cfg_path,
                                                        g_dino_weights_path=g_dino_wgts_path,
                                                        img=frame,
                                                        box_thresh=box_threshold,
                                                        text_thresh=text_threshold,
                                                        classes=class_lst,
                                                        to_prompt=True)
                    if len(labels) > 0:
                        # Get instance segment by mobile_SAM
                        if detect_type == 'seg':
                            fin_detections = mobile_sam_process(device_nm=device, img=frame, detections=dino_detections)
                        else:
                            fin_detections = dino_detections
                        # Save raw image file
                        cv2.imwrite(f'{detect_img_path}/{frame_nm}.jpg', frame)
                        # Save annotated result meta data to txt file
                        annotate_str = detections_to_str(img=frame, detections=fin_detections, detect_type=detect_type)
                        lo = open(f"{detect_labels_path}/{frame_nm}.txt", "w")
                        lo.write(annotate_str)
                        lo.close()
                        # Save annotated result image file
                        annotate_img = detections_to_img(img=frame, detections=fin_detections, labels=labels, detect_type=detect_type)
                        cv2.imwrite(f'{annotated_img_path}/{frame_nm}.jpg', annotate_img)
                        annotate_vid.write(annotate_img)
                    else:
                        cv2.imwrite(f'{non_detect_img_path}/{frame_nm}.jpg', frame)
                        annotate_vid.write(frame)
                annotate_vid.release()
                vid.release()
            # For image files
            else:
                img = cv2.imread(dir_path+f'/{file_path}')
                # Get bounding box by grounding dino
                dino_detections, labels = dino_process(g_dino_config_path=g_dino_cfg_path,
                                                        g_dino_weights_path=g_dino_wgts_path,
                                                        img=img,
                                                        box_thresh=box_threshold,
                                                        text_thresh=text_threshold,
                                                        classes=class_lst,
                                                        to_prompt=True)
                if len(labels) > 0:
                    # Get instance segment by mobile_SAM
                    if detect_type == 'seg':
                        fin_detections = mobile_sam_process(device_nm=device, img=img, detections=dino_detections)
                    else:
                        fin_detections = dino_detections
                    # Save raw image file
                    cv2.imwrite(f'{detect_img_path}/{file_path}', img)
                    # Save annotated result meta data to txt file
                    annotate_str = detections_to_str(img=img, detections=fin_detections, detect_type=detect_type)
                    lo = open(f"{detect_labels_path}/{file_nm}.txt", "w")
                    lo.write(annotate_str)
                    lo.close()
                    # Save annotated result image file
                    annotate_img = detections_to_img(img=img, detections=fin_detections, labels=labels, detect_type=detect_type)
                    cv2.imwrite(f'{annotated_img_path}/{file_path}', annotate_img)
                else:
                    cv2.imwrite(f'{non_detect_img_path}/{file_path}', img)
    else:
        print("The image file does not exist in that path, please check.")

  from .autonotebook import tqdm as notebook_tqdm
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)


In [2]:
# set model configuration file path
g_dino_cfg_path = "./GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"

# %cd {HOME}/GroundingDINO/weights
# !wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth

# set model weight file path
g_dino_wgts_path = "./GroundingDINO/weights/groundingdino_swint_ogc.pth"

# set device
# device = "cuda" if torch.cuda.is_available() else "cpu"

dir_path = "./kr_soccer_img"
class_lst = ["yellow persons","blue persons","red persons","green persons"]
detect_type = "box"
grounded_sam(dir_path, g_dino_cfg_path, g_dino_wgts_path, class_lst, detect_type, 0.4, 0.3)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


final text_encoder_type: bert-base-uncased




final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_

In [2]:
g_dino_cfg_path = "./GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
g_dino_wgts_path = "./GroundingDINO/weights/groundingdino_swint_ogc.pth"

model = Model(model_config_path=g_dino_cfg_path, model_checkpoint_path=g_dino_wgts_path)

class_lst = ["yellow persons","blue persons","red persons","green persons"]

prompt = " , ".join(class_lst)

frame = cv2.imread('./kr_soccer_img/korea_soccer_frame_3715.jpg')

detections, labels = model.predict_with_caption(
                image=frame,
                caption=prompt,
                box_threshold=0.4,
                text_threshold=0.3
            )


# dino_detections, labels = dino_process(g_dino_config_path=g_dino_cfg_path,
#                                     g_dino_weights_path=g_dino_wgts_path,
#                                     img=frame,
#                                     box_thresh=0.4,
#                                     text_thresh=0.3,
#                                     classes=class_lst,
#                                     to_prompt=True)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


final text_encoder_type: bert-base-uncased




In [6]:
(detections.xyxy[0][0] + detections.xyxy[0][2])/(2*)

2118.6074