@@ -522,8 +522,6 @@ def preprocess_combined(self, data_dict):
522
522
if mesh_indices_flattened .dtype != xp .int32 :
523
523
mesh_indices_flattened = mesh_indices_flattened .astype (xp .int32 )
524
524
525
- length_scale = xp .amax (xp .amax (stl_vertices , 0 ) - xp .amin (stl_vertices , 0 ))
526
-
527
525
center_of_mass = calculate_center_of_mass (stl_centers , stl_sizes )
528
526
529
527
if self .config .bounding_box_dims_surf is None :
@@ -570,7 +568,6 @@ def preprocess_combined(self, data_dict):
570
568
surf_grid_max_min = xp .stack ([s_min , s_max ])
571
569
572
570
return_dict = {
573
- "length_scale" : length_scale ,
574
571
"surf_grid" : surf_grid ,
575
572
"sdf_surf_grid" : sdf_surf_grid ,
576
573
"surface_min_max" : surf_grid_max_min ,
@@ -651,7 +648,7 @@ def preprocess_surface(self, data_dict, core_dict, center_of_mass, s_min, s_max)
651
648
(s_max [2 ] - s_min [2 ]) / nz ,
652
649
)
653
650
pos_normals_com_surface = calculate_normal_positional_encoding (
654
- surface_coordinates , center_of_mass , cell_length = [dx , dy , dz ]
651
+ surface_coordinates , center_of_mass , cell_dimensions = [dx , dy , dz ]
655
652
)
656
653
else :
657
654
pos_normals_com_surface = surface_coordinates - xp .asarray (
@@ -744,7 +741,13 @@ def preprocess_surface(self, data_dict, core_dict, center_of_mass, s_min, s_max)
744
741
745
742
else :
746
743
# We are *not* sampling, kNN on ALL points:
747
- ii = knn .kneighbors (surface_coordinates , return_distance = False )
744
+ if self .array_provider == cp :
745
+ ii = knn .kneighbors (surface_coordinates , return_distance = False )
746
+ else :
747
+ _ , ii = interp_func .query (
748
+ surface_coordinates ,
749
+ k = self .config .num_surface_neighbors ,
750
+ )
748
751
749
752
# Construct the neighbors arrays:
750
753
surface_neighbors = surface_coordinates [ii ][:, 1 :]
@@ -892,10 +895,10 @@ def preprocess_volume(
892
895
pos_normals_closest_vol = calculate_normal_positional_encoding (
893
896
volume_coordinates ,
894
897
sdf_node_closest_point ,
895
- cell_length = [dx , dy , dz ],
898
+ cell_dimensions = [dx , dy , dz ],
896
899
)
897
900
pos_normals_com_vol = calculate_normal_positional_encoding (
898
- volume_coordinates , center_of_mass , cell_length = [dx , dy , dz ]
901
+ volume_coordinates , center_of_mass , cell_dimensions = [dx , dy , dz ]
899
902
)
900
903
else :
901
904
pos_normals_closest_vol = volume_coordinates - sdf_node_closest_point
0 commit comments