In [4]:
import json
import numpy as np
import open3d as o3d
from sklearn.neighbors import NearestNeighbors
import os

In [5]:
def find_nearest_points_from_json(json_file: str,
                                  model_id: str,
                                  stl_dir: str,
                                  max_points: int = 8192,
                                  sampling_method: str = 'uniform',
                                  normalize: bool = True,
                                  num_keypoints: int = 2) -> list:
    """
    Load GT coords for a model_id from JSON, sample its mesh, find nearest point indices.

    Args:
        json_file: Path to keypoints.json.
        model_id: The mesh ID to lookup.
        stl_dir: Directory containing .stl files.
        max_points: Points to sample.
        sampling_method: 'uniform' or 'poisson'.
        normalize: Normalize mesh like training.
        num_keypoints: Number of keypoints to use.

    Returns:
        List of dicts [{'index': int, 'point': [x,y,z], 'gt_coord': [x,y,z]}]
    """

    # === 1. Load annotations ===
    with open(json_file, 'r') as f:
        annotations = json.load(f)

    match = None
    for entry in annotations:
        if entry['model_id'] == model_id:
            match = entry
            break

    if match is None:
        raise ValueError(f"Model ID {model_id} not found in {json_file}")

    # Pull GT keypoint coordinates (only use first num_keypoints)
    gt_coords = [kp['xyz'] for kp in match['keypoints'][:num_keypoints]]

    # === 2. Load mesh ===
    mesh_path = os.path.join(stl_dir, f"{model_id}.stl")
    mesh = o3d.io.read_triangle_mesh(mesh_path)
    if len(mesh.vertices) == 0:
        raise ValueError(f"Failed to load mesh: {mesh_path}")

    # === 3. Sample points ===
    num_sample_points = max_points * 2
    if sampling_method == 'uniform':
        pcd = mesh.sample_points_uniformly(number_of_points=num_sample_points)
    elif sampling_method == 'poisson':
        pcd = mesh.sample_points_poisson_disk(number_of_points=num_sample_points)
    else:
        raise ValueError(f"Unsupported sampling method: {sampling_method}")

    points = np.asarray(pcd.points)

    # === 4. Normalize ===
    if normalize:
        centroid = np.mean(points, axis=0)
        points = points - centroid
        max_distance = np.max(np.linalg.norm(points, axis=1))
        points = points / max_distance

    # === 5. Nearest neighbors ===
    nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(points)

    nearest = []
    for coord in gt_coords:
        query = np.array(coord)
        if normalize:
            query = (query - centroid) / max_distance
        distances, indices = nbrs.kneighbors([query])
        idx = indices[0][0]
        nearest.append({
            'index': idx,
            'point': points[idx].tolist(),
            'gt_coord': coord
        })

    return nearest

In [30]:
result = find_nearest_points_from_json(
    json_file='knee_annotations/7-2-25/knee_points_4_5_flipped.json',
    model_id='12216',
    stl_dir='scans_3/',
    max_points=8192,
    sampling_method='uniform',
    normalize=True,
    num_keypoints=2
)

print('gt_coord')
for i in result:
    print(*i['gt_coord'])
print('--------------------------------')
print('point')
for i in result:
    print(*i['point'])

gt_coord
-0.0030665805392982304 0.11286537717454001 0.296266566188385
0.037489796848741874 -0.19636781349888416 0.3651748456235039
--------------------------------
point
0.10753322063502067 0.22316784320738903 -0.0814913912146363
0.2147687742965619 -0.6241310194616749 0.10349931818217269


In [31]:
import open3d as o3d

# Create Open3D point cloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)

# GT point as a sphere
gt_sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.01)
gt_sphere.paint_uniform_color([1, 0, 0])  # red
gt_sphere.translate(gt_coord)

# Nearest point as a sphere
pred_sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.01)
pred_sphere.paint_uniform_color([0, 1, 0])  # green
pred_sphere.translate(point)

o3d.visualization.draw_geometries([pcd, gt_sphere, pred_sphere])


NameError: name 'points' is not defined