# Semantic Segment Road and Sidewalk
Tony Wang July 04 2023

This notebook is used for tutuorial demo, because I believe, compared to the unstable .py file, jupyter notebook would provide a vivid description and data pipeline demonstration.



## Library & Model Loading

In [1]:
import os
import cv2
# filter some annoying debug info
import warnings
warnings.filterwarnings('ignore')

import torch
import torchvision
import supervision as sv

import numpy as np
from PIL import Image
from pathlib import Path

import termcolor
import matplotlib.pyplot as plt

from groundingdino.util.inference import Model
from segment_anything import sam_model_registry, SamPredictor
#TODO name!
from groundingdino.util.inference import load_model, load_image, predict, annotate

# import SAM_utility # 

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Paths to GroundingDINO and SAM checkpoints
GROUNDING_DINO_CONFIG_PATH = "/root/autodl-tmp/DINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT_PATH = "/root/autodl-tmp/DINO/weights/groundingdino_swint_ogc.pth"
MODEL_TYPE = "default"
SAM_CHECKPOINT_PATH = "/root/autodl-tmp/sam_vit_h_4b8939.pth"

# Predict classes and hyper-param for GroundingDINO
BOX_TRESHOLD = 0.25
TEXT_TRESHOLD = 0.25
NMS_THRESHOLD = 0.8

The model loading is quite long
with some unremovable warning in gDINO, just ignore it

In [2]:
# Initialize GroundingDINO model
grounding_dino_model = Model(
    model_config_path=GROUNDING_DINO_CONFIG_PATH, 
    model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, 
    device=DEVICE
)

# Initialize SAM model and predictor
sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT_PATH)
sam.to(device=DEVICE)
sam_predictor = SamPredictor(sam)

final text_encoder_type: bert-base-uncased


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Utility Function

In [3]:

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    


In [4]:
def is_image_file(filename):
    IMAGE_EXT = ['.jpg', '.jpeg', '.png', '.bmp']
    return any(filename.endswith(extension) for extension in IMAGE_EXT)

def display_mask(SAM_masks, image_path,output_dir,DINO_boxes):
    # Create a new subplot
    output_path = os.path.join(output_dir, image_path)
    plt.figure(figsize=(16,9))
    image = cv2.cvtColor( cv2.imread(image_path),cv2.COLOR_BGR2RGB )
    # Display the original image
    plt.imshow(image)
    plt.axis('off')


    for mask in SAM_masks:
        show_mask(mask, plt.gca(), random_color=True)
    for box in DINO_boxes:
        show_box(box, plt.gca())

    plt.savefig(output_path)
    plt.close()



## Architecture:
1. gDINO : grounding_dino_model.predict_with_classes

   CLASSES_prompt= ['road', 'sidewalk']

   Based on testing, this pair is most reliable (otherwise the sidewalk may messed up with road) 

   In this part, I use the box as Region of Interest(ROI) to further prompt SAM

2. Non-maximum suppression (NMS) :

   get rid of redundant and overlapping bounding boxes.

   the metric is Intersection over Union(IoU)

3. Prompting SAM with ROI, select mask with largest area, in this step, the road and sidewalk can be segmented with naming in pixel level accuracy.

4. save the result 

5. TODO: label the result with label and confidence

6. TODO: do image sequence experiment, analyze the behavior of person

7. TODO: split cases based on JAAD info

   - car is moving 
   - car is stopping
   - time
   - weather
   - more...

In GTX3090 environment, the algorithm runs relatively fast with GPU boosting.

(Not as bad as I guessed before, much faster than all of the online demo)

1. dino find road √ （Regieon of interest)

2. use road's bbox as prompt to use SAM

   text: person, sidewalk, road, vehicle

3. rule base 

   - comparing pixel relationship betweeen person and road, sidewalk
   - other VQA method to generate text

4. analyze image sequence, predict behavior

   - question: can I try LLM to do this?
     - 一些想法，之后会议解释

5. overall must use video & image

In [12]:
# Prompting SAM with ROI
def segment_ROI(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray):
    sam_predictor.set_image(image)
    result_masks = []
    for box in xyxy:
        masks_np, scores_np, _ = sam_predictor.predict(
        point_coords=None,
        point_labels=None,
        box= box,
        multimask_output=True,
        )
        index = np.argmax(scores_np)
        result_masks.append(masks_np[index])

    return np.array(result_masks)

def detect_road(image_path,output_path):
    try:
        image = cv2.imread(image_path)
        if image is None:
            print(f"Image at path {image_path} could not be loaded. Skipping.")
            return None
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


        image_source, image2 = load_image(image_path)
    except Exception as e:
        print(f"Failed to process image at {image_path}. Error: {e}")
        return None
    
    TEXT_PROMPT = "road . sidewalk"
    CLASSES = ['road', 'sidewalk']
    

    # detect objects
    detections = grounding_dino_model.predict_with_classes(
        image=image,
        classes = CLASSES,
        box_threshold= BOX_TRESHOLD,
        text_threshold=TEXT_TRESHOLD
    )

    box_annotator = sv.BoxAnnotator()

    labels = [
    f"{CLASSES[class_id]} {confidence:0.2f}" 
    for _, _, confidence, class_id, _ 
    in detections]

    # NMS post process
    nms_idx = torchvision.ops.nms(
        torch.from_numpy(detections.xyxy), 
        torch.from_numpy(detections.confidence), 
        NMS_THRESHOLD
    ).numpy().tolist()

    detections.xyxy = detections.xyxy[nms_idx]
    detections.confidence = detections.confidence[nms_idx]
    detections.class_id = detections.class_id[nms_idx]

    DINO_boxes = np.array(detections.xyxy)

    annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
    # sv.plot_image(annotated_frame, (16, 16))


    # cv2.imwrite("annotated_image.jpg", annotated_frame)
    



    SAM_masks = segment_ROI(
        sam_predictor=sam_predictor,
        image= image,
        xyxy= DINO_boxes,
    )

    plt.figure(figsize=(16,9))

    # Display the original image
    # plt.imshow(image)
    plt.imshow(annotated_frame)  # Change this line
    plt.axis('off')


    for mask in SAM_masks:
        show_mask(mask, plt.gca(), random_color=True)
    for box in DINO_boxes:
        show_box(box, plt.gca())

    plt.savefig(output_path)
    plt.close()



    return DINO_boxes,labels


In [7]:
result_masks = []
mask = [  1.7368164 ,187.55162,   893.4925 ,   430.34235  ]
score = np.array( [  1.7368164 ,187.55162,   893.4925 ,   430.34235  ])
index = np.argmax(score)
result_masks.append(mask[index])
result_masks


[893.4925]

## Main Function

In [13]:
image_dir = Path("input") # contain many folder 
output_dir = Path('DINOmasked')
output_dir.mkdir(parents=True, exist_ok=True)

print("===== Start =====")
i = 1
# Use rglob to recursively find all image files
for image_path in image_dir.rglob('*'):
    if is_image_file(str(image_path)):
        relative_path = image_path.relative_to(image_dir)

        output_path = output_dir / relative_path
        output_path.parent.mkdir(parents=True,exist_ok=True)

        if not output_path.exists():
            print("Processing: ", i)
            i += 1
            print(f"Image path: {termcolor.colored(os.path.basename(str(image_path)), 'green')}")

            result = detect_road(str(image_path),str(output_path))

            if result is not None:
                print(f"Detected: {image_path}") # {termcolor.colored(result, 'blue')}")
            else:
                fail_str = "failed to detect result"
                print(f" {termcolor.colored(fail_str, 'red')}")


===== Start =====
Processing:  1
Image path: image_0008.png
Image at path input/video_0018/image_0008.png could not be loaded. Skipping.
 failed to detect result
Processing:  2
Image path: image_0001.png
Image at path input/video_0020/image_0001.png could not be loaded. Skipping.
 failed to detect result
Processing:  3
Image path: image_0002.png
Image at path input/video_0020/image_0002.png could not be loaded. Skipping.
 failed to detect result
Processing:  4
Image path: image_0003.png
Image at path input/video_0020/image_0003.png could not be loaded. Skipping.
 failed to detect result
Processing:  5
Image path: image_0004.png


libpng error: Read Error
libpng error: Read Error
libpng error: Read Error
libpng error: Read Error
libpng error: Read Error


Image at path input/video_0020/image_0004.png could not be loaded. Skipping.
 failed to detect result
Processing:  6
Image path: image_0005.png
Image at path input/video_0020/image_0005.png could not be loaded. Skipping.
 failed to detect result
Processing:  7
Image path: image_0006.png
Image at path input/video_0020/image_0006.png could not be loaded. Skipping.
 failed to detect result
Processing:  8
Image path: image_0007.png
Image at path input/video_0020/image_0007.png could not be loaded. Skipping.
 failed to detect result
Processing:  9
Image path: image_0008.png
Image at path input/video_0020/image_0008.png could not be loaded. Skipping.
 failed to detect result
Processing:  10
Image path: image_0009.png


libpng error: Read Error
libpng error: Read Error
libpng error: Read Error
libpng error: Read Error
libpng error: Read Error


Image at path input/video_0020/image_0009.png could not be loaded. Skipping.
 failed to detect result
Processing:  11
Image path: image_0010.png
Image at path input/video_0020/image_0010.png could not be loaded. Skipping.
 failed to detect result
Processing:  12
Image path: image_0011.png
Image at path input/video_0020/image_0011.png could not be loaded. Skipping.
 failed to detect result
Processing:  13
Image path: image_0012.png
Image at path input/video_0020/image_0012.png could not be loaded. Skipping.
 failed to detect result
Processing:  14
Image path: image_0013.png
Image at path input/video_0020/image_0013.png could not be loaded. Skipping.
 failed to detect result
Processing:  15
Image path: image_0014.png


libpng error: Read Error
libpng error: Read Error
libpng error: Read Error
libpng error: Read Error
libpng error: Read Error


Image at path input/video_0020/image_0014.png could not be loaded. Skipping.
 failed to detect result
Processing:  16
Image path: image_0015.png
Image at path input/video_0020/image_0015.png could not be loaded. Skipping.
 failed to detect result
Processing:  17
Image path: image_0016.png
Image at path input/video_0020/image_0016.png could not be loaded. Skipping.
 failed to detect result
Processing:  18
Image path: image_0017.png
Image at path input/video_0020/image_0017.png could not be loaded. Skipping.
 failed to detect result
Processing:  19
Image path: image_0018.png
Image at path input/video_0020/image_0018.png could not be loaded. Skipping.
 failed to detect result
Processing:  20
Image path: image_0001.png


libpng error: Read Error
libpng error: Read Error
libpng error: Read Error
libpng error: Read Error
libpng error: Read Error


Image at path input/video_0021/image_0001.png could not be loaded. Skipping.
 failed to detect result
Processing:  21
Image path: image_0002.png
Image at path input/video_0021/image_0002.png could not be loaded. Skipping.
 failed to detect result
Processing:  22
Image path: image_0003.png
Image at path input/video_0021/image_0003.png could not be loaded. Skipping.
 failed to detect result
Processing:  23
Image path: image_0004.png
Image at path input/video_0021/image_0004.png could not be loaded. Skipping.
 failed to detect result
Processing:  24
Image path: image_0005.png
Image at path input/video_0021/image_0005.png could not be loaded. Skipping.
 failed to detect result
Processing:  25
Image path: image_0006.png


libpng error: Read Error
libpng error: Read Error
libpng error: Read Error
libpng error: Read Error
libpng error: Read Error


Image at path input/video_0021/image_0006.png could not be loaded. Skipping.
 failed to detect result
Processing:  26
Image path: image_0001.png
Detected: (array([[   4.040344 ,  681.7572   , 1916.9829   , 1070.0835   ],
       [   2.4958496,  695.2447   , 1824.8105   ,  881.4327   ]],
      dtype=float32), ['road 0.56', 'sidewalk 0.34'])
Processing:  27
Image path: image_0002.png
Detected: (array([[   4.387146 ,  687.65015  , 1916.2673   , 1070.9369   ],
       [   3.0578613,  688.7768   , 1917.5137   ,  990.799    ],
       [   4.111847 ,  797.8878   , 1012.699    ,  896.75195  ],
       [   5.4229736,  711.2475   , 1633.386    ,  897.3488   ]],
      dtype=float32), ['road 0.40', 'sidewalk 0.30', 'road 0.32', 'sidewalk 0.25'])
Processing:  28
Image path: image_0003.png
Detected: (array([[4.0161133e+00, 6.9651025e+02, 1.9165444e+03, 1.0736506e+03],
       [5.7609863e+00, 6.9726196e+02, 1.9161823e+03, 9.7091809e+02],
       [1.2233887e+00, 8.0690332e+02, 8.3867938e+02, 9.0724951e+02],