In [None]:
import numpy as np
import open3d as o3d
import torch

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

from superprimitive_fusion.scanner import capture_spherical_scans
from superprimitive_fusion.utils import bake_uv_to_vertex_colours
from superprimitive_fusion.mesh_fusion_utils import get_mesh_components, show_mesh_boundaries
from superprimitive_fusion.mesh_fusion import fuse_meshes

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [None]:
device0 = torch.device("cuda:0")

torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

In [None]:
sam2_checkpoint = "../models/SAM2/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device0, apply_postprocessing=False)

# mask_generator = SAM2AutomaticMaskGenerator(sam2)
mask_generator = SAM2AutomaticMaskGenerator(
    model=sam2,
    points_per_side=64,
    points_per_batch=128,
    pred_iou_thresh=0.7,
    stability_score_thresh=0.92,
    stability_score_offset=0.7,
    crop_n_layers=1,
    box_nms_thresh=0.7,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=25.0,
    use_m2m=True,
)

In [None]:
# You can place this in a new file or use it as a utility function.

import torch
from torchvision.ops.boxes import batched_nms, box_area

from segment_anything.utils.amg import MaskData, batched_mask_to_box
import tool.point_utils as point_utils

def infer_sam_masks_batch(model_sam, image, points, logits=True):
    # points is same form as keypoints, [N, 2]
    # in the range [-1, 1]
    device = points.device
    num_pts = points.shape[0]

    model_sam.set_image(image)
    
    H, W, _ = image.shape 
    
    H_sam, W_sam = model_sam.transform.get_preprocess_shape(H, W, model_sam.transform.target_length)

    points_sam_format = point_utils.denormalise_coordinates(points, (H_sam, W_sam)) 
    points_sam_format = points_sam_format.flip(-1)

    # sam ids for postive / negative ponits. here we use one postive per segment 
    dummy_ids = torch.ones((num_pts, 1), dtype=torch.int64, device=device)

    masks, iou_predictions, lowres =  model_sam.predict_torch(points_sam_format[:, None],
                                                              dummy_ids,
                                                              multimask_output=True,
                                                              return_logits=logits)
    
    return {'masks': masks, 
            'iou_pred': iou_predictions,
            'lowres': lowres}

def predict_sam_segmentation_masks(
    sam_model,
    image,
    sam_config,
    keypoints=None,
    num_pts=300,
    num_pts_active=50,
    device=torch.device('cuda:0')
):
    """
    Standalone function to predict segmentation masks from an image using SAM.

    Args:
        sam_model: A SamPredictor instance.
        image: Input image as a numpy array (H, W, 3).
        sam_config: Dict with SAM config options (see infer_masks).
        keypoints: Optional tensor of normalized keypoints [-1, 1], shape (N, 2).
        num_pts: Number of random keypoints to sample if keypoints is None.
        num_pts_active: Number of active sampling points.
        device: Torch device.

    Returns:
        Dict with masks and related outputs.
    """
    H, W, _ = image.shape

    if keypoints is None:
        keypoints = (torch.rand(num_pts, 2, device=device) * 2 - 1)

    select_smallest = sam_config.get('select_smallest', True)
    nms = sam_config.get('nms', True)
    box_nms_thresh = sam_config.get('box_nms_thresh', 0.7)
    iou_threshold = sam_config.get('iou_threshold', 0.88)
    stability_score_thresh = sam_config.get('stability_threshold', 0.95)
    filter_edge_keypoints = sam_config.get('filter_edge_points', False)
    cut_masks_by_edges = sam_config.get('cut_masks_by_edges', False)
    edge_probs_threshold = sam_config.get('edge_probs_threshold', 0.5)
    filter_by_box_size = sam_config.get('filter_by_box_size', False)

    # 1. Initial mask prediction
    masks = infer_sam_masks_batch(sam_model, image, keypoints)
    masks = smallest_good_mask_batch(
        masks['masks'],
        masks['iou_pred'],
        iou_threshold=iou_threshold,
        stability_score_thresh=stability_score_thresh,
        select_smallest=select_smallest
    )
    masks = MaskData(**masks)
    keypoints_filtered = keypoints[masks['keypoints_ids'], ...]

    # 2. NMS filtering
    if nms:
        scores_boxes = 1 / box_area(masks["boxes"])
        keep_by_nms = batched_nms(
            masks['boxes'].float(),
            scores_boxes if filter_by_box_size else masks['iou_preds'],
            torch.zeros_like(masks["boxes"][:, 0]),
            iou_threshold=box_nms_thresh,
        )
        masks.filter(keep_by_nms)
        keypoints_filtered = keypoints_filtered[keep_by_nms, ...]

    # 3. Active sampling for additional masks
    coverage_mask = masks['masks'].any(dim=0)
    if num_pts_active > 0:
        from frontend.segment.mask_generation import active_sample_pos, smallest_good_mask_batch  # local import to avoid circular
        sampled_masks = active_sample_pos(coverage_mask[None], num_samples=num_pts_active)
        keypoints_active = sampled_masks['normalised_coords'][0]
        src_masks_add = infer_sam_masks_batch(sam_model, image, keypoints_active)
        src_masks_add = smallest_good_mask_batch(
            src_masks_add['masks'],
            src_masks_add['iou_pred'],
            iou_threshold=iou_threshold,
            stability_score_thresh=stability_score_thresh,
            select_smallest=select_smallest
        )
        keypoints_active_filtered = keypoints_active[src_masks_add['keypoints_ids'], ...]
        src_masks_add = MaskData(**src_masks_add)
        if nms:
            add_scores_boxes = 1 / box_area(src_masks_add["boxes"])
            keep_by_nms = batched_nms(
                src_masks_add['boxes'].float(),
                add_scores_boxes if filter_by_box_size else src_masks_add['iou_preds'],
                torch.zeros_like(src_masks_add["boxes"][:, 0]),
                iou_threshold=box_nms_thresh,
            )
            src_masks_add.filter(keep_by_nms)
            keypoints_active_filtered = keypoints_active_filtered[keep_by_nms, ...]
        keypoints_final = torch.cat([keypoints_filtered, keypoints_active_filtered], dim=0)
        masks.cat(src_masks_add)
    else:
        keypoints_final = keypoints_filtered

    # 4. Return masks and keypoints
    return {
        'masks': masks['masks'],
        'boxes': masks['boxes'],
        'keypoints': keypoints_final,
        'iou_preds': masks['iou_preds'],
    }

In [4]:
mesh = o3d.io.read_triangle_mesh("../data/power-drill/textured.obj", enable_post_processing=True)

bake_uv_to_vertex_colours(mesh)

mesh.compute_vertex_normals()

bb = mesh.get_minimal_oriented_bounding_box()
scale = np.mean(bb.get_max_bound())

In [6]:
from superprimitive_fusion.mesh_fusion_utils import sanitise_mesh, colour_transfer, count_inconsistent_normal_pairs
from superprimitive_fusion.debug_utils import debug_mesh
import pymeshfix

scans = capture_spherical_scans(
    mesh=mesh,
    mask_generator=None,
    num_views=6,
    radius=0.3,
    width_px=360,
    height_px=240,
    fov=70,
    dropout_rate=0,
    depth_error_std=0.0,
    translation_error_std=0,
    rotation_error_std_degs=0,
    k=10,
    sampler="fibonacci",
)
meshes = [scan['mesh'] for scan in scans]

# o3d.visualization.draw_geometries(
#     meshes,
#     window_name="Virtual scan",
#     front=[0.3, 1, 0],
#     lookat=[0, 0, 0],
#     up=[0, 0, 1],
#     zoom=0.7,
# )

In [7]:
j = 6
repaired = meshes[0]
for i in range(1,j):
    fused_mesh = fuse_meshes(
        repaired,
        meshes[i],
        h_alpha=3,
        trilat_iters=2,
        shift_all=False,
        fill_holes=False,
    )
    V0 = np.asarray(fused_mesh.vertices)
    F0 = np.asarray(fused_mesh.triangles)
    C0 = np.asarray(fused_mesh.vertex_colors)

    V0, F0 = sanitise_mesh(V0, F0)

    meshfix = pymeshfix.MeshFix(V0, F0)
    meshfix.repair(verbose=False, joincomp=False, remove_smallest_components=False)
    V1, F1 = meshfix.points, meshfix.faces

    C1 = colour_transfer(V0, C0, V1)

    repaired = o3d.geometry.TriangleMesh()
    repaired.vertices = o3d.utility.Vector3dVector(V1)
    repaired.triangles = o3d.utility.Vector3iVector(F1)
    repaired.vertex_colors = o3d.utility.Vector3dVector(C1)
    repaired.compute_vertex_normals()
    # repaired.compute_face_normals()

    o3d.visualization.draw_geometries([repaired])


o3d.visualization.draw_geometries(
    [fused_mesh],
    window_name="Virtual scan",
    front=[0.3, 1, 0],
    lookat=[0, 0, 0],
    up=[0, 0, 1],
    zoom=0.7,
)

In [None]:
print('start stats')
print("edge manifold:", fused_mesh.is_edge_manifold(allow_boundary_edges=True))
print("vertex manifold:", fused_mesh.is_vertex_manifold())
print("self-intersecting:", fused_mesh.is_self_intersecting())
print("watertight:", fused_mesh.is_watertight())

import trimesh
V = np.asarray(fused_mesh.vertices)
F = np.asarray(fused_mesh.triangles)
tm = trimesh.Trimesh(vertices=V, faces=F, process=False)
print("winding consistent:", tm.is_winding_consistent)
print("watertight:", tm.is_watertight)
print("euler number:", tm.euler_number)  # sanity: wildly negative often = non-manifold mess
print('finish stats')

o3d.visualization.draw_geometries(
    [fused_mesh],
    window_name="Virtual scan",
    front=[0.3, 1, 0],
    lookat=[0, 0, 0],
    up=[0, 0, 1],
    zoom=0.7,
)

new_points = np.asarray(fused_mesh.vertices)
fused_mesh_triangles = np.asarray(fused_mesh.triangles)
new_colours = np.asarray(fused_mesh.vertex_colors)

V0, F0 = sanitise_mesh(new_points, fused_mesh_triangles)
C0 = new_colours

# meshfix = pymeshfix.MeshFix(V0, F0)
# meshfix.repair(verbose=True, joincomp=False, remove_smallest_components=False)
# V1, F1 = meshfix.points, meshfix.faces

mf = pymeshfix.PyTMesh(True)
mf.load_array(V0, F0)
# mf.fill_small_boundaries(nbe=5)
mf.clean()
# mf.join_closest_components()
V1, F1 = mf.return_arrays()

# mf = pymeshfix.PyTMesh(False)
# mf.load_array(V0, F0)

# mf.fill_small_boundaries(nbe=100)

# V1, F1 = mf.return_arrays()

C1 = colour_transfer(V0, C0, V1)

repaired = o3d.geometry.TriangleMesh()
repaired.vertices = o3d.utility.Vector3dVector(V1)
repaired.triangles = o3d.utility.Vector3iVector(F1)
repaired.vertex_colors = o3d.utility.Vector3dVector(C1)
repaired.compute_vertex_normals()

o3d.visualization.draw_geometries([repaired])

components = get_mesh_components(fused_mesh, show=True)

# n_components = len(components)
# tris_added = len(F1) - len(fused_mesh_triangles)
# n_inconsistent_pairs_holes = count_inconsistent_normal_pairs(fused_mesh, show=False)
# n_inconsistent_pairs_filled = count_inconsistent_normal_pairs(repaired, show=False)
# show_mesh_boundaries(fused_mesh, show=True, edges=True)
# show_mesh_boundaries(repaired, show=True, edges=True)
# show_mesh_boundaries(repaired, show=True, edges=True, base_mesh=fused_mesh)

# print(f'{tris_added} triangles were added to fill holes')
# print(f'{n_inconsistent_pairs_holes} (w/ holes) is the number of triangle neighbours with inconsistent normals')
# print(f'{n_inconsistent_pairs_filled} (holes filled) is the number of triangle neighbours with inconsistent normals')