In [1]:
import pandas as pd
from groundingdino.util.inference import Model
from typing import List
import os
import supervision as sv
import cv2
import warnings
from tqdm import tqdm
from ultralytics import YOLO, SAM
import torch
import numpy as np
import matplotlib.pyplot as plt
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.build_sam import build_sam2, build_sam2_video_predictor
import rerun as rr
#from track_utils import sample_points_from_masks
#from video_utils import create_video_from_images
import json
import random
from uuid import uuid4
import pandas as pd
import logging

In [2]:
#!pip install supervision==0.22.0

In [3]:
sv.__version__

'0.22.0'

In [4]:
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
for logger in loggers:
    if "transformers" in logger.name.lower():
        logger.setLevel(logging.ERROR)

In [5]:
class GLAMModel:
    def __init__(self, grounding_dino_config_path, grounding_dino_checkpoint_path, sam_model_cfg, sam_checkpoint_path, prompt=None):
        if prompt is None:
            # prompt = ['pathways', 'trails', 'walkways', 'sidewalks', 'tracks', 'footpaths', 'routes', 'pedestrian paths', 'walking paths', 'lanes']
            prompt = ['pavement', 'fence', 'cyclepath', 'trees', 'grasses', 'sidewalk', 'buildings', 'skies', 'streetlights']
        self.prompt = prompt
        self.grounding_dino_model = Model(model_config_path=grounding_dino_config_path, model_checkpoint_path=grounding_dino_checkpoint_path)
        self.sam = build_sam2(sam_model_cfg, sam_checkpoint_path, device="cuda")
        self.sam_predictor = SAM2ImagePredictor(self.sam)
        self.yolo = YOLO('yolov8x-seg.pt')
        with open('class_descriptions.json', 'r', encoding='utf-8') as file:
            self.class_names = json.load(file)

        self.class_names += [{'id': 80+i, 'color': self.generate_random_color(), 'name': p} for i, p in enumerate(self.prompt)]
        self.class_dict = {item['id']: item['name'] for item in self.class_names}
        # self.dino_classes = 'pathways . trails . walkways'
        self.dino_classes = self.enhance_class_name(self.prompt)
        # self.dino_classes = str.join(' . ', self.prompt) + ' .'
        self.dino_box_threshold = 0.35
        self.dino_text_threshold = 0.25
        self.class_descriptions = [rr.AnnotationInfo(id=cat["id"], color=cat["color"], label=cat["name"]) for cat in self.class_names]
        self.yolo_classes = [0, 1, 2, 3, 5, 7, 9, 11, 30]  
        self.persist = []
        self.video_outs = dict()
        
    @staticmethod
    def enhance_class_name(class_names: List[str]) -> List[str]:
        return [f"{class_name}" for class_name in class_names]

    @staticmethod
    def generate_random_color():
        r = random.randint(0, 255)
        g = random.randint(0, 255)
        b = random.randint(0, 255)
        return r, g, b
    
    def add_dino_class(self, _phrase):
        _class_id = max(self.class_dict.keys()) + 1
        self.class_names.append({'id': _class_id, 'color': self.generate_random_color(), 'name': _phrase})
        self.class_dict = {item['id']: item['name'] for item in self.class_names}
        self.class_descriptions = [rr.AnnotationInfo(id=cat["id"], color=cat["color"], label=cat["name"]) for cat in self.class_names]
        return _class_id
    
    def dino_id_to_class_name(self, dino_id):
        return self.class_dict[dino_id]
    
    def phrases2classes(self, phrases: List[str]) -> (np.ndarray, bool):
        class_ids = []
        ret = False
        for phrase in phrases:
            if phrase in self.class_dict.values():
                for k, v in self.class_dict.items():
                    if v == phrase:
                        class_ids.append(k)
            else:
                _class_id = self.add_dino_class(phrase)
                class_ids.append(_class_id)
                ret = True
        return np.array(class_ids), ret

In [6]:
GROUNDING_DINO_CONFIG_PATH = os.path.join('../', "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py")
GROUNDING_DINO_CHECKPOINT_PATH = os.path.join('../', "weights", "groundingdino_swint_ogc.pth")
SAM_CHECKPOINT_PATH = os.path.join("/home/lnt/PycharmProjects/sam/weights/sam2_hiera_large.pt")
SAM_MODEL_CFG = "sam2_hiera_l.yaml"

In [9]:
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm

# 初始化 GLAMModel
glam_model = GLAMModel(grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, 
                       grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, 
                       sam_model_cfg=SAM_MODEL_CFG, 
                       sam_checkpoint_path=SAM_CHECKPOINT_PATH)

# 加载 gaze 点数据
gaze_data = pd.read_csv('/home/lnt/PycharmProjects/sam/data/gaze_positions.csv')

# 初始化视频
frame_pos = 0
HOME = '/home/lnt/PycharmProjects/sam'
SOURCE_VIDEO_PATH = f"{HOME}/data/world_raw.mp4"
output_path = f"{HOME}/data/video_dino-sam1_yolo_masks.mp4"
cap = cv2.VideoCapture(SOURCE_VIDEO_PATH)

cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

# 初始化一个新的DataFrame来保存结果
results = []

for frame_number in tqdm(range(frame_pos, total_frames)):
    ret, frame = cap.read()
    if not ret:
        break

    annotated_frame = frame.copy()

    # 获取当前帧的 gaze 点
    current_gaze_points = gaze_data[gaze_data['world_index'] == frame_number]

    # Dino-SAM 预测
    dino_results, phrases = glam_model.grounding_dino_model.predict_with_caption(
        image=frame,
        caption=str.join(' . ', glam_model.prompt),
        box_threshold=glam_model.dino_box_threshold,
        text_threshold=glam_model.dino_text_threshold
    )
    dino_results.class_id, ret = glam_model.phrases2classes(phrases)
    
    glam_model.sam_predictor.set_image(frame)

    # 检查每个 gaze 点是否落在任何一个掩膜上
    for _box, _cls_id, _confidence, _phrase in zip(dino_results.xyxy, dino_results.class_id, dino_results.confidence, phrases):
        _masks, _scores, _logits = glam_model.sam_predictor.predict(
            box=_box,
            multimask_output=True
        )
        _index = np.argmax(_scores)
        _mask = _masks[_index]
        _mask = _mask.astype('bool')
        
        for _, gaze_point in current_gaze_points.iterrows():
            # 获取 gaze 点的实际坐标
            gaze_x = int(gaze_point['norm_pos_x'] * width)
            gaze_y = int((1 - gaze_point['norm_pos_y']) * height)
            
            # 确保坐标在图像范围内
            if 0 <= gaze_x < _mask.shape[1] and 0 <= gaze_y < _mask.shape[0]:
                # 检查 gaze 点是否在掩膜内
                if _mask[gaze_y, gaze_x]:
                    # 记录结果
                    results.append({
                        'index': gaze_point.name,
                        'timestamp': gaze_point['gaze_timestamp'],
                        'frame': frame_number,
                        'label': _phrase,
                        'x': gaze_x,
                        'y': gaze_y,
                    })

        # 注释和可视化（可选）
        mask_result = sv.Detections(np.array([_box]), np.array([_mask]), np.array([_confidence]), np.array([_cls_id]))
        mask_annotator = sv.MaskAnnotator()
        box_annotator = sv.BoxAnnotator()
        label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER)
        annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=mask_result)
        annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=mask_result)
        annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=mask_result, labels=[_phrase])
    cv2.putText(annotated_frame, f'Frame: {frame_number}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
    out.write(annotated_frame)
    
# 保存结果到 CSV
results_df = pd.DataFrame(results)
results_df.to_csv('/home/lnt/PycharmProjects/sam/data/dino_sam2_gazed.csv', index=False)

out.release()
cap.release()


final text_encoder_type: bert-base-uncased


100%|███████████████████████████████████| 13077/13077 [1:11:40<00:00,  3.04it/s]
