In [1]:
import torch
from atom3d import MeshBVH
from atom3d.grid import CubeGrid, OctreeIndexer

import trimesh
import pyvista as pv
pv.start_xvfb()
pv.set_jupyter_backend('html')

import numpy as np


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [2]:
meshtem = trimesh.creation.icosphere(subdivisions=1)
V = torch.tensor(meshtem.vertices, dtype=torch.float32)
F = torch.tensor(meshtem.faces, dtype=torch.int32)

bvh = MeshBVH(V, F, device='cuda')
res = 32
grid_indexer = OctreeIndexer(max_level=int(np.log2(res)), bounds=bvh.get_bounds(), device='cuda')


In [3]:
candidates_ijk = grid_indexer.octree_traverse(bvh)
candidates_idx = grid_indexer.ijk_to_cube(candidates_ijk)


vertex_unique_idx, unique_coords, mapping = grid_indexer.voxel_unique_vertices(candidates_idx)

unique_coords = unique_coords.cuda()
unique_coords.requires_grad = True

udfs = bvh.udf(unique_coords, return_grad=True)

udfs_grad = torch.autograd.grad(udfs.sum(), unique_coords)[0]

print(udfs_grad.norm(dim=1))



tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000], device='cuda:0')


In [None]:
pl = pv.Plotter()
pl.add_mesh(meshtem, show_edges=True, opacity=0.1)
pl.add_points(unique_coords.detach().cpu().numpy(), 
scalars=(1+udfs_grad.detach().cpu().numpy())/2, rgb=True)
pl.show()

In [None]:
voxel_min, voxel_max = grid_indexer.cube_aabb_level(candidates_ijk, grid_indexer.max_level)
result = bvh.intersect_aabb(voxel_min, voxel_max, mode=3)
surface_idx = candidates_idx[result.hit]



In [None]:
voxel_min, voxel_max = grid_indexer.cube_aabb_level(candidates_ijk)
result = bvh.intersect_aabb(voxel_min, voxel_max, mode=3)

print(f"Total intersections: {len(result.poly_counts)}")
print(f"poly_counts range: {result.poly_counts.min().item()} - {result.poly_counts.max().item()}")

all_points = []
all_faces = []
all_face_ids = []
all_aabb_ids = []
vertex_offset = 0
valid_count = 0

poly_counts = result.poly_counts.cpu().numpy()
poly_verts = result.poly_verts.cpu().numpy()
aabb_ids = result.aabb_ids.cpu().numpy()
face_ids = result.face_ids.cpu().numpy()

for i in range(len(poly_counts)):
    n_verts = int(poly_counts[i])
    if n_verts < 3:
        continue
    
    verts = poly_verts[i, :n_verts]
    all_points.append(verts)
    
    face = [n_verts] + list(range(vertex_offset, vertex_offset + n_verts))
    all_faces.extend(face)
    
    all_face_ids.append(face_ids[i])
    all_aabb_ids.append(aabb_ids[i])
    
    vertex_offset += n_verts
    valid_count += 1

print(f"Valid polygons: {valid_count}")

if valid_count > 0:
    points = np.vstack(all_points)
    faces = np.array(all_faces, dtype=np.int32)
    
    mesh = pv.PolyData(points, faces)
    mesh.cell_data['face_id'] = np.array(all_face_ids)
    mesh.cell_data['voxel_id'] = np.array(all_aabb_ids)
    
    pl = pv.Plotter()
    pl.add_mesh(mesh, show_edges=True, scalars='face_id', cmap='tab20', opacity=0.8)
    pl.add_mesh(meshtem, opacity=0.2, color='gray')
    pl.show()
else:
    print("No valid polygons!")