Skip to content

FSDP optimizer state channels last fix#1597

Merged
pzharrington merged 4 commits intoNVIDIA:mainfrom
pzharrington:ckpt-channels-last-fix
Apr 30, 2026
Merged

FSDP optimizer state channels last fix#1597
pzharrington merged 4 commits intoNVIDIA:mainfrom
pzharrington:ckpt-channels-last-fix

Conversation

@pzharrington
Copy link
Copy Markdown
Collaborator

PhysicsNeMo Pull Request

Description

Fixes a silent corruption of optimizer state when checkpointing FSDP-wrapped models that have channels_last params (e.g. DiT in stormcast, or any SongUNet trained with apex GroupNorm).

PyTorch FSDP with use_orig_params=False packs and unpacks the FlatParameter asymmetrically for non-truly-contiguous params:

  • Save (_get_unflat_views in torch/distributed/fsdp/_flat_param.py)
    uses as_strided((numel,),(1,)) -- reads bytes in storage order.
  • Load (_flatten_tensor_optim_state in _optim_utils.py)
    uses torch.flatten -- reads bytes in logical (row-major) order.

For a 4-D Conv2d weight in channels_last format, those two byte orders differ. After a save/load round-trip the optimizer state for that param is silently scrambled, while the model weights survive (the model load path uses the storage-aware variant on both ends).

Fixes add a new helper to physicsnemo/utils/checkpoint.py that detects the case and performs appropriate permute+contiguous+view calls to load state correctly. Adds test coverage in package and recipe unit tests for the case. This PR also removes train dataloader seeding in the recipe as requested by SAs.

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@pzharrington pzharrington self-assigned this Apr 27, 2026
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 27, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 27, 2026

Greptile Summary

This PR fixes a silent corruption of FSDP optimizer state during checkpoint round-trips for models with channels_last parameters (e.g. DiT/SongUNet). The root cause is an asymmetry in PyTorch FSDP's FlatParameter pack (as_strided, storage order) vs. unpack (torch.flatten, logical order) for non-truly-contiguous params; the fix pre-permutes the saved tensors so the loader's torch.flatten reproduces the original storage byte sequence.

  • The permute+contiguous+view transform in _remap_channels_last_optim_sd is mathematically sound for use_orig_params=False, but it is applied unconditionally: a use_orig_params=True FSDP model whose optimizer state happens to contain channels_last tensors would have those tensors incorrectly remapped, producing the same silent corruption the fix aims to prevent.
  • examples/weather/stormcast/test_training.py also corrects a pre-existing bug where optimizer state was compared against itself (opt_params0 on both sides of the assertion) rather than against the loaded checkpoint.

Important Files Changed

Filename Overview
physicsnemo/utils/checkpoint.py Adds _remap_channels_last_optim_sd to work around an FSDP pack/unpack asymmetry for channels_last params; the permute+contiguous+view trick is mathematically sound for use_orig_params=False, but the remap is applied unconditionally without checking the FSDP use_orig_params flag, which could corrupt optimizer state for use_orig_params=True models.
test/utils/test_checkpoint_distributed.py Adds test_fsdp_channels_last_optim_roundtrip covering the channels_last optimizer state round-trip; tolerance in an unrelated test was loosened from 1e-5 to 1e-4 without explanation; new test only uses NO_SHARD strategy.
examples/weather/stormcast/test_training.py Fixes a pre-existing test bug (comparing opt_params0 with itself instead of opt_params1); extends test to world_size=1 with force_sharding parametrize; logic is correct.
examples/weather/stormcast/utils/trainer.py Removes seed= from dataloader construction as requested by SAs; intentional reproducibility trade-off.
CHANGELOG.md Adds changelog entry for the FSDP optimizer state channels_last fix.

Reviews (1): Last reviewed commit: "lint" | Re-trigger Greptile

Comment thread physicsnemo/utils/checkpoint.py
Comment thread test/utils/test_checkpoint_distributed.py Outdated
Comment thread test/utils/test_checkpoint_distributed.py
@pzharrington
Copy link
Copy Markdown
Collaborator Author

/ok to test cc703b5

Copy link
Copy Markdown
Collaborator

@CharlelieLrt CharlelieLrt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@pzharrington
Copy link
Copy Markdown
Collaborator Author

/ok to test d781754

@pzharrington pzharrington enabled auto-merge April 29, 2026 17:08
@pzharrington
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@pzharrington pzharrington added this pull request to the merge queue Apr 30, 2026
Merged via the queue into NVIDIA:main with commit 16a336f Apr 30, 2026
6 checks passed
peterdsharpe added a commit to peterdsharpe/physicsnemo that referenced this pull request May 4, 2026
commit 91a942b
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 21:12:32 2026 -0400

    Adds Greptile minor fixes

commit b24f9b6
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:56:24 2026 -0400

    Back-merges dataset interrogate fix

commit 6ddfb5a
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:43:01 2026 -0400

    Removes accidentally-commited benchmarks; these will come later

commit 9fa0b5d
Merge: 3e67057 4c52a45
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:37:57 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-drivaerml-standalone

commit 3e67057
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:35:32 2026 -0400

    Partial merge from add-GLOBE-3D-BarnesHut

commit 4c52a45
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 16:38:28 2026 -0400

    Synchronizes `GLOBE` model progress for 26.05 (NVIDIA#1595)

    * Migrate cached_dataset.py

    * verified model arch new features (self_regularization_beta)

    * minor formatting syncs

    * Adds nonregression testing

    * Adds compile_logging utilities and prefetching utilities

    * Adds self to pade.py codeowners

    * Syncs AirFRANS updates

    * corrects a docstring

    * Strips out broken ram caching

    * Adds helpful error messages

    * Adds helpful error messages

    * docs

    * Refactor compile logging in training script

    - Removed the CompileDiagnosticsCollector and replaced it with a new utility function, silence_compile_logs_on_non_zero_ranks, to suppress non-error logs from torch.compile on all ranks except rank 0.
    - Updated the training script to call this new function, improving log clarity during distributed training.
    - Adjusted logging levels for the globe logger to ensure proper diagnostics are captured only during the first launch.

    * Enhance DataLoader worker configuration for distributed training

    - Updated the logic for auto-computing `num_workers` in the `AirFRANSDataSet` class to consider CPU affinity and local world size, improving efficiency in distributed environments.
    - Adjusted logging to provide detailed information about the computed `num_workers`, including CPU count and GPU visibility.
    - Modified the run script comments to reflect the new method of calculating `num_workers`, ensuring clarity on process-level parallelism.

    * Partial merge from add-GLOBE-3D-BarnesHut

commit 4cb586a
Merge: 645701f ed855da
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 15:23:50 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-model-progress

commit 645701f
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 15:06:55 2026 -0400

    Partial merge from add-GLOBE-3D-BarnesHut

commit 15d7913
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 13:45:06 2026 -0400

    Enhance DataLoader worker configuration for distributed training

    - Updated the logic for auto-computing `num_workers` in the `AirFRANSDataSet` class to consider CPU affinity and local world size, improving efficiency in distributed environments.
    - Adjusted logging to provide detailed information about the computed `num_workers`, including CPU count and GPU visibility.
    - Modified the run script comments to reflect the new method of calculating `num_workers`, ensuring clarity on process-level parallelism.

commit 65675a4
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri May 1 17:53:47 2026 -0400

    Refactor compile logging in training script

    - Removed the CompileDiagnosticsCollector and replaced it with a new utility function, silence_compile_logs_on_non_zero_ranks, to suppress non-error logs from torch.compile on all ranks except rank 0.
    - Updated the training script to call this new function, improving log clarity during distributed training.
    - Adjusted logging levels for the globe logger to ensure proper diagnostics are captured only during the first launch.

commit 948da86
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri May 1 15:40:31 2026 -0400

    docs

commit ed855da
Author: Charlelie Laurent <84199758+CharlelieLrt@users.noreply.github.com>
Date:   Thu Apr 30 20:44:37 2026 -0700

    Implements Predictor specialization for multi-diffusion (NVIDIA#1573)

    * Implements Predictor specialization for multi-diffusion

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Compile denoiser in multi-diffusion sampling compile tests

    Compiling the predictor instance directly was producing divergent results
    under torch 2.10 in the sample() loop (euler cases only). Follow the same
    pattern as test_samplers.py::TestSampleCompile and compile the denoiser
    closure instead — tracing through it still verifies that the predictor's
    __call__ path is compile-compatible.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Avoid fullgraph compile in multi-diffusion sampling test

    torch 2.10 Dynamo crashes with Fatal Python error: Aborted when tracing
    the nested MultiDiffusionPredictor -> MultiDiffusionModel2D call chain
    inside sample() with fullgraph=True. Allow graph breaks here; the
    predictor compile contract is still tested in isolation by
    test_multi_diffusion_predictor.py::TestCompile.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Flatten MultiDiffusionPredictor hot path for torch.compile

    Dispatch on pos_embd presence and model_kwargs is now resolved once at
    __init__ into a specialized closure, so __call__ is branch-free and the
    no-kwargs path avoids ** expansion. This keeps fullgraph=True compile
    cleanly traceable under torch 2.10 (which was hitting a Dynamo abort on
    the nested MultiDiffusionPredictor -> MultiDiffusionModel2D call chain
    when the denoiser closure was compiled in the sample() loop).

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Loosen torch.compile euler check in multi-diffusion sampling tests

    Reverts the two earlier CI-fix attempts (compile-denoiser switch, predictor
    hot-path flatten) since neither actually fixed the divergence. The
    underlying issue is an upstream torch>=2.10 Dynamo bug: euler + compiled
    MultiDiffusionPredictor produces numerically divergent results. Heun works,
    predictor compiles correctly in isolation. For euler we now assert only
    shape + isfinite until the upstream bug is resolved.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Force contiguous t_cur/t_next in Euler solvers

    sample() passes t.expand(B) (a stride-0 non-contiguous tensor) into
    solver.step(). HeunSolver already forces .contiguous() on both tensors to
    prevent torch.compile from specializing on the stride pattern of the first
    call and then either mis-firing guards or silently recompiling on
    subsequent calls with different underlying storage.

    EulerSolver and EDMStochasticEulerSolver had no such guard, which was a
    latent bug exposed by torch 2.10 (stricter stride tracking) in the
    multi-diffusion compiled sample loop — producing 90%+ element divergence
    vs eager on the first call and a Dynamo abort on the second call. Apply
    the same fix uniformly across all four solver steps.

    Also revert the temporary loosened euler assertion in
    test_multi_diffusion_sampling.py now that the real fix is in place.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Drop dead is_compiling guard and inherit from Predictor in MultiDiffusionPredictor

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Narrow _patching type and tighten multi-diffusion tests

    Move the _patching None check out of the is_compiling guard in
    MultiDiffusionModel2D so the type checker narrows self._patching
    to RandomPatching2D | GridPatching2D for the rest of each method,
    and route fuse/reset_patch_indices through isinstance.

    Streamline TestConstructor to only exercise the public contract
    (.fuse, .model, setter round-trip) and drop assertions on private
    caches. Compile the denoiser instead of the predictor in
    TestMultiDiffusionSampleCompile and add TestMultiDiffusionFullSamplerCompile
    mirroring test_samplers.py.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Force contiguous pos_embd before patching

    pos_embd.unsqueeze(0).expand(B, -1, -1, -1) produces a stride-0 view
    (all B copies share storage). Passing this through nn.ReflectionPad2d
    and F.unfold inside image_batching triggers a glibc heap corruption
    on torch 2.10 (CI, not locally on torch 2.8) when the first non-regression
    posembd_sin test runs. Same class of fix as the earlier euler solver.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Use functional F.pad in image_batching

    Instantiating torch.nn.ReflectionPad2d inside image_batching on every
    call creates a fresh nn.Module each time, which torch.compile / AOT
    autograd struggles to trace cleanly under fullgraph=True on torch 2.10.
    Switch to torch.nn.functional.pad which is a plain functional call and
    traces without allocating a module. Same result semantically.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Replace einops.rearrange with native torch reshape+permute

    einops.rearrange goes through a pattern-matched lowering path that
    torch.compile / inductor on torch 2.10 handles fragilely in the
    image_batching / image_fuse hot paths. The underlying transform is a
    plain view + permute + view, so express it directly: this gives inductor
    a straightforward sequence of ops to trace, and drops the einops
    dependency from this module.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Materialise returned tensors in multi-diffusion fuse path

    Under torch.compile / inductor on torch 2.10, a compiled sample() call
    through MultiDiffusionPredictor was returning a tensor whose metadata
    was valid but whose data pointer was dangling (use-after-free) — the
    caller SIGABRTed on the first read of the tensor data. Add .contiguous()
    at the two boundaries that returned a view: image_fuse returns
    x_folded[...] / overlap_count[...], and MultiDiffusionModel2D.forward
    returns the (possibly fused) inner-model output. Forcing fresh storage
    on each boundary prevents the returned tensor from aliasing a buffer
    whose lifetime ends with the compiled frame.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Use clone instead of contiguous at fuse boundary

    The second torch.compile call of a fused MultiDiffusionPredictor was
    segfaulting (SIGSEGV) while the first succeeded. .contiguous() is a
    no-op when the tensor is already contiguous, so inductor could still
    see the returned tensor as aliasing an internal buffer across calls.
    .clone() always allocates fresh storage, so successive compiled calls
    get independent outputs. Also drop the redundant .contiguous() added
    earlier in MultiDiffusionModel2D.forward now that image_fuse owns that
    boundary.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Revert speculative fuse-boundary copies and xfail full-sampler compile on torch>=2.10

    Revert commits 3dfcdb5, 746518f and a007c46 (native-torch rearrange
    in image_batching/image_fuse, .contiguous() on returned tensors, .clone()
    at fuse boundary) since they did not resolve the torch 2.10 inductor
    codegen segfault in TestMultiDiffusionFullSamplerCompile. Keep commits
    7e1db11 (pos_embd .contiguous() for the glibc heap corruption in
    posembd_sin non-regression tests) and feb0d9e (ReflectionPad2d → F.pad).

    Gate TestMultiDiffusionFullSamplerCompile with xfail(run=False) when
    torch>=2.10 so the SIGSEGV does not bring down the pytest process.
    TestMultiDiffusionSampleCompile (per-step denoiser compile) still runs.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Minor updates to predictor.py

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Drop redundant _patching_type and add test-time-only docstring warning

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    ---------

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

commit 16a336f
Author: Peter Harrington <48932392+pzharrington@users.noreply.github.com>
Date:   Thu Apr 30 00:03:24 2026 -0700

    FSDP optimizer state channels last fix (NVIDIA#1597)

    * Fix channels last FSDP optimizer state load bug

    * lint

    * Catch use_orig_params=True case

commit 845906f
Author: Peter Harrington <48932392+pzharrington@users.noreply.github.com>
Date:   Wed Apr 29 23:59:29 2026 -0700

    Add HealDA dataloader protocols and init recipe (NVIDIA#1555)

    * Add healda protocols and loaders to experimental

    * Cleanup and address imports

    * Update precommit for examples tests

    * integrate restartable sampler, other updates, migrate tests

    * move imports, cleanup

    * ruff check fix

    * skip prefetch on CPU

    * Rename to local_platform

    * Revert precommit change

    * greptile feedback

    * Migrate CSVs and deps to example

    * lockfile fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants