Skip to content

Implements Predictor specialization for multi-diffusion#1573

Merged
CharlelieLrt merged 20 commits intoNVIDIA:mainfrom
CharlelieLrt:multi_diffusion_sampling
May 1, 2026
Merged

Implements Predictor specialization for multi-diffusion#1573
CharlelieLrt merged 20 commits intoNVIDIA:mainfrom
CharlelieLrt:multi_diffusion_sampling

Conversation

@CharlelieLrt
Copy link
Copy Markdown
Collaborator

@CharlelieLrt CharlelieLrt commented Apr 17, 2026

PhysicsNeMo Pull Request

Description

Adds MultiDiffusionPredictor, the sampling-time counterpart of MultiDiffusionMSEDSMLoss. It wraps a trained MultiDiffusionModel2D and satisfies the Predictor protocol, so it plugs straight into sample() / get_denoiser() / standard solvers with no other changes. Condition and positional embedding are pre-patched once at construction; per-step calls only patch x and t. New unit, end-to-end, and compile tests live under test/diffusion/test_multi_diffusion_{predictor,sampling}.py. Drive-by fixes: .contiguous() on t_cur / t_next in EulerSolver and EDMStochasticEulerSolver (parity with HeunSolver), .contiguous() on the broadcast pos_embd before patching, and F.pad in place of nn.ReflectionPad2d(...)(input) in image_batching.

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 17, 2026

Greptile Summary

This PR introduces MultiDiffusionPredictor, a Predictor-protocol-compatible wrapper for test-time patch-based diffusion sampling using MultiDiffusionModel2D. It pre-patches the condition and positional embedding once at construction time to avoid redundant work per step, refactors the duplicate PE injection logic into a shared _inject_patched_pos_embd helper, and generalises the DDP/compile-unwrapping helper into a reusable _unwrap_module utility. A thorough test suite covering constructor validation, non-regression, gradient flow, torch.compile, and checkpoint round-trips is included.

  • P1 (predictor.py line 406): _skip_positional_embedding_injection = True is permanently written to the underlying MultiDiffusionModel2D and never restored. If the same model object is used for loss computation after the predictor is created (e.g., continued fine-tuning), PE injection is silently skipped, producing wrong training outputs without any error.

Important Files Changed

Filename Overview
physicsnemo/diffusion/multi_diffusion/predictor.py New file implementing MultiDiffusionPredictor — a Predictor-protocol wrapper for patch-based sampling. Correctly pre-patches condition and PE at construction, but permanently mutates the underlying model's _skip_positional_embedding_injection flag with no restore mechanism (P1), and the dead is_compiling() guard in init is misleading.
physicsnemo/diffusion/multi_diffusion/models.py Adds _skip_positional_embedding_injection flag and refactors PE injection into a shared _inject_patched_pos_embd helper. The refactoring is clean and the logic is preserved correctly.
physicsnemo/diffusion/utils/utils.py Extracts the DDP/torch.compile unwrapping logic into a reusable generic _unwrap_module function. Clean generalization with correct TypeVar usage.
physicsnemo/diffusion/multi_diffusion/losses.py Replaces the local _unwrap_multi_diffusion helper with the new generic _unwrap_module utility. Equivalent behavior, no logic changes.
physicsnemo/diffusion/guidance/dps_guidance.py One-line change: DPSScorePredictor now explicitly inherits from Predictor, correctly aligning it with the protocol.
physicsnemo/diffusion/multi_diffusion/init.py Adds MultiDiffusionPredictor to the public module exports.
test/diffusion/test_multi_diffusion_predictor.py Comprehensive test suite for MultiDiffusionPredictor covering constructor validation, non-regression outputs, gradient flow, torch.compile compatibility, and checkpoint round-trips.
test/diffusion/test_multi_diffusion_sampling.py End-to-end sampling tests for MultiDiffusionPredictor integrated with all three noise schedulers (EDM, VE, VP), two solvers, and compiled predictor paths.

Comments Outside Diff (1)

  1. physicsnemo/diffusion/multi_diffusion/predictor.py, line 405-406 (link)

    P1 Persistent flag side-effect on underlying model

    _skip_positional_embedding_injection = True is written to _md_model and never restored. If the same MultiDiffusionModel2D is used for loss computation after the predictor is created — for example, after sampling a checkpoint to validate a mid-training result or during continued fine-tuning — the model's forward will silently skip PE injection in both the "no patching" and "with patching" paths, producing wrong outputs from the loss without any error or warning.

    A minimal safeguard would be to save and restore the original flag, or at minimum include a runtime warning that the model is permanently modified:

    # Save original value so callers can restore it
    self._prev_skip_pe = self._md_model._skip_positional_embedding_injection
    self._md_model._skip_positional_embedding_injection = True

    Or add a close() / __del__ to restore the flag, or document that the model must not be used for training after predictor construction.

Reviews (1): Last reviewed commit: "Implements Predictor specialization for ..." | Re-trigger Greptile

Comment thread physicsnemo/diffusion/multi_diffusion/predictor.py Outdated
Comment thread physicsnemo/diffusion/multi_diffusion/predictor.py
Comment thread physicsnemo/diffusion/multi_diffusion/predictor.py
Compiling the predictor instance directly was producing divergent results
under torch 2.10 in the sample() loop (euler cases only). Follow the same
pattern as test_samplers.py::TestSampleCompile and compile the denoiser
closure instead — tracing through it still verifies that the predictor's
__call__ path is compile-compatible.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
torch 2.10 Dynamo crashes with Fatal Python error: Aborted when tracing
the nested MultiDiffusionPredictor -> MultiDiffusionModel2D call chain
inside sample() with fullgraph=True. Allow graph breaks here; the
predictor compile contract is still tested in isolation by
test_multi_diffusion_predictor.py::TestCompile.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Dispatch on pos_embd presence and model_kwargs is now resolved once at
__init__ into a specialized closure, so __call__ is branch-free and the
no-kwargs path avoids ** expansion. This keeps fullgraph=True compile
cleanly traceable under torch 2.10 (which was hitting a Dynamo abort on
the nested MultiDiffusionPredictor -> MultiDiffusionModel2D call chain
when the denoiser closure was compiled in the sample() loop).

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Reverts the two earlier CI-fix attempts (compile-denoiser switch, predictor
hot-path flatten) since neither actually fixed the divergence. The
underlying issue is an upstream torch>=2.10 Dynamo bug: euler + compiled
MultiDiffusionPredictor produces numerically divergent results. Heun works,
predictor compiles correctly in isolation. For euler we now assert only
shape + isfinite until the upstream bug is resolved.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
sample() passes t.expand(B) (a stride-0 non-contiguous tensor) into
solver.step(). HeunSolver already forces .contiguous() on both tensors to
prevent torch.compile from specializing on the stride pattern of the first
call and then either mis-firing guards or silently recompiling on
subsequent calls with different underlying storage.

EulerSolver and EDMStochasticEulerSolver had no such guard, which was a
latent bug exposed by torch 2.10 (stricter stride tracking) in the
multi-diffusion compiled sample loop — producing 90%+ element divergence
vs eager on the first call and a Dynamo abort on the second call. Apply
the same fix uniformly across all four solver steps.

Also revert the temporary loosened euler assertion in
test_multi_diffusion_sampling.py now that the real fix is in place.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
…sionPredictor

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Move the _patching None check out of the is_compiling guard in
MultiDiffusionModel2D so the type checker narrows self._patching
to RandomPatching2D | GridPatching2D for the rest of each method,
and route fuse/reset_patch_indices through isinstance.

Streamline TestConstructor to only exercise the public contract
(.fuse, .model, setter round-trip) and drop assertions on private
caches. Compile the denoiser instead of the predictor in
TestMultiDiffusionSampleCompile and add TestMultiDiffusionFullSamplerCompile
mirroring test_samplers.py.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
pos_embd.unsqueeze(0).expand(B, -1, -1, -1) produces a stride-0 view
(all B copies share storage). Passing this through nn.ReflectionPad2d
and F.unfold inside image_batching triggers a glibc heap corruption
on torch 2.10 (CI, not locally on torch 2.8) when the first non-regression
posembd_sin test runs. Same class of fix as the earlier euler solver.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Instantiating torch.nn.ReflectionPad2d inside image_batching on every
call creates a fresh nn.Module each time, which torch.compile / AOT
autograd struggles to trace cleanly under fullgraph=True on torch 2.10.
Switch to torch.nn.functional.pad which is a plain functional call and
traces without allocating a module. Same result semantically.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
einops.rearrange goes through a pattern-matched lowering path that
torch.compile / inductor on torch 2.10 handles fragilely in the
image_batching / image_fuse hot paths. The underlying transform is a
plain view + permute + view, so express it directly: this gives inductor
a straightforward sequence of ops to trace, and drops the einops
dependency from this module.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Under torch.compile / inductor on torch 2.10, a compiled sample() call
through MultiDiffusionPredictor was returning a tensor whose metadata
was valid but whose data pointer was dangling (use-after-free) — the
caller SIGABRTed on the first read of the tensor data. Add .contiguous()
at the two boundaries that returned a view: image_fuse returns
x_folded[...] / overlap_count[...], and MultiDiffusionModel2D.forward
returns the (possibly fused) inner-model output. Forcing fresh storage
on each boundary prevents the returned tensor from aliasing a buffer
whose lifetime ends with the compiled frame.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
The second torch.compile call of a fused MultiDiffusionPredictor was
segfaulting (SIGSEGV) while the first succeeded. .contiguous() is a
no-op when the tensor is already contiguous, so inductor could still
see the returned tensor as aliasing an internal buffer across calls.
.clone() always allocates fresh storage, so successive compiled calls
get independent outputs. Also drop the redundant .contiguous() added
earlier in MultiDiffusionModel2D.forward now that image_fuse owns that
boundary.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
…e on torch>=2.10

Revert commits 3dfcdb5, 746518f and a007c46 (native-torch rearrange
in image_batching/image_fuse, .contiguous() on returned tensors, .clone()
at fuse boundary) since they did not resolve the torch 2.10 inductor
codegen segfault in TestMultiDiffusionFullSamplerCompile. Keep commits
7e1db11 (pos_embd .contiguous() for the glibc heap corruption in
posembd_sin non-regression tests) and feb0d9e (ReflectionPad2d → F.pad).

Gate TestMultiDiffusionFullSamplerCompile with xfail(run=False) when
torch>=2.10 so the SIGSEGV does not bring down the pytest process.
TestMultiDiffusionSampleCompile (per-step denoiser compile) still runs.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
@CharlelieLrt
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

Comment thread physicsnemo/diffusion/multi_diffusion/models.py
Comment thread physicsnemo/diffusion/multi_diffusion/predictor.py Outdated
Copy link
Copy Markdown
Collaborator

@pzharrington pzharrington left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving with minor comments. I'll note the test coverage is good, but this is adding another fairly large set of data files to the repo. We probably eventually need a better solution then keeping everything here, but that's a problem for another day.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 29, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@CharlelieLrt
Copy link
Copy Markdown
Collaborator Author

/ok to test 0874c90

@CharlelieLrt CharlelieLrt enabled auto-merge May 1, 2026 01:49
@CharlelieLrt
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@CharlelieLrt CharlelieLrt added this pull request to the merge queue May 1, 2026
Merged via the queue into NVIDIA:main with commit ed855da May 1, 2026
6 checks passed
peterdsharpe added a commit to peterdsharpe/physicsnemo that referenced this pull request May 4, 2026
commit 91a942b
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 21:12:32 2026 -0400

    Adds Greptile minor fixes

commit b24f9b6
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:56:24 2026 -0400

    Back-merges dataset interrogate fix

commit 6ddfb5a
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:43:01 2026 -0400

    Removes accidentally-commited benchmarks; these will come later

commit 9fa0b5d
Merge: 3e67057 4c52a45
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:37:57 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-drivaerml-standalone

commit 3e67057
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 20:35:32 2026 -0400

    Partial merge from add-GLOBE-3D-BarnesHut

commit 4c52a45
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 16:38:28 2026 -0400

    Synchronizes `GLOBE` model progress for 26.05 (NVIDIA#1595)

    * Migrate cached_dataset.py

    * verified model arch new features (self_regularization_beta)

    * minor formatting syncs

    * Adds nonregression testing

    * Adds compile_logging utilities and prefetching utilities

    * Adds self to pade.py codeowners

    * Syncs AirFRANS updates

    * corrects a docstring

    * Strips out broken ram caching

    * Adds helpful error messages

    * Adds helpful error messages

    * docs

    * Refactor compile logging in training script

    - Removed the CompileDiagnosticsCollector and replaced it with a new utility function, silence_compile_logs_on_non_zero_ranks, to suppress non-error logs from torch.compile on all ranks except rank 0.
    - Updated the training script to call this new function, improving log clarity during distributed training.
    - Adjusted logging levels for the globe logger to ensure proper diagnostics are captured only during the first launch.

    * Enhance DataLoader worker configuration for distributed training

    - Updated the logic for auto-computing `num_workers` in the `AirFRANSDataSet` class to consider CPU affinity and local world size, improving efficiency in distributed environments.
    - Adjusted logging to provide detailed information about the computed `num_workers`, including CPU count and GPU visibility.
    - Modified the run script comments to reflect the new method of calculating `num_workers`, ensuring clarity on process-level parallelism.

    * Partial merge from add-GLOBE-3D-BarnesHut

commit 4cb586a
Merge: 645701f ed855da
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 15:23:50 2026 -0400

    Merge branch 'main' into psharpe/add-GLOBE-model-progress

commit 645701f
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 15:06:55 2026 -0400

    Partial merge from add-GLOBE-3D-BarnesHut

commit 15d7913
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Sun May 3 13:45:06 2026 -0400

    Enhance DataLoader worker configuration for distributed training

    - Updated the logic for auto-computing `num_workers` in the `AirFRANSDataSet` class to consider CPU affinity and local world size, improving efficiency in distributed environments.
    - Adjusted logging to provide detailed information about the computed `num_workers`, including CPU count and GPU visibility.
    - Modified the run script comments to reflect the new method of calculating `num_workers`, ensuring clarity on process-level parallelism.

commit 65675a4
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri May 1 17:53:47 2026 -0400

    Refactor compile logging in training script

    - Removed the CompileDiagnosticsCollector and replaced it with a new utility function, silence_compile_logs_on_non_zero_ranks, to suppress non-error logs from torch.compile on all ranks except rank 0.
    - Updated the training script to call this new function, improving log clarity during distributed training.
    - Adjusted logging levels for the globe logger to ensure proper diagnostics are captured only during the first launch.

commit 948da86
Author: Peter Sharpe <peterdsharpe@gmail.com>
Date:   Fri May 1 15:40:31 2026 -0400

    docs

commit ed855da
Author: Charlelie Laurent <84199758+CharlelieLrt@users.noreply.github.com>
Date:   Thu Apr 30 20:44:37 2026 -0700

    Implements Predictor specialization for multi-diffusion (NVIDIA#1573)

    * Implements Predictor specialization for multi-diffusion

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Compile denoiser in multi-diffusion sampling compile tests

    Compiling the predictor instance directly was producing divergent results
    under torch 2.10 in the sample() loop (euler cases only). Follow the same
    pattern as test_samplers.py::TestSampleCompile and compile the denoiser
    closure instead — tracing through it still verifies that the predictor's
    __call__ path is compile-compatible.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Avoid fullgraph compile in multi-diffusion sampling test

    torch 2.10 Dynamo crashes with Fatal Python error: Aborted when tracing
    the nested MultiDiffusionPredictor -> MultiDiffusionModel2D call chain
    inside sample() with fullgraph=True. Allow graph breaks here; the
    predictor compile contract is still tested in isolation by
    test_multi_diffusion_predictor.py::TestCompile.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Flatten MultiDiffusionPredictor hot path for torch.compile

    Dispatch on pos_embd presence and model_kwargs is now resolved once at
    __init__ into a specialized closure, so __call__ is branch-free and the
    no-kwargs path avoids ** expansion. This keeps fullgraph=True compile
    cleanly traceable under torch 2.10 (which was hitting a Dynamo abort on
    the nested MultiDiffusionPredictor -> MultiDiffusionModel2D call chain
    when the denoiser closure was compiled in the sample() loop).

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Loosen torch.compile euler check in multi-diffusion sampling tests

    Reverts the two earlier CI-fix attempts (compile-denoiser switch, predictor
    hot-path flatten) since neither actually fixed the divergence. The
    underlying issue is an upstream torch>=2.10 Dynamo bug: euler + compiled
    MultiDiffusionPredictor produces numerically divergent results. Heun works,
    predictor compiles correctly in isolation. For euler we now assert only
    shape + isfinite until the upstream bug is resolved.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Force contiguous t_cur/t_next in Euler solvers

    sample() passes t.expand(B) (a stride-0 non-contiguous tensor) into
    solver.step(). HeunSolver already forces .contiguous() on both tensors to
    prevent torch.compile from specializing on the stride pattern of the first
    call and then either mis-firing guards or silently recompiling on
    subsequent calls with different underlying storage.

    EulerSolver and EDMStochasticEulerSolver had no such guard, which was a
    latent bug exposed by torch 2.10 (stricter stride tracking) in the
    multi-diffusion compiled sample loop — producing 90%+ element divergence
    vs eager on the first call and a Dynamo abort on the second call. Apply
    the same fix uniformly across all four solver steps.

    Also revert the temporary loosened euler assertion in
    test_multi_diffusion_sampling.py now that the real fix is in place.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Drop dead is_compiling guard and inherit from Predictor in MultiDiffusionPredictor

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Narrow _patching type and tighten multi-diffusion tests

    Move the _patching None check out of the is_compiling guard in
    MultiDiffusionModel2D so the type checker narrows self._patching
    to RandomPatching2D | GridPatching2D for the rest of each method,
    and route fuse/reset_patch_indices through isinstance.

    Streamline TestConstructor to only exercise the public contract
    (.fuse, .model, setter round-trip) and drop assertions on private
    caches. Compile the denoiser instead of the predictor in
    TestMultiDiffusionSampleCompile and add TestMultiDiffusionFullSamplerCompile
    mirroring test_samplers.py.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Force contiguous pos_embd before patching

    pos_embd.unsqueeze(0).expand(B, -1, -1, -1) produces a stride-0 view
    (all B copies share storage). Passing this through nn.ReflectionPad2d
    and F.unfold inside image_batching triggers a glibc heap corruption
    on torch 2.10 (CI, not locally on torch 2.8) when the first non-regression
    posembd_sin test runs. Same class of fix as the earlier euler solver.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Use functional F.pad in image_batching

    Instantiating torch.nn.ReflectionPad2d inside image_batching on every
    call creates a fresh nn.Module each time, which torch.compile / AOT
    autograd struggles to trace cleanly under fullgraph=True on torch 2.10.
    Switch to torch.nn.functional.pad which is a plain functional call and
    traces without allocating a module. Same result semantically.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Replace einops.rearrange with native torch reshape+permute

    einops.rearrange goes through a pattern-matched lowering path that
    torch.compile / inductor on torch 2.10 handles fragilely in the
    image_batching / image_fuse hot paths. The underlying transform is a
    plain view + permute + view, so express it directly: this gives inductor
    a straightforward sequence of ops to trace, and drops the einops
    dependency from this module.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Materialise returned tensors in multi-diffusion fuse path

    Under torch.compile / inductor on torch 2.10, a compiled sample() call
    through MultiDiffusionPredictor was returning a tensor whose metadata
    was valid but whose data pointer was dangling (use-after-free) — the
    caller SIGABRTed on the first read of the tensor data. Add .contiguous()
    at the two boundaries that returned a view: image_fuse returns
    x_folded[...] / overlap_count[...], and MultiDiffusionModel2D.forward
    returns the (possibly fused) inner-model output. Forcing fresh storage
    on each boundary prevents the returned tensor from aliasing a buffer
    whose lifetime ends with the compiled frame.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Use clone instead of contiguous at fuse boundary

    The second torch.compile call of a fused MultiDiffusionPredictor was
    segfaulting (SIGSEGV) while the first succeeded. .contiguous() is a
    no-op when the tensor is already contiguous, so inductor could still
    see the returned tensor as aliasing an internal buffer across calls.
    .clone() always allocates fresh storage, so successive compiled calls
    get independent outputs. Also drop the redundant .contiguous() added
    earlier in MultiDiffusionModel2D.forward now that image_fuse owns that
    boundary.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Revert speculative fuse-boundary copies and xfail full-sampler compile on torch>=2.10

    Revert commits 3dfcdb5, 746518f and a007c46 (native-torch rearrange
    in image_batching/image_fuse, .contiguous() on returned tensors, .clone()
    at fuse boundary) since they did not resolve the torch 2.10 inductor
    codegen segfault in TestMultiDiffusionFullSamplerCompile. Keep commits
    7e1db11 (pos_embd .contiguous() for the glibc heap corruption in
    posembd_sin non-regression tests) and feb0d9e (ReflectionPad2d → F.pad).

    Gate TestMultiDiffusionFullSamplerCompile with xfail(run=False) when
    torch>=2.10 so the SIGSEGV does not bring down the pytest process.
    TestMultiDiffusionSampleCompile (per-step denoiser compile) still runs.

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Minor updates to predictor.py

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    * Drop redundant _patching_type and add test-time-only docstring warning

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

    ---------

    Signed-off-by: Charlelie Laurent <claurent@nvidia.com>

commit 16a336f
Author: Peter Harrington <48932392+pzharrington@users.noreply.github.com>
Date:   Thu Apr 30 00:03:24 2026 -0700

    FSDP optimizer state channels last fix (NVIDIA#1597)

    * Fix channels last FSDP optimizer state load bug

    * lint

    * Catch use_orig_params=True case

commit 845906f
Author: Peter Harrington <48932392+pzharrington@users.noreply.github.com>
Date:   Wed Apr 29 23:59:29 2026 -0700

    Add HealDA dataloader protocols and init recipe (NVIDIA#1555)

    * Add healda protocols and loaders to experimental

    * Cleanup and address imports

    * Update precommit for examples tests

    * integrate restartable sampler, other updates, migrate tests

    * move imports, cleanup

    * ruff check fix

    * skip prefetch on CPU

    * Rename to local_platform

    * Revert precommit change

    * greptile feedback

    * Migrate CSVs and deps to example

    * lockfile fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants