You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
PR 2 of Sym upstream efforts. In this PR we port the examples to use the new Sym interface and drop the geometry specific imports from previous sym in support of PhysicsNeMo.Mesh + PyVista.
Depends on #1567 (PhysicsInformer upstreaming - merged)
Changes
Import updates: Replace physicsnemo.sym.hydra, physicsnemo.sym.distributed, and physicsnemo.sym.eq.spatial_grads imports with their PhysicsNeMo equivalents
Inline PDE definitions: Replace pre-built PDE classes (NavierStokes, Diffusion, IncompressibleNavierStokes) with inline SymPy definitions in each example
Key/Arch removal: Replace Key objects and Arch base class with plain strings and torch.nn.Module
Geometry replacement: Drop physicsnemo.sym.geometry imports in favor of physicsnemo.mesh primitives, sample_random_points_on_cells, and PyVista for STL loading; analytical SDF for axis-aligned box domains
Derivative functional: Replace FirstDeriv with mesh_lsq_gradient from physicsnemo.nn.functional.derivatives in DoMINO physics loss
Documentation: Update README, FAQ, example READMEs, and migration guide to reflect that PhysicsNeMo-Sym is archived and its functionality is now built-in
All changes have been verified by running the actual example (in case of LDC) or a minimal test to check the reproducibility of the behavior using the new code changes
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.
This PR ports PhysicsNeMo examples from the archived physicsnemo-sym package to the new physicsnemo.sym module, replacing geometry primitives with physicsnemo.mesh + PyVista, dropping Key/Arch wrappers in favor of plain torch.nn.Module, and inlining PDE definitions with SymPy. Two issues need attention before merging:
ldc_pinns/train.py: mask_no_slip (y < height/2) and mask_top_wall (y >= height/2 - 1e-7) overlap for a thin strip of boundary points, causing conflicting boundary condition loss signals in the same training step.
domino_nim_finetuning/src/train.py: The eqn dict is now built with {c.outputs[0]: c for c in ns.make_computations()} instead of eqn.make_nodes(return_as_dict=True). If the key type or structure differs, downstream physics-loss lookups could silently break.
Important Files Changed
Filename
Overview
examples/cfd/ldc_pinns/train.py
Replaces GeometryDatapipe/Rectangle with physicsnemo.mesh; refactors boundary sampling and training loop — overlapping mask_no_slip/mask_top_wall creates conflicting BC gradients (P1).
Removes FirstDeriv/IncompressibleNavierStokes imports; defines PDE inline and switches to make_computations() — potential API mismatch vs old make_nodes(return_as_dict=True) dict structure (P1).
Replaces Arch/Key-based MdlsSymWrapper with plain torch.nn.Module DeepONet; inlines Diffusion PDE — duplicates the class with darcy_physics_informed_fno.py.
Removes first_deriv parameter; replaces manual neighbor-loop gradients with mesh_lsq_gradient; duplicates _build_csr_from_neighbors with the nim-finetuning file.
Drops physicsnemo.sym.eq.pdes.diffusion import; inlines Diffusion PDE — duplicates the class definition that also appears in darcy_physics_informed_deeponet.py.
Rewrites Point2D as a plain class dropping all sym geometry dependencies; SDF and boundary sampling logic faithfully reproduced.
examples/cfd/datacenter/train_physics_informed.py
Inlines NavierStokes PDE using SymPy; drops physicsnemo.sym.eq.pdes.navier_stokes import; minor docstring additions only.
Comments Outside Diff (1)
examples/cfd/external_aerodynamics/domino_nim_finetuning/src/train.py, line 993-995 (link)
make_computations() API produces a different dict structure than make_nodes(return_as_dict=True)
The original code called eqn.make_nodes(return_as_dict=True), which returned a dict keyed by equation name (e.g., "continuity", "momentum_x", …). The replacement {c.outputs[0]: c for c in ns.make_computations()} keys the dict by the first output symbol of each computation object instead. Any downstream code that looks up eqn["continuity"] or similar string keys will receive a KeyError or silently use the wrong computation, breaking the physics-informed training.
Please verify that make_computations() is the correct new API and that all eqn consumers use the new key type, or use the equivalent new API (e.g., ns.make_nodes(return_as_dict=True)) if it still exists in physicsnemo.sym.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PhysicsNeMo Pull Request
Description
PR 2 of Sym upstream efforts. In this PR we port the examples to use the new Sym interface and drop the geometry specific imports from previous sym in support of PhysicsNeMo.Mesh + PyVista.
Depends on #1567 (PhysicsInformer upstreaming - merged)
Changes
All changes have been verified by running the actual example (in case of LDC) or a minimal test to check the reproducibility of the behavior using the new code changes
Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.