# Test dino
use ipynb to speed up the base line work


In [24]:
import os
import cv2
import torch
import torchvision
import supervision as sv

import warnings
import numpy as np
from PIL import Image
from glob import glob

import termcolor
import matplotlib.pyplot as plt

from groundingdino.util.inference import Model
from segment_anything import sam_model_registry, SamPredictor
#TODO name!
from groundingdino.util.inference import load_model, load_image, predict, annotate

warnings.filterwarnings('ignore')

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Paths to GroundingDINO and SAM checkpoints
GROUNDING_DINO_CONFIG_PATH = "/root/autodl-tmp/DINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT_PATH = "/root/autodl-tmp/DINO/weights/groundingdino_swint_ogc.pth"
model_type = "default"
SAM_CHECKPOINT_PATH = "/root/autodl-tmp/sam_vit_h_4b8939.pth"

# Predict classes and hyper-param for GroundingDINO
BOX_TRESHOLD = 0.25
TEXT_TRESHOLD = 0.25
NMS_THRESHOLD = 0.8

In [25]:

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    


## functional 
#### Architecture:

1. dino find road √ （Regieon of interest)

2. use road's bbox as prompt to use SAM

   text: person, sidewalk, road, vehicle

3. rule base 

   - comparing pixel relationship betweeen person and road, sidewalk
   - other VQA method to generate text

4. analyze image sequence, predict behavior

   - question: can I try LLM to do this?
     - 一些想法，之后会议解释

5. overall must use video & image

In [26]:
# Initialize GroundingDINO model
grounding_dino_model = Model(
    model_config_path=GROUNDING_DINO_CONFIG_PATH, 
    model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, 
    device=DEVICE
)

# Initialize SAM model and predictor
sam = sam_model_registry[model_type](checkpoint=SAM_CHECKPOINT_PATH)
sam.to(device=DEVICE)
sam_predictor = SamPredictor(sam)

# Classes of interest (add as needed)

output_dir = 'DINOmasked'
os.makedirs(output_dir, exist_ok=True)


final text_encoder_type: bert-base-uncased


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [27]:
import supervision as sv


def display_mask(SAM_masks, image_path,image):
    # Create a new subplot
    output_path = os.path.join(output_dir, image_path)
    plt.figure(figsize=(16,9))

    # Display the original image
    plt.imshow(image)
    plt.axis('off')


    for mask in SAM_masks:
        show_mask(mask, plt.gca(), random_color=True)
    for box in DINO_boxes:
        show_box(box, plt.gca())

    plt.savefig(output_path)
    plt.close()




# Prompting SAM with ROI
def segment_ROI(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray):
    sam_predictor.set_image(image)
    result_masks = []
    for box in xyxy:
        masks_np, scores_np, _ = sam_predictor.predict(
        point_coords=None,
        point_labels=None,
        box= box,
        multimask_output=True,
        )
        index = np.argmax(scores_np)
        result_masks.append(masks_np[index])

    return np.array(result_masks)

def detect_road(image_path):
    
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


    image_source, image2 = load_image(image_path)

    TEXT_PROMPT = "road . sidewalk"
    CLASSES = ['road', 'sidewalk']
    

    # detect objects
    detections = grounding_dino_model.predict_with_classes(
        image=image,
        classes = CLASSES,
        box_threshold= BOX_TRESHOLD,
        text_threshold=TEXT_TRESHOLD
    )
    detections.box_area 

    box_annotator = sv.BoxAnnotator()

    labels = [
    f"{CLASSES[class_id]} {confidence:0.2f}" 
    for _, _, confidence, class_id, _ 
    in detections]

    annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
    %matplotlib inline
    # sv.plot_image(annotated_frame, (16, 16))


    # cv2.imwrite("annotated_image.jpg", annotated_frame)
    
    # NMS post process
    nms_idx = torchvision.ops.nms(
        torch.from_numpy(detections.xyxy), 
        torch.from_numpy(detections.confidence), 
        NMS_THRESHOLD
    ).numpy().tolist()

    detections.xyxy = detections.xyxy[nms_idx]
    detections.confidence = detections.confidence[nms_idx]
    detections.class_id = detections.class_id[nms_idx]

    DINO_boxes = np.array(detections.xyxy)


    SAM_masks = segment_ROI(
        sam_predictor=sam_predictor,
        image= image,
        xyxy= DINO_boxes,
    )

    # plt.figure(figsize=(10, 10))
    # plt.imshow(image)

    # for mask in SAM_masks:
    #     show_mask(mask, plt.gca(), random_color=True)
    # for box in DINO_boxes:
    #     show_box(box, plt.gca())
        
    # plt.axis('off')
    # plt.show()
    # display_mask(SAM_masks,image_path,image)
    output_path = os.path.join(output_dir, image_path)
    plt.figure(figsize=(16,9))

    # Display the original image
    plt.imshow(image)
    plt.axis('off')


    for mask in SAM_masks:
        show_mask(mask, plt.gca(), random_color=True)
    for box in DINO_boxes:
        show_box(box, plt.gca())

    plt.savefig(output_path)
    plt.close()


    return DINO_boxes,labels


# text program to make sure the label works
DINO_boxes,labels = detect_road("input/scene_2.png")
print(DINO_boxes, labels)



[[  1.7368164 187.55162   893.4925    430.34235  ]
 [351.7953    195.08667   893.7616    429.8313   ]] ['road 0.50', 'sidewalk 0.31']


In [28]:
result_masks = []
mask = [  1.7368164 ,187.55162,   893.4925 ,   430.34235  ]
score = np.array( [  1.7368164 ,187.55162,   893.4925 ,   430.34235  ])
index = np.argmax(score)
result_masks.append(mask[index])
result_masks


[893.4925]

In [29]:
DINO_boxes[1]

array([351.7953 , 195.08667, 893.7616 , 429.8313 ], dtype=float32)

In [30]:
image_dir = "selected"
    
print("Start =====")
i = 1
for filename in os.listdir(image_dir):
    if filename.endswith(".jpg") or filename.endswith(".png") or filename.endswith(".jpeg"):
        image_path = os.path.join(image_dir, filename)
        output_path = os.path.join(output_dir, filename)
        if not os.path.exists(output_path):
            print("Processing: ", i)
            i += 1
            print(f"Image path: {termcolor.colored(os.path.basename(image_path), 'green')}")

            result = detect_road(image_path)
            
            print(f"Detected: {termcolor.colored(result, 'blue')}")

Start =====
Processing:  1
Image path: video_0001_0012.png


Detected: (array([[   4.183899 ,  696.2627   , 1917.4514   , 1010.33154  ],
       [   2.8311157,  937.5635   , 1917.977    , 1077.463    ]],
      dtype=float32), ['road 0.75', 'sidewalk 0.42'])
Processing:  2
Image path: video_0002_0007.png
Detected: (array([[   2.9528198,  781.04114  , 1523.0205   , 1000.4623   ],
       [   4.013916 ,  781.1417   , 1914.7395   , 1073.8596   ],
       [   3.230835 ,  932.0819   , 1917.1422   , 1076.9109   ]],
      dtype=float32), ['road 0.39', 'road 0.29', 'road 0.26'])
Processing:  3
Image path: video_0003_0004.png
Detected: (array([[   4.653076 ,  699.9023   , 1915.5515   , 1075.1248   ],
       [ 281.41516  ,  705.7569   , 1620.6094   ,  973.64496  ],
       [ 289.2341   ,  777.806    ,  810.494    ,  933.048    ],
       [   2.6627197,  931.6172   , 1916.9866   , 1076.7948   ]],
      dtype=float32), ['road 0.42', 'sidewalk 0.34', 'road 0.36', 'sidewalk 0.30'])
Processing:  4
Image path: video_0004_0005.png
Detected: (array([[   3.0351562,  682