In [1]:
import numpy as np
import open3d as o3d
import torch

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

from superprimitive_fusion.scanner import capture_spherical_scans
from superprimitive_fusion.utils import bake_uv_to_vertex_colours, distinct_colours
from superprimitive_fusion.mesh_fusion import fuse_meshes

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [None]:
device0 = torch.device("cuda:0")

torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

In [8]:
sam2_checkpoint = "../models/SAM2/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device0, apply_postprocessing=False)

# mask_generator = SAM2AutomaticMaskGenerator(sam2)
mask_generator = SAM2AutomaticMaskGenerator(
    model=sam2,
    points_per_side=64,
    points_per_batch=128,
    pred_iou_thresh=0.7,
    stability_score_thresh=0.92,
    stability_score_offset=0.7,
    crop_n_layers=1,
    box_nms_thresh=0.7,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=25.0,
    use_m2m=True,
)

In [11]:
mesh = o3d.io.read_triangle_mesh("../data/mustard-bottle/textured.obj", enable_post_processing=True)

bake_uv_to_vertex_colours(mesh)

mesh.compute_vertex_normals()

bb = mesh.get_minimal_oriented_bounding_box()
scale = np.mean(bb.get_max_bound())

In [17]:
scans = capture_spherical_scans(
    mesh=mesh,
    mask_generator=None,
    num_views=6,
    radius=0.3,
    width_px=360,
    height_px=240,
    fov=70,
    dropout_rate=0,
    depth_error_std=0.0,
    translation_error_std=0,
    rotation_error_std_degs=0,
    k=10,
    sampler="fibonacci",
)
meshes = [scan['mesh'] for scan in scans]

o3d.visualization.draw_geometries(
    meshes,
    window_name="Virtual scan",
    front=[0.3, 1, 0],
    lookat=[0, 0, 0],
    up=[0, 0, 1],
    zoom=0.7,
)

In [25]:
slices = []
axis = 'z'      # Choose from 'x', 'y', or 'z'
c = 0.0         # Plane location (e.g., z = 0.0)
width = 0.01    # Width of the null zone on each side

axis_index = {'x': 0, 'y': 1, 'z': 2}[axis]

for i, mesh in enumerate(meshes):
    mesh.compute_triangle_normals()
    triangles = np.asarray(mesh.triangles)
    vertices = np.asarray(mesh.vertices)

    # Get coordinate values for each vertex along the slicing axis
    coords = vertices[:, axis_index]

    # For each triangle, get the 3 corresponding coordinate values
    tri_coords = coords[triangles]

    # Find triangles that span the plane ± width (min < c-width and max > c+width)
    mask = np.logical_or(
        np.min(tri_coords, axis=1) < c - width,
        np.max(tri_coords, axis=1) > c + width
    )
    # Also make sure triangle is not intersecting the null zone
    not_crossing = np.logical_or(
        np.max(tri_coords, axis=1) < c - width,
        np.min(tri_coords, axis=1) > c + width
    )

    keep_triangles = triangles[not_crossing]
    
    # Create a new mesh with only the kept triangles
    new_mesh = o3d.geometry.TriangleMesh()
    new_mesh.vertices = o3d.utility.Vector3dVector(vertices)
    new_mesh.vertex_colors = mesh.vertex_colors
    new_mesh.triangles = o3d.utility.Vector3iVector(keep_triangles)
    new_mesh.compute_vertex_normals()

    slices.append(new_mesh)

o3d.visualization.draw_geometries(
    slices,
    window_name="Virtual scan",
    front=[0.3, 1, 0],
    lookat=[0, 0, 0],
    up=[0, 0, 1],
    zoom=0.7,
)

In [27]:
for i in range(len(meshes)):
    o3d.visualization.draw_geometries(
        [meshes[i]],
        window_name="Virtual scan",
        front=[0.3, 1, 0],
        lookat=[0, 0, 0],
        up=[0, 0, 1],
        zoom=0.7,
    )

In [None]:
for i in range(2):
    o3d.visualization.draw_geometries(
        [slices[i]],
        window_name="Virtual scan",
        front=[0.3, 1, 0],
        lookat=[0, 0, 0],
        up=[0, 0, 1],
        zoom=0.7,
    )

In [29]:
j = 2
fused_mesh = slices[0]
for i in range(1,j):
    fused_mesh = fuse_meshes(
        fused_mesh,
        slices[i],
        h_alpha=5,
        trilat_iters=2,
        shift_all=False,
        fill_holes=False,
    )

o3d.visualization.draw_geometries(
    [fused_mesh],
    window_name="Virtual scan",
    front=[0.3, 1, 0],
    lookat=[0, 0, 0],
    up=[0, 0, 1],
    zoom=0.7,
)

  clustered_overlap_nrms.append(merged_nrm / np.linalg.norm(merged_nrm))


In [None]:
j = 6
fused_mesh = meshes[0]
for i in range(1,j):
    fused_mesh = fuse_meshes(
        fused_mesh,
        meshes[i],
        h_alpha=5,
        trilat_iters=2,
        shift_all=False,
        fill_holes=True,
    )

o3d.visualization.draw_geometries(
    [fused_mesh],
    window_name="Virtual scan",
    front=[0.3, 1, 0],
    lookat=[0, 0, 0],
    up=[0, 0, 1],
    zoom=0.7,
)

In [18]:
import copy

mesh = copy.deepcopy(fused_mesh)

out = mesh.cluster_connected_triangles()

clust_ids = np.asarray(out[0])
unique_clusters = np.unique(clust_ids)
print(f'There are {len(unique_clusters)} clusters of connected components')

colours = distinct_colours(len(unique_clusters))
colours = (colours * 255).astype(np.uint8)

components = []
for i in unique_clusters:
    component = copy.deepcopy(mesh)
    tris = np.asarray(component.triangles)
    tris_clust = tris[clust_ids==i]
    component.triangles = o3d.utility.Vector3iVector(tris_clust)
    component.paint_uniform_color(colours[i].astype(float) / 255)
    components.append(component)

o3d.visualization.draw_geometries(
    components
)

There are 22 clusters of connected components
