Skip to content

Refactor interpolation functionals and add point-to-grid API#1542

Merged
ktangsali merged 13 commits intoNVIDIA:mainfrom
loliverhennigh:pr3-interpolation-restack
Apr 16, 2026
Merged

Refactor interpolation functionals and add point-to-grid API#1542
ktangsali merged 13 commits intoNVIDIA:mainfrom
loliverhennigh:pr3-interpolation-restack

Conversation

@loliverhennigh
Copy link
Copy Markdown
Collaborator

PhysicsNeMo Pull Request

Description

Just flushing out the interpolation functionals a bit. Adding better names and expanding to point to grid interpolation.

Still trying to figure out how to get AI to open PRs for me ok so apologies for the like 4 closed PRs on this.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 1, 2026

Greptile Summary

This PR refactors the existing interpolation functional into a clearly named grid_to_point_interpolation (renaming InterpolationGridToPointInterpolation) and introduces a new point_to_grid_interpolation (scatter) operation, including both Warp and PyTorch backends with full forward/backward support across 1D/2D/3D. The old interpolation function is preserved as a deprecated alias, and the monolithic _warp_impl.py is replaced by a modular per-kernel file structure.

Key concerns:

  • Breaking API change: The Interpolation class is silently removed from physicsnemo.nn.functional.interpolation without a deprecation alias, while the interpolation function gets one. Any downstream code using from physicsnemo.nn.functional.interpolation import Interpolation will receive an ImportError at runtime.
  • Inconsistent dtype contracts: The grid_to_point warp backend silently accepts any dtype and converts to float32 internally; the point_to_grid warp backend raises TypeError for non-float32 inputs. The same asymmetry appears in the torch backends. Since the two functionals are natural adjoints of each other and are often used together in the same pipeline, this difference will cause confusing runtime failures.
  • Minor docstring typo: smooth_step_2_weighting in the g2p torch implementation has "pacing" instead of "Spacing".

Important Files Changed

Filename Overview
physicsnemo/nn/functional/interpolation/init.py Re-exports the new grid-to-point and point-to-grid symbols; drops the old Interpolation class alias without a deprecation shim, causing a breaking ImportError for any consumer that imported it directly.
physicsnemo/nn/functional/interpolation/grid_to_point_interpolation/grid_to_point_interpolation.py Renames the old Interpolation FunctionSpec to GridToPointInterpolation; preserves the interpolation function as a deprecated alias that defaults to the torch backend for backward compatibility. Logic looks correct.
physicsnemo/nn/functional/interpolation/point_to_grid_interpolation/point_to_grid_interpolation.py New FunctionSpec for scattering point values onto a structured grid; warp and torch backends are registered correctly, adjoint tested against grid-to-point.
physicsnemo/nn/functional/interpolation/point_to_grid_interpolation/_torch_impl.py Torch backend for point-to-grid interpolation covering nearest-neighbor, linear, smooth-step, and Gaussian scatter; enforces float32-only dtype while the adjoint g2p torch backend is permissive — the asymmetry is notable but the accumulation logic itself is correct.
physicsnemo/nn/functional/interpolation/point_to_grid_interpolation/_warp_impl/op.py Warp custom-op registration for p2g; public entry point rejects non-float32 while the g2p warp entry point silently casts — inconsistent contract that will surprise users who use both operations with the same float16/float64 tensors.
physicsnemo/nn/functional/interpolation/grid_to_point_interpolation/_warp_impl/op.py Warp custom-op registration for grid-to-point; silently converts inputs to float32 and casts outputs back, which is the opposite policy from the p2g warp entry point.
physicsnemo/nn/functional/interpolation/grid_to_point_interpolation/_torch_impl.py Moved from old monolithic interpolation.py; one docstring typo ("pacing" → "Spacing") in smooth_step_2_weighting; implementation logic is unchanged from the original.
test/nn/functional/interpolation/test_point_to_grid_interpolation.py New test suite validating the adjoint property between p2g and g2p for all interpolation types, both warp and torch backends; good coverage of error paths and backward gradients.
test/nn/functional/interpolation/test_grid_to_point_interpolation.py Replaces the deleted test_interpolation_functional.py; covers deprecated alias, backend parity (forward and backward), and error-handling paths for the renamed GridToPointInterpolation.

Comments Outside Diff (1)

  1. physicsnemo/nn/functional/interpolation/grid_to_point_interpolation/_torch_impl.py, line 228-237 (link)

    P2 Docstring typo: "pacing" → "Spacing"

    The dx parameter description in smooth_step_2_weighting is missing the leading "S":

Reviews (1): Last reviewed commit: "Refactor interpolation functionals and a..." | Re-trigger Greptile

Comment thread physicsnemo/nn/functional/interpolation/__init__.py
Comment on lines +214 to +221
def point_to_grid_interpolation_warp(
query_points: torch.Tensor,
point_values: torch.Tensor,
grid: List[Tuple[float, float, int]],
interpolation_type: str = "smooth_step_2",
mem_speed_trade: bool = True,
) -> torch.Tensor:
if query_points.dtype != torch.float32:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Inconsistent dtype enforcement between g2p and p2g backends

point_to_grid_interpolation_warp raises TypeError for any non-float32 input, while the grid_to_point_interpolation warp backend (interpolation_warp / interpolation_impl) silently converts any dtype to float32 and casts the output back. This asymmetry means that a workflow using float16 or float64 tensors can call grid_to_point_interpolation without error, but will crash when calling point_to_grid_interpolation with the same inputs, even though the underlying point_to_grid_interpolation_impl custom-op already handles non-float32 via the same cast-and-convert pattern.

Similarly, the torch backend for p2g (in _normalize_inputs) raises TypeError for non-float32, whereas the torch backend for g2p (interpolation_torch) applies no dtype check at all.

Consider aligning the two warp backends to the same policy — either both silently cast (as g2p does) or both require float32 explicitly (as p2g does) — so the two functionals are interchangeable in typical adjoint workflows.

.. rubric:: Benchmarks (ASV)

.. figure:: /img/nn/functional/radius_search/benchmark.png
.. figure:: ../../../img/nn/functional/radius_search/benchmark.png
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated but fixes a doc bug.

loliverhennigh and others added 2 commits April 1, 2026 16:06
@mehdiataei
Copy link
Copy Markdown
Contributor

Hey @loliverhennigh. I opened a PR with the changes we discussed here: loliverhennigh#1

@loliverhennigh
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@loliverhennigh
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@loliverhennigh
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@loliverhennigh
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

1 similar comment
@ktangsali
Copy link
Copy Markdown
Collaborator

/blossom-ci

@ktangsali
Copy link
Copy Markdown
Collaborator

Discussed with @loliverhennigh offline. Seems like @mehdiataei has reviewed and approved but for some reason the approval is not showing up here. Merging this to unblock.

@ktangsali ktangsali merged commit 841def3 into NVIDIA:main Apr 16, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants