RGB-D传感器

In [None]:
import numpy as np
import torch
import cv2
from torch.nn.functional import interpolate
from kmeans_pytorch import kmeans
from sklearn.cluster import MeanShift

def filter_points_by_bounds(points, bounds_min, bounds_max, strict=True):
    """
    Filter points by taking only points within workspace bounds.
    """
    assert points.shape[1] == 3, "points must be (N, 3)"
    bounds_min = bounds_min.copy()
    bounds_max = bounds_max.copy()
    if not strict:
        bounds_min[:2] = bounds_min[:2] - 0.1 * (bounds_max[:2] - bounds_min[:2])
        bounds_max[:2] = bounds_max[:2] + 0.1 * (bounds_max[:2] - bounds_min[:2])
        bounds_min[2] = bounds_min[2] - 0.1 * (bounds_max[2] - bounds_min[2])
    within_bounds_mask = (
        (points[:, 0] >= bounds_min[0])
        & (points[:, 0] <= bounds_max[0])
        & (points[:, 1] >= bounds_min[1])
        & (points[:, 1] <= bounds_max[1])
        & (points[:, 2] >= bounds_min[2])
        & (points[:, 2] <= bounds_max[2])
    )
    return within_bounds_mask

class KeypointProposer:
    """关键点提议器类，用于从RGB图像和掩码中提取关键点"""

    def __init__(self, config):
        """初始化关键点提议器

        Args:
            config: 配置字典，包含设备、边界、距离等参数
        """
        self.config = config
        self.device = torch.device(self.config['device'])  # 设置计算设备
        # 加载预训练的DINOv2模型
        self.dinov3 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').eval().to(self.device)
        # 工作空间边界
        self.bounds_min = np.array(self.config['bounds_min'])
        self.bounds_max = np.array(self.config['bounds_max'])
        # MeanShift聚类器，用于合并相近的关键点
        self.mean_shift = MeanShift(bandwidth=self.config['min_dist_bt_keypoints'], bin_seeding=True, n_jobs=32)
        self.patch_size = 14  # DINOv2的补丁大小
        # 设置随机种子以确保结果可重现
        np.random.seed(self.config['seed'])
        torch.manual_seed(self.config['seed'])
        torch.cuda.manual_seed(self.config['seed'])

    def get_keypoints(self, rgb, points, masks):
        """从RGB图像、点云和掩码中提取关键点

        Args:
            rgb: RGB图像 [H, W, 3]
            points: 3D点云 [H, W, 3]
            masks: 分割掩码 [H, W]

        Returns:
            candidate_keypoints: 候选关键点的3D坐标
            projected: 在图像上标注关键点的可视化结果
        """
        # 预处理：调整图像尺寸，转换掩码格式
        transformed_rgb, rgb, points, masks, shape_info = self._preprocess(rgb, points, masks)

        # 使用DINOv2提取特征
        features_flat = self._get_features(transformed_rgb, shape_info)

        # 对每个掩码区域进行特征聚类，获取关键点候选
        candidate_keypoints, candidate_pixels, candidate_rigid_group_ids = self._cluster_features(points, features_flat, masks)

        # 过滤掉工作空间外的关键点
        within_space = filter_points_by_bounds(candidate_keypoints, self.bounds_min, self.bounds_max, strict=True)
        candidate_keypoints = candidate_keypoints[within_space]
        candidate_pixels = candidate_pixels[within_space]
        candidate_rigid_group_ids = candidate_rigid_group_ids[within_space]

        # 在笛卡尔空间中合并相近的点
        merged_indices = self._merge_clusters(candidate_keypoints)
        candidate_keypoints = candidate_keypoints[merged_indices]
        candidate_pixels = candidate_pixels[merged_indices]
        candidate_rigid_group_ids = candidate_rigid_group_ids[merged_indices]

        # 按位置对候选点进行排序
        sort_idx = np.lexsort((candidate_pixels[:, 0], candidate_pixels[:, 1]))
        candidate_keypoints = candidate_keypoints[sort_idx]
        candidate_pixels = candidate_pixels[sort_idx]
        candidate_rigid_group_ids = candidate_rigid_group_ids[sort_idx]

        # 将关键点投影到图像空间进行可视化
        projected = self._project_keypoints_to_img(rgb, candidate_pixels, candidate_rigid_group_ids, masks, features_flat)

        return candidate_keypoints, projected

    def _preprocess(self, rgb, points, masks):
        """预处理输入数据

        Args:
            rgb: 原始RGB图像
            points: 原始点云
            masks: 原始掩码

        Returns:
            transformed_rgb: 调整尺寸后的RGB图像
            rgb: 原始RGB图像
            points: 原始点云
            masks: 转换为二值掩码列表
            shape_info: 形状信息字典
        """
        # 将掩码转换为二值掩码列表
        masks = [masks == uid for uid in np.unique(masks)]

        # 确保输入形状与DINOv2兼容（必须是patch_size的倍数）
        H, W, _ = rgb.shape
        patch_h = int(H // self.patch_size)  # 垂直方向的补丁数量
        patch_w = int(W // self.patch_size)  # 水平方向的补丁数量
        new_H = patch_h * self.patch_size    # 调整后的高度
        new_W = patch_w * self.patch_size    # 调整后的宽度

        # 调整图像尺寸并归一化到[0,1]
        transformed_rgb = cv2.resize(rgb, (new_W, new_H))
        transformed_rgb = transformed_rgb.astype(np.float32) / 255.0  # float32 [H, W, 3]

        # 保存形状信息
        shape_info = {
            'img_h': H,        # 原始图像高度
            'img_w': W,        # 原始图像宽度
            'patch_h': patch_h, # 补丁网格高度
            'patch_w': patch_w, # 补丁网格宽度
        }

        return transformed_rgb, rgb, points, masks, shape_info

    def _project_keypoints_to_img(self, rgb, candidate_pixels, candidate_rigid_group_ids, masks, features_flat):
        """将关键点投影到图像上进行可视化

        Args:
            rgb: 原始RGB图像
            candidate_pixels: 候选关键点的像素坐标
            candidate_rigid_group_ids: 候选关键点的刚体组ID
            masks: 掩码列表
            features_flat: 展平的特征

        Returns:
            projected: 标注了关键点的图像
        """
        projected = rgb.copy()

        # 在图像上叠加关键点标注
        for keypoint_count, pixel in enumerate(candidate_pixels):
            displayed_text = f"{keypoint_count}"  # 显示的文本
            text_length = len(displayed_text)

            # 绘制文本背景框
            box_width = 30 + 10 * (text_length - 1)
            box_height = 30
            # 白色填充矩形
            cv2.rectangle(projected,
                         (pixel[1] - box_width // 2, pixel[0] - box_height // 2),
                         (pixel[1] + box_width // 2, pixel[0] + box_height // 2),
                        (255, 255, 255), -1)
            # 黑色边框
            cv2.rectangle(projected,
                         (pixel[1] - box_width // 2, pixel[0] - box_height // 2),
                         (pixel[1] + box_width // 2, pixel[0] + box_height // 2),
                        (0, 0, 0), 2)

            # 绘制关键点编号文本
            org = (pixel[1] - 7 * (text_length), pixel[0] + 7)
            color = (255, 0, 0)  # 红色文本
            cv2.putText(projected, str(keypoint_count), org, cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
            keypoint_count += 1

        return projected

    @torch.inference_mode()  # 禁用梯度计算以节省内存
    @torch.amp.autocast('cuda')  # 使用混合精度加速
    def _get_features(self, transformed_rgb, shape_info):
        """使用DINOv2提取图像特征
        Args:
            transformed_rgb: 预处理后的RGB图像
            shape_info: 形状信息字典

        Returns:
            features_flat: 展平的特征向量 [H*W, feature_dim]
        """
        img_h = shape_info['img_h']
        img_w = shape_info['img_w']
        patch_h = shape_info['patch_h']
        patch_w = shape_info['patch_w']

        # 将图像转换为张量格式
        img_tensors = torch.from_numpy(transformed_rgb).permute(2, 0, 1).unsqueeze(0).to(self.device)  # [1, 3, H, W]
        assert img_tensors.shape[1] == 3, "unexpected image shape"

        # 使用DINOv2提取特征
        features_dict = self.dinov3.forward_features(img_tensors)
        raw_feature_grid = features_dict['x_norm_patchtokens']  # [1, patch_h*patch_w, feature_dim]
        raw_feature_grid = raw_feature_grid.reshape(1, patch_h, patch_w, -1)  # [1, patch_h, patch_w, feature_dim]

        # 使用双线性插值将特征上采样到原始图像尺寸
        interpolated_feature_grid = interpolate(
            raw_feature_grid.permute(0, 3, 1, 2),  # [1, feature_dim, patch_h, patch_w]
            size=(img_h, img_w),
            mode='bilinear'
        ).permute(0, 2, 3, 1).squeeze(0)  # [H, W, feature_dim]

        # 展平特征以便后续处理
        features_flat = interpolated_feature_grid.reshape(-1, interpolated_feature_grid.shape[-1])  # [H*W, feature_dim]

        return features_flat

    def _cluster_features(self, points, features_flat, masks):
        """对每个掩码区域的特征进行聚类以获取关键点候选

        Args:
            points: 3D点云
            features_flat: 展平的特征
            masks: 二值掩码列表

        Returns:
            candidate_keypoints: 候选关键点的3D坐标
            candidate_pixels: 候选关键点的像素坐标
            candidate_rigid_group_ids: 候选关键点的刚体组ID
        """
        candidate_keypoints = []
        candidate_pixels = []
        candidate_rigid_group_ids = []

        for rigid_group_id, binary_mask in enumerate(masks):
            # 忽略过大的掩码区域
            if np.mean(binary_mask) > self.config['max_mask_ratio']:
                continue

            # 只考虑前景特征
            obj_features_flat = features_flat[binary_mask.reshape(-1)]
            feature_pixels = np.argwhere(binary_mask)  # 获取掩码内的像素坐标
            feature_points = points[binary_mask]       # 获取掩码内的3D点

            # 使用PCA降维以减少对噪声和纹理的敏感性
            obj_features_flat = obj_features_flat.double()
            (u, s, v) = torch.pca_lowrank(obj_features_flat, center=False)
            features_pca = torch.mm(obj_features_flat, v[:, :3])  # 保留前3个主成分

            # 将PCA特征归一化到[0,1]
            features_pca = (features_pca - features_pca.min(0)[0]) / (features_pca.max(0)[0] - features_pca.min(0)[0])
            X = features_pca

            # 将3D坐标作为额外维度添加到特征中
            feature_points_torch = torch.tensor(feature_points, dtype=features_pca.dtype, device=features_pca.device)
            feature_points_torch = (feature_points_torch - feature_points_torch.min(0)[0]) / (feature_points_torch.max(0)[0] - feature_points_torch.min(0)[0])
            X = torch.cat([X, feature_points_torch], dim=-1)

            # 使用K-means聚类获取有意义的区域
            cluster_ids_x, cluster_centers = kmeans(
                X=X,
                num_clusters=self.config['num_candidates_per_mask'],
                distance='euclidean',
                device=self.device,
            )
            cluster_centers = cluster_centers.to(self.device)

            # 对每个聚类中心，找到最接近的实际点作为关键点候选
            for cluster_id in range(self.config['num_candidates_per_mask']):
                cluster_center = cluster_centers[cluster_id][:3]  # 只使用PCA特征部分
                member_idx = cluster_ids_x == cluster_id
                member_points = feature_points[member_idx]
                member_pixels = feature_pixels[member_idx]
                member_features = features_pca[member_idx]

                # 找到距离聚类中心最近的点
                dist = torch.norm(member_features - cluster_center, dim=-1)
                closest_idx = torch.argmin(dist)

                # 添加到候选列表
                candidate_keypoints.append(member_points[closest_idx])
                candidate_pixels.append(member_pixels[closest_idx])
                candidate_rigid_group_ids.append(rigid_group_id)

        # 转换为numpy数组
        candidate_keypoints = np.array(candidate_keypoints)
        candidate_pixels = np.array(candidate_pixels)
        candidate_rigid_group_ids = np.array(candidate_rigid_group_ids)

        return candidate_keypoints, candidate_pixels, candidate_rigid_group_ids

    def _merge_clusters(self, candidate_keypoints):
        """使用MeanShift合并相近的关键点候选
        Args:
            candidate_keypoints: 候选关键点的3D坐标
        Returns:
            merged_indices: 合并后保留的关键点索引
        """
        # 使用MeanShift聚类合并相近的点
        self.mean_shift.fit(candidate_keypoints)
        cluster_centers = self.mean_shift.cluster_centers_
        # 对每个聚类中心，找到最接近的原始候选点
        merged_indices = []
        for center in cluster_centers:
            dist = np.linalg.norm(candidate_keypoints - center, axis=-1)
            merged_indices.append(np.argmin(dist))
        return merged_indices
