In [None]:
import numpy as np
import torch
import cv2
from torch.nn.functional import interpolate
from kmeans_pytorch import kmeans
from sklearn.cluster import MeanShift
from PIL import Image


from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

def sam2_generate_mask(image_path, model_cfg, checkpoint, device="cuda"):
    """用SAM2生成整数mask (H, W)，每个区域对应一个ID"""
    # 读取图像
    image = Image.open(image_path).convert("RGB")
    image_np = np.array(image)

    # 构建SAM2
    sam2 = build_sam2(model_cfg, checkpoint, device=torch.device(device), apply_postprocessing=False)
    mask_generator = SAM2AutomaticMaskGenerator(model=sam2, use_m2m=False)

    # 生成mask
    anns = mask_generator.generate(image_np)

    # 转换为单通道整数mask
    h, w, _ = image_np.shape
    masks = np.zeros((h, w), dtype=np.int32)
    for idx, ann in enumerate(anns, start=1):
        m = ann["segmentation"]
        masks[m] = idx  # 每个mask区域赋一个唯一的ID

    return image_np, masks

def filter_points_by_bounds(points, bounds_min, bounds_max, strict=True):
    """过滤掉超出工作空间范围的点"""
    assert points.shape[1] == 3, "points must be (N, 3)"
    bounds_min = bounds_min.copy()
    bounds_max = bounds_max.copy()
    if not strict:
        bounds_min[:2] -= 0.1 * (bounds_max[:2] - bounds_min[:2])
        bounds_max[:2] += 0.1 * (bounds_max[:2] - bounds_min[:2])
        bounds_min[2] -= 0.1 * (bounds_max[2] - bounds_min[2])
    mask = (
        (points[:, 0] >= bounds_min[0]) & (points[:, 0] <= bounds_max[0]) &
        (points[:, 1] >= bounds_min[1]) & (points[:, 1] <= bounds_max[1]) &
        (points[:, 2] >= bounds_min[2]) & (points[:, 2] <= bounds_max[2])
    )
    return mask


def preprocess(rgb, points, masks, patch_size=14):
    """预处理输入数据，调整RGB大小，生成二值mask"""
    masks = [masks == uid for uid in np.unique(masks)]

    H, W, _ = rgb.shape
    patch_h = H // patch_size
    patch_w = W // patch_size
    new_H, new_W = patch_h * patch_size, patch_w * patch_size

    transformed_rgb = cv2.resize(rgb, (new_W, new_H)).astype(np.float32) / 255.0

    shape_info = {
        'img_h': H,
        'img_w': W,
        'patch_h': patch_h,
        'patch_w': patch_w,
    }
    return transformed_rgb, rgb, points, masks, shape_info


@torch.inference_mode()
@torch.amp.autocast('cuda')
def get_features(dinov2, transformed_rgb, shape_info, device):
    """用DINOv2提取特征"""
    img_h, img_w = shape_info['img_h'], shape_info['img_w']
    patch_h, patch_w = shape_info['patch_h'], shape_info['patch_w']

    img_tensors = torch.from_numpy(transformed_rgb).permute(2, 0, 1).unsqueeze(0).to(device)
    features_dict = dinov2.forward_features(img_tensors)
    raw_feature_grid = features_dict['x_norm_patchtokens']  # [1, patch_h*patch_w, C]
    raw_feature_grid = raw_feature_grid.reshape(1, patch_h, patch_w, -1)

    interpolated = interpolate(
        raw_feature_grid.permute(0, 3, 1, 2),
        size=(img_h, img_w),
        mode='bilinear'
    ).permute(0, 2, 3, 1).squeeze(0)

    features_flat = interpolated.reshape(-1, interpolated.shape[-1])
    return features_flat


def cluster_features(points, features_flat, masks, config, device):
    """对mask区域特征聚类，获取候选点"""
    candidate_keypoints, candidate_pixels, candidate_rigid_group_ids = [], [], []

    for rigid_group_id, binary_mask in enumerate(masks):
        if np.mean(binary_mask) > config['max_mask_ratio']:
            continue

        obj_features = features_flat[binary_mask.reshape(-1)]
        feature_pixels = np.argwhere(binary_mask)
        feature_points = points[binary_mask]

        obj_features = obj_features.double()
        u, s, v = torch.pca_lowrank(obj_features, center=False)
        features_pca = torch.mm(obj_features, v[:, :3])

        features_pca = (features_pca - features_pca.min(0)[0]) / (
            features_pca.max(0)[0] - features_pca.min(0)[0]
        )

        feature_points_torch = torch.tensor(
            feature_points, dtype=features_pca.dtype, device=features_pca.device
        )
        feature_points_torch = (feature_points_torch - feature_points_torch.min(0)[0]) / (
            feature_points_torch.max(0)[0] - feature_points_torch.min(0)[0]
        )
        X = torch.cat([features_pca, feature_points_torch], dim=-1)

        cluster_ids, cluster_centers = kmeans(
            X=X,
            num_clusters=config['num_candidates_per_mask'],
            distance='euclidean',
            device=device,
        )
        cluster_centers = cluster_centers.to(device)

        for cluster_id in range(config['num_candidates_per_mask']):
            cluster_center = cluster_centers[cluster_id][:3]
            member_idx = cluster_ids == cluster_id
            member_points = feature_points[member_idx]
            member_pixels = feature_pixels[member_idx]
            member_features = features_pca[member_idx]

            dist = torch.norm(member_features - cluster_center, dim=-1)
            closest_idx = torch.argmin(dist)

            candidate_keypoints.append(member_points[closest_idx])
            candidate_pixels.append(member_pixels[closest_idx])
            candidate_rigid_group_ids.append(rigid_group_id)

    return (
        np.array(candidate_keypoints),
        np.array(candidate_pixels),
        np.array(candidate_rigid_group_ids)
    )


def merge_clusters(candidate_keypoints, min_dist_bt_keypoints):
    """用MeanShift合并相近的候选点"""
    ms = MeanShift(bandwidth=min_dist_bt_keypoints, bin_seeding=True, n_jobs=32)
    ms.fit(candidate_keypoints)
    cluster_centers = ms.cluster_centers_

    merged_indices = []
    for center in cluster_centers:
        dist = np.linalg.norm(candidate_keypoints - center, axis=-1)
        merged_indices.append(np.argmin(dist))
    return merged_indices


def project_keypoints_to_img(rgb, candidate_pixels):
    """在图像上画关键点"""
    projected = rgb.copy()
    for idx, pixel in enumerate(candidate_pixels):
        text = str(idx)
        box_w, box_h = 30, 30
        cv2.rectangle(
            projected,
            (pixel[1] - box_w // 2, pixel[0] - box_h // 2),
            (pixel[1] + box_w // 2, pixel[0] + box_h // 2),
            (255, 255, 255), -1
        )
        cv2.rectangle(
            projected,
            (pixel[1] - box_w // 2, pixel[0] - box_h // 2),
            (pixel[1] + box_w // 2, pixel[0] + box_h // 2),
            (0, 0, 0), 2
        )
        org = (pixel[1] - 7 * len(text), pixel[0] + 7)
        cv2.putText(projected, text, org, cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
    return projected


def get_keypoints(rgb, points, masks, dinov2, config):
    """整体关键点提取流程"""
    device = torch.device(config['device'])

    transformed_rgb, rgb, points, masks, shape_info = preprocess(
        rgb, points, masks, patch_size=14
    )

    features_flat = get_features(dinov2, transformed_rgb, shape_info, device)

    candidate_keypoints, candidate_pixels, candidate_rigid_group_ids = cluster_features(
        points, features_flat, masks, config, device
    )

    within_space = filter_points_by_bounds(
        candidate_keypoints, np.array(config['bounds_min']), np.array(config['bounds_max']), strict=True
    )
    candidate_keypoints = candidate_keypoints[within_space]
    candidate_pixels = candidate_pixels[within_space]
    candidate_rigid_group_ids = candidate_rigid_group_ids[within_space]

    merged_indices = merge_clusters(candidate_keypoints, config['min_dist_bt_keypoints'])
    candidate_keypoints = candidate_keypoints[merged_indices]
    candidate_pixels = candidate_pixels[merged_indices]
    candidate_rigid_group_ids = candidate_rigid_group_ids[merged_indices]

    sort_idx = np.lexsort((candidate_pixels[:, 0], candidate_pixels[:, 1]))
    candidate_keypoints = candidate_keypoints[sort_idx]
    candidate_pixels = candidate_pixels[sort_idx]
    candidate_rigid_group_ids = candidate_rigid_group_ids[sort_idx]

    projected = project_keypoints_to_img(rgb, candidate_pixels)
    return candidate_keypoints, projected

# 1. 加载模型
dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').eval().to('cuda')
sam2_checkpoint = "./models/sam2/sam2_hiera_base_plus.pt"
model_cfg_0 = "configs/sam2/sam2_hiera_b+.yaml"

rgb, masks = sam2_generate_mask(
    "./img/test.png",
    model_cfg=model_cfg_0,
    checkpoint=sam2_checkpoint,
    device="cuda"
)
points = np.load("points.npy")             # (H, W, 3)，如果你有点云

# 3. 配置
config = {
    'device': 'cuda',
    'bounds_min': [-1, -1, -1],
    'bounds_max': [1, 1, 1],
    'min_dist_bt_keypoints': 0.05,
    'seed': 42,
    'max_mask_ratio': 0.5,
    'num_candidates_per_mask': 5,
}
candidate_keypoints, projected = get_keypoints(rgb, points, masks, dinov2, config)

print(candidate_keypoints.shape)  # (N, 3)
cv2.imwrite("projected.png", projected[:, :, ::-1])  # 保存可视化结果

In [None]:
candidate_keypoints, projected = get_keypoints(rgb, points, masks, dinov2, config)

print(candidate_keypoints.shape)  # (N, 3)
cv2.imwrite("projected.png", projected[:, :, ::-1])  # 保存可视化结果