Skip to content

Commit 9d9e4a3

Browse files
authored
Add significantly more coverage of the domino datapipe to catch more errors (#1102)
* Add significantly more coverage of the domino datapipe to catch more errors. * update changelog * Update tests based on PR review * Rework dataset tests for domino by using synthetic datasets. * Remove length scale
1 parent cd31a32 commit 9d9e4a3

File tree

3 files changed

+643
-76
lines changed

3 files changed

+643
-76
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
- Added lead-time aware training support to the StormCast example.
1414
- Add a device aware kNN method to physicsnemo.utils.neighbors. Works with CPU or GPU
1515
by dispatching to the proper optimized library, and torch.compile compatible.
16+
- Added additional testing of the DoMINO datapipe.
1617

1718
### Changed
1819

physicsnemo/datapipes/cae/domino_datapipe.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,6 @@ def preprocess_combined(self, data_dict):
522522
if mesh_indices_flattened.dtype != xp.int32:
523523
mesh_indices_flattened = mesh_indices_flattened.astype(xp.int32)
524524

525-
length_scale = xp.amax(xp.amax(stl_vertices, 0) - xp.amin(stl_vertices, 0))
526-
527525
center_of_mass = calculate_center_of_mass(stl_centers, stl_sizes)
528526

529527
if self.config.bounding_box_dims_surf is None:
@@ -570,7 +568,6 @@ def preprocess_combined(self, data_dict):
570568
surf_grid_max_min = xp.stack([s_min, s_max])
571569

572570
return_dict = {
573-
"length_scale": length_scale,
574571
"surf_grid": surf_grid,
575572
"sdf_surf_grid": sdf_surf_grid,
576573
"surface_min_max": surf_grid_max_min,
@@ -651,7 +648,7 @@ def preprocess_surface(self, data_dict, core_dict, center_of_mass, s_min, s_max)
651648
(s_max[2] - s_min[2]) / nz,
652649
)
653650
pos_normals_com_surface = calculate_normal_positional_encoding(
654-
surface_coordinates, center_of_mass, cell_length=[dx, dy, dz]
651+
surface_coordinates, center_of_mass, cell_dimensions=[dx, dy, dz]
655652
)
656653
else:
657654
pos_normals_com_surface = surface_coordinates - xp.asarray(
@@ -744,7 +741,13 @@ def preprocess_surface(self, data_dict, core_dict, center_of_mass, s_min, s_max)
744741

745742
else:
746743
# We are *not* sampling, kNN on ALL points:
747-
ii = knn.kneighbors(surface_coordinates, return_distance=False)
744+
if self.array_provider == cp:
745+
ii = knn.kneighbors(surface_coordinates, return_distance=False)
746+
else:
747+
_, ii = interp_func.query(
748+
surface_coordinates,
749+
k=self.config.num_surface_neighbors,
750+
)
748751

749752
# Construct the neighbors arrays:
750753
surface_neighbors = surface_coordinates[ii][:, 1:]
@@ -892,10 +895,10 @@ def preprocess_volume(
892895
pos_normals_closest_vol = calculate_normal_positional_encoding(
893896
volume_coordinates,
894897
sdf_node_closest_point,
895-
cell_length=[dx, dy, dz],
898+
cell_dimensions=[dx, dy, dz],
896899
)
897900
pos_normals_com_vol = calculate_normal_positional_encoding(
898-
volume_coordinates, center_of_mass, cell_length=[dx, dy, dz]
901+
volume_coordinates, center_of_mass, cell_dimensions=[dx, dy, dz]
899902
)
900903
else:
901904
pos_normals_closest_vol = volume_coordinates - sdf_node_closest_point

0 commit comments

Comments
 (0)