Skip to content

Fix SongUNet with ShardTensor when using zero embedding#1432

Merged
pzharrington merged 12 commits intoNVIDIA:2.0.0-rcfrom
jleinonen:unet-shardtensor
Feb 19, 2026
Merged

Fix SongUNet with ShardTensor when using zero embedding#1432
pzharrington merged 12 commits intoNVIDIA:2.0.0-rcfrom
jleinonen:unet-shardtensor

Conversation

@jleinonen
Copy link
Collaborator

@jleinonen jleinonen commented Feb 19, 2026

PhysicsNeMo Pull Request

When training regression models (i.e. no time step embedding, embedding_type == "zero") using ShardTensor, SongUNet was giving this error:

  File "/usr/local/lib/python3.12/dist-packages/physicsnemo/models/diffusion_unets/song_unet.py", line 663, in forward
    x = block(x, emb)
        ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/physicsnemo/nn/module/unet_layers.py", line 242, in forward
    params = self.affine(emb).unsqueeze(2).unsqueeze(3)
             ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/physicsnemo/nn/module/fully_connected_layers.py", line 527, in forward
    x = x @ weight.t()

caused by the emb tensor being a plain torch.Tensor:

else:
emb = torch.zeros(
(noise_labels.shape[0], self.emb_channels),
device=x.device,
dtype=x.dtype,
)

I added a conversion of emb to ShardTensor if x is a ShardTensor. With this fix, it is possible to train the StormCastUNet type regression models with ShardTensor.

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@jleinonen jleinonen self-assigned this Feb 19, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 19, 2026

Greptile Summary

This PR fixes a tensor type mismatch error that occurred when training regression models with SongUNet using embedding_type == "zero" and ShardTensor. The error was caused by the emb tensor being a plain torch.Tensor while the input x was a ShardTensor, causing incompatibility in the matrix multiplication operation in the Linear layer (x @ weight.t()).

The fix adds a conditional check: when x is a ShardTensor, the emb tensor (created as zeros) is converted to a ShardTensor using ShardTensor.from_local() with the same device mesh as x. This ensures type consistency between tensors in subsequent operations.

  • Adds ShardTensor import from physicsnemo.domain_parallel.shard_tensor
  • Adds conversion logic after creating zero embedding tensor (lines 634-635)
  • The fix is minimal and targeted, only affecting the zero embedding path
  • The non-zero embedding path doesn't need this fix since the mapping layers (map_layer0, map_layer1) are Linear modules that preserve the tensor type through operations

Important Files Changed

Filename Overview
physicsnemo/models/diffusion_unets/song_unet.py Adds ShardTensor conversion for zero embedding case to fix tensor type mismatch error

Last reviewed commit: 8db65bf

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@pzharrington
Copy link
Collaborator

/blossom-ci

@pzharrington pzharrington changed the base branch from main to 2.0.0-rc February 19, 2026 19:46
@pzharrington pzharrington merged commit c3a8248 into NVIDIA:2.0.0-rc Feb 19, 2026
4 checks passed
ktangsali pushed a commit that referenced this pull request Feb 25, 2026
* Bug fixes for ShardTensor+SongUNet

* Handle dtensor spec in sharded view

* Fix SongUNet with ShardTensor when using zero embedding

* Use buffer for zero embed

---------

Co-authored-by: Peter Harrington <48932392+pzharrington@users.noreply.github.com>
Co-authored-by: Peter Harrington <pharrington@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants