FSDP optimizer state channels last fix#1597
Merged
pzharrington merged 4 commits intoNVIDIA:mainfrom Apr 30, 2026
Merged
Conversation
Contributor
Greptile SummaryThis PR fixes a silent corruption of FSDP optimizer state during checkpoint round-trips for models with
Important Files Changed
Reviews (1): Last reviewed commit: "lint" | Re-trigger Greptile |
Collaborator
Author
|
/ok to test cc703b5 |
Collaborator
Author
|
/ok to test d781754 |
Collaborator
Author
|
/blossom-ci |
peterdsharpe
added a commit
to peterdsharpe/physicsnemo
that referenced
this pull request
May 4, 2026
commit 91a942b Author: Peter Sharpe <peterdsharpe@gmail.com> Date: Sun May 3 21:12:32 2026 -0400 Adds Greptile minor fixes commit b24f9b6 Author: Peter Sharpe <peterdsharpe@gmail.com> Date: Sun May 3 20:56:24 2026 -0400 Back-merges dataset interrogate fix commit 6ddfb5a Author: Peter Sharpe <peterdsharpe@gmail.com> Date: Sun May 3 20:43:01 2026 -0400 Removes accidentally-commited benchmarks; these will come later commit 9fa0b5d Merge: 3e67057 4c52a45 Author: Peter Sharpe <peterdsharpe@gmail.com> Date: Sun May 3 20:37:57 2026 -0400 Merge branch 'main' into psharpe/add-GLOBE-drivaerml-standalone commit 3e67057 Author: Peter Sharpe <peterdsharpe@gmail.com> Date: Sun May 3 20:35:32 2026 -0400 Partial merge from add-GLOBE-3D-BarnesHut commit 4c52a45 Author: Peter Sharpe <peterdsharpe@gmail.com> Date: Sun May 3 16:38:28 2026 -0400 Synchronizes `GLOBE` model progress for 26.05 (NVIDIA#1595) * Migrate cached_dataset.py * verified model arch new features (self_regularization_beta) * minor formatting syncs * Adds nonregression testing * Adds compile_logging utilities and prefetching utilities * Adds self to pade.py codeowners * Syncs AirFRANS updates * corrects a docstring * Strips out broken ram caching * Adds helpful error messages * Adds helpful error messages * docs * Refactor compile logging in training script - Removed the CompileDiagnosticsCollector and replaced it with a new utility function, silence_compile_logs_on_non_zero_ranks, to suppress non-error logs from torch.compile on all ranks except rank 0. - Updated the training script to call this new function, improving log clarity during distributed training. - Adjusted logging levels for the globe logger to ensure proper diagnostics are captured only during the first launch. * Enhance DataLoader worker configuration for distributed training - Updated the logic for auto-computing `num_workers` in the `AirFRANSDataSet` class to consider CPU affinity and local world size, improving efficiency in distributed environments. - Adjusted logging to provide detailed information about the computed `num_workers`, including CPU count and GPU visibility. - Modified the run script comments to reflect the new method of calculating `num_workers`, ensuring clarity on process-level parallelism. * Partial merge from add-GLOBE-3D-BarnesHut commit 4cb586a Merge: 645701f ed855da Author: Peter Sharpe <peterdsharpe@gmail.com> Date: Sun May 3 15:23:50 2026 -0400 Merge branch 'main' into psharpe/add-GLOBE-model-progress commit 645701f Author: Peter Sharpe <peterdsharpe@gmail.com> Date: Sun May 3 15:06:55 2026 -0400 Partial merge from add-GLOBE-3D-BarnesHut commit 15d7913 Author: Peter Sharpe <peterdsharpe@gmail.com> Date: Sun May 3 13:45:06 2026 -0400 Enhance DataLoader worker configuration for distributed training - Updated the logic for auto-computing `num_workers` in the `AirFRANSDataSet` class to consider CPU affinity and local world size, improving efficiency in distributed environments. - Adjusted logging to provide detailed information about the computed `num_workers`, including CPU count and GPU visibility. - Modified the run script comments to reflect the new method of calculating `num_workers`, ensuring clarity on process-level parallelism. commit 65675a4 Author: Peter Sharpe <peterdsharpe@gmail.com> Date: Fri May 1 17:53:47 2026 -0400 Refactor compile logging in training script - Removed the CompileDiagnosticsCollector and replaced it with a new utility function, silence_compile_logs_on_non_zero_ranks, to suppress non-error logs from torch.compile on all ranks except rank 0. - Updated the training script to call this new function, improving log clarity during distributed training. - Adjusted logging levels for the globe logger to ensure proper diagnostics are captured only during the first launch. commit 948da86 Author: Peter Sharpe <peterdsharpe@gmail.com> Date: Fri May 1 15:40:31 2026 -0400 docs commit ed855da Author: Charlelie Laurent <84199758+CharlelieLrt@users.noreply.github.com> Date: Thu Apr 30 20:44:37 2026 -0700 Implements Predictor specialization for multi-diffusion (NVIDIA#1573) * Implements Predictor specialization for multi-diffusion Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Compile denoiser in multi-diffusion sampling compile tests Compiling the predictor instance directly was producing divergent results under torch 2.10 in the sample() loop (euler cases only). Follow the same pattern as test_samplers.py::TestSampleCompile and compile the denoiser closure instead — tracing through it still verifies that the predictor's __call__ path is compile-compatible. Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Avoid fullgraph compile in multi-diffusion sampling test torch 2.10 Dynamo crashes with Fatal Python error: Aborted when tracing the nested MultiDiffusionPredictor -> MultiDiffusionModel2D call chain inside sample() with fullgraph=True. Allow graph breaks here; the predictor compile contract is still tested in isolation by test_multi_diffusion_predictor.py::TestCompile. Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Flatten MultiDiffusionPredictor hot path for torch.compile Dispatch on pos_embd presence and model_kwargs is now resolved once at __init__ into a specialized closure, so __call__ is branch-free and the no-kwargs path avoids ** expansion. This keeps fullgraph=True compile cleanly traceable under torch 2.10 (which was hitting a Dynamo abort on the nested MultiDiffusionPredictor -> MultiDiffusionModel2D call chain when the denoiser closure was compiled in the sample() loop). Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Loosen torch.compile euler check in multi-diffusion sampling tests Reverts the two earlier CI-fix attempts (compile-denoiser switch, predictor hot-path flatten) since neither actually fixed the divergence. The underlying issue is an upstream torch>=2.10 Dynamo bug: euler + compiled MultiDiffusionPredictor produces numerically divergent results. Heun works, predictor compiles correctly in isolation. For euler we now assert only shape + isfinite until the upstream bug is resolved. Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Force contiguous t_cur/t_next in Euler solvers sample() passes t.expand(B) (a stride-0 non-contiguous tensor) into solver.step(). HeunSolver already forces .contiguous() on both tensors to prevent torch.compile from specializing on the stride pattern of the first call and then either mis-firing guards or silently recompiling on subsequent calls with different underlying storage. EulerSolver and EDMStochasticEulerSolver had no such guard, which was a latent bug exposed by torch 2.10 (stricter stride tracking) in the multi-diffusion compiled sample loop — producing 90%+ element divergence vs eager on the first call and a Dynamo abort on the second call. Apply the same fix uniformly across all four solver steps. Also revert the temporary loosened euler assertion in test_multi_diffusion_sampling.py now that the real fix is in place. Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Drop dead is_compiling guard and inherit from Predictor in MultiDiffusionPredictor Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Narrow _patching type and tighten multi-diffusion tests Move the _patching None check out of the is_compiling guard in MultiDiffusionModel2D so the type checker narrows self._patching to RandomPatching2D | GridPatching2D for the rest of each method, and route fuse/reset_patch_indices through isinstance. Streamline TestConstructor to only exercise the public contract (.fuse, .model, setter round-trip) and drop assertions on private caches. Compile the denoiser instead of the predictor in TestMultiDiffusionSampleCompile and add TestMultiDiffusionFullSamplerCompile mirroring test_samplers.py. Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Force contiguous pos_embd before patching pos_embd.unsqueeze(0).expand(B, -1, -1, -1) produces a stride-0 view (all B copies share storage). Passing this through nn.ReflectionPad2d and F.unfold inside image_batching triggers a glibc heap corruption on torch 2.10 (CI, not locally on torch 2.8) when the first non-regression posembd_sin test runs. Same class of fix as the earlier euler solver. Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Use functional F.pad in image_batching Instantiating torch.nn.ReflectionPad2d inside image_batching on every call creates a fresh nn.Module each time, which torch.compile / AOT autograd struggles to trace cleanly under fullgraph=True on torch 2.10. Switch to torch.nn.functional.pad which is a plain functional call and traces without allocating a module. Same result semantically. Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Replace einops.rearrange with native torch reshape+permute einops.rearrange goes through a pattern-matched lowering path that torch.compile / inductor on torch 2.10 handles fragilely in the image_batching / image_fuse hot paths. The underlying transform is a plain view + permute + view, so express it directly: this gives inductor a straightforward sequence of ops to trace, and drops the einops dependency from this module. Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Materialise returned tensors in multi-diffusion fuse path Under torch.compile / inductor on torch 2.10, a compiled sample() call through MultiDiffusionPredictor was returning a tensor whose metadata was valid but whose data pointer was dangling (use-after-free) — the caller SIGABRTed on the first read of the tensor data. Add .contiguous() at the two boundaries that returned a view: image_fuse returns x_folded[...] / overlap_count[...], and MultiDiffusionModel2D.forward returns the (possibly fused) inner-model output. Forcing fresh storage on each boundary prevents the returned tensor from aliasing a buffer whose lifetime ends with the compiled frame. Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Use clone instead of contiguous at fuse boundary The second torch.compile call of a fused MultiDiffusionPredictor was segfaulting (SIGSEGV) while the first succeeded. .contiguous() is a no-op when the tensor is already contiguous, so inductor could still see the returned tensor as aliasing an internal buffer across calls. .clone() always allocates fresh storage, so successive compiled calls get independent outputs. Also drop the redundant .contiguous() added earlier in MultiDiffusionModel2D.forward now that image_fuse owns that boundary. Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Revert speculative fuse-boundary copies and xfail full-sampler compile on torch>=2.10 Revert commits 3dfcdb5, 746518f and a007c46 (native-torch rearrange in image_batching/image_fuse, .contiguous() on returned tensors, .clone() at fuse boundary) since they did not resolve the torch 2.10 inductor codegen segfault in TestMultiDiffusionFullSamplerCompile. Keep commits 7e1db11 (pos_embd .contiguous() for the glibc heap corruption in posembd_sin non-regression tests) and feb0d9e (ReflectionPad2d → F.pad). Gate TestMultiDiffusionFullSamplerCompile with xfail(run=False) when torch>=2.10 so the SIGSEGV does not bring down the pytest process. TestMultiDiffusionSampleCompile (per-step denoiser compile) still runs. Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Minor updates to predictor.py Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Drop redundant _patching_type and add test-time-only docstring warning Signed-off-by: Charlelie Laurent <claurent@nvidia.com> --------- Signed-off-by: Charlelie Laurent <claurent@nvidia.com> commit 16a336f Author: Peter Harrington <48932392+pzharrington@users.noreply.github.com> Date: Thu Apr 30 00:03:24 2026 -0700 FSDP optimizer state channels last fix (NVIDIA#1597) * Fix channels last FSDP optimizer state load bug * lint * Catch use_orig_params=True case commit 845906f Author: Peter Harrington <48932392+pzharrington@users.noreply.github.com> Date: Wed Apr 29 23:59:29 2026 -0700 Add HealDA dataloader protocols and init recipe (NVIDIA#1555) * Add healda protocols and loaders to experimental * Cleanup and address imports * Update precommit for examples tests * integrate restartable sampler, other updates, migrate tests * move imports, cleanup * ruff check fix * skip prefetch on CPU * Rename to local_platform * Revert precommit change * greptile feedback * Migrate CSVs and deps to example * lockfile fix
6 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PhysicsNeMo Pull Request
Description
Fixes a silent corruption of optimizer state when checkpointing FSDP-wrapped models that have
channels_lastparams (e.g. DiT in stormcast, or any SongUNet trained with apex GroupNorm).PyTorch FSDP with
use_orig_params=Falsepacks and unpacks theFlatParameterasymmetrically for non-truly-contiguous params:_get_unflat_viewsintorch/distributed/fsdp/_flat_param.py)uses
as_strided((numel,),(1,))-- reads bytes in storage order._flatten_tensor_optim_statein_optim_utils.py)uses
torch.flatten-- reads bytes in logical (row-major) order.For a 4-D Conv2d weight in
channels_lastformat, those two byte orders differ. After a save/load round-trip the optimizer state for that param is silently scrambled, while the model weights survive (the model load path uses the storage-aware variant on both ends).Fixes add a new helper to
physicsnemo/utils/checkpoint.pythat detects the case and performs appropriate permute+contiguous+view calls to load state correctly. Adds test coverage in package and recipe unit tests for the case. This PR also removes train dataloader seeding in the recipe as requested by SAs.Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.