In [1]:
import cv2
import numpy as np
import supervision as sv
import torch
import torchvision
from groundingdino.util.inference import Model
from segment_anything import sam_model_registry, SamPredictor



### 下载必需的库

```bash
git clone https://github.com/IDEA-Research/GroundingDINO.git
pip install --no-build-isolation -e GroundingDINO
pip install supervision==0.21.0
```

### 下载权重

```bash
cd Grounded-Segment-Anything

wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
```

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

GROUNDING_DINO_CONFIG_PATH = "./config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT_PATH = "./models/groundingdino_swint_ogc.pth"

SAM_ENCODER_VERSION = "vit_h"
SAM_CHECKPOINT_PATH = "./models/sam_vit_h_4b8939.pth"

grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)

sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
sam.to(device=DEVICE)
sam_predictor = SamPredictor(sam)



final text_encoder_type: bert-base-uncased




In [3]:
SOURCE_IMAGE_PATH = "./img/test1.png"

In [4]:
# 语言嵌入
CLASSES = ["white plate"]

In [5]:
BOX_THRESHOLD = 0.25    # 边界框置信度阈值
TEXT_THRESHOLD = 0.25   # 文本匹配置信度阈值
NMS_THRESHOLD = 0.8     # 非极大值抑制阈值

In [6]:
image = cv2.imread(SOURCE_IMAGE_PATH)

detections = grounding_dino_model.predict_with_classes(
    image=image,
    classes=CLASSES,
    box_threshold=BOX_THRESHOLD,
    text_threshold=TEXT_THRESHOLD
)



In [None]:
# 边界检测标注器
box_annotator = sv.BoxAnnotator()
labels = [
    f"{CLASSES[class_id]} {confidence:0.2f}"
    for _, _, confidence, class_id, _, _
    in detections]
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)

# 保存标注后的图片
cv2.imwrite("result/groundingdino_annotated_image.jpg", annotated_frame)



True

In [8]:
print(f"Before NMS: {len(detections.xyxy)} boxes")
nms_idx = torchvision.ops.nms(
    torch.from_numpy(detections.xyxy), 
    torch.from_numpy(detections.confidence), 
    NMS_THRESHOLD
).numpy().tolist()

# 过滤重复检测
detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
detections.class_id = detections.class_id[nms_idx]

print(f"After NMS: {len(detections.xyxy)} boxes")

Before NMS: 1 boxes
After NMS: 1 boxes


In [9]:
# 调用sam进行分割
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
    sam_predictor.set_image(image)
    result_masks = []
    for box in xyxy:
        masks, scores, logits = sam_predictor.predict(
            box=box,
            multimask_output=True
        )
        # 需要注意的点，生成的结果并不一定是置信度高到低排列的
        # 因此需要根据scores选择置信度最高的mask
        index = np.argmax(scores)
        result_masks.append(masks[index])
    return np.array(result_masks)

In [None]:
# 将检测框转为mask
detections.mask = segment(
    sam_predictor=sam_predictor,
    image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
    xyxy=detections.xyxy
)

# 创建box和mask的标注器
box_annotator = sv.BoxAnnotator()
mask_annotator = sv.MaskAnnotator()
labels = [
    f"{CLASSES[class_id]} {confidence:0.2f}"
    for _, _, confidence, class_id, _, _
    in detections]
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)

cv2.imwrite("result/grounded_sam_annotated_image.jpg", annotated_image)



True