Skip to content

Shard tensor view, reshape, and unsqueeze fixes#1413

Merged
coreyjadams merged 7 commits intoNVIDIA:mainfrom
coreyjadams:shard_tensor_view_and_squeeze_fixes-clean
Feb 17, 2026
Merged

Shard tensor view, reshape, and unsqueeze fixes#1413
coreyjadams merged 7 commits intoNVIDIA:mainfrom
coreyjadams:shard_tensor_view_and_squeeze_fixes-clean

Conversation

@coreyjadams
Copy link
Collaborator

PhysicsNeMo Pull Request

Until this PR, ShardTensor has mostly relied on DTensor's ops for view, reshape, and unsqueeze. This PR implements those operations with ShardTensor instead, to provide a more flexible interplay with shaping operations that are now breaking in the torch_function and torch_dispatch layers of ShardTensor.

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

We now explicitly apply both of those operations, to make sure the
operations work at both the dispatch and functional level.
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 13, 2026

Greptile Summary

This PR implements proper view, reshape, and unsqueeze operations natively in ShardTensor, replacing the previous approach that relied on DTensor's ops and was causing failures in the __torch_function__ and __torch_dispatch__ layers.

Key changes:

  • New view_ops.py module with a complete, locally-computed (no communication) implementation of sharded view/reshape, including dtype-reinterpretation (view(dtype)), differentiable autograd support via ShardedView, and registration at both __torch_function__ and __torch_dispatch__ levels for torch.Tensor.view, torch.Tensor.reshape, torch.reshape, aten.view.default, and aten.reshape.default.
  • unary_ops.py refactored: unsqueeze_wrapper is promoted to a __torch_function__-level handler (matching the pattern used in view_ops), with a separate _unsqueeze_dispatch bridging the __torch_dispatch__ level.
  • shard_tensor.py adds _dispatch_registry_by_name as a name-based fallback lookup for dispatch handlers, guarding against OpOverload identity mismatches across PyTorch versions.
  • Previous two bugs in test_unary_ops.py where assertions compared i_sharded_unsqueeze.shape against itself (a tautology) are fixed.
  • test_sharded_domino.py is simplified by replacing an inline duplicate of model_params with the canonical DEFAULT_MODEL_PARAMS from physicsnemo.models.domino.config.
  • One misleading comment found in test_view_ops.py: # 4 int64 -> 4 float32 (shape same, bytes 32) — 4 int64 (8 bytes each) = 32 bytes, which produces 8 float32 elements (not 4), and the shape is not preserved (different itemsize path yields 1D). The test body is correct; only the comment is wrong.

Important Files Changed

Filename Overview
physicsnemo/domain_parallel/shard_utils/view_ops.py New file implementing sharded view/reshape/dtype-view for ShardTensor. Core logic (_match_view_dim_groups, _find_shard_in_new_dims, _compute_local_view_shape, _compute_view_placements) is well-structured. Autograd support via ShardedView.apply. Registers at both torch_function and torch_dispatch levels. Minor: _sharded_view_dispatch bypasses autograd wrappers but is at the dispatch level (below autograd), so this is intentional. No critical logic bugs found.
physicsnemo/domain_parallel/shard_utils/unary_ops.py Refactors unsqueeze_wrapper from dispatch-level to function-level, adds _unsqueeze_dispatch as separate dispatch handler, and registers at both torch_function and torch_dispatch levels. The dim unpacking from torch_function args is correct. The dim default of 0 when not in args or kwargs is a safe fallback.
physicsnemo/domain_parallel/shard_tensor.py Adds _dispatch_registry_by_name as a fallback name-based lookup for dispatch handlers when OpOverload identity differs (e.g. cross-module or version differences). Also populates this dict on register_dispatch_handler. Small, focused, backward-compatible change.
physicsnemo/domain_parallel/shard_utils/init.py Adds import of reshape_wrapper and view_wrapper from view_ops inside register_shard_wrappers(). Triggers module-level registrations in view_ops.py at call time. Simple, correct change.
test/domain_parallel/ops/test_view_ops.py New test file for view/reshape operations. Tests cover merge/split dims, -1 inference, round-trip, dtype reinterpretation, and sharding on non-viewed dims. Includes both forward and backward (gradient) checks. One test comment is misleading: the 'int64_to_float32' parametrize entry incorrectly says '4 int64 -> 4 float32 (shape same, bytes 32)' but should be '4 int64 -> 8 float32'. The ndim==1 assertion in test_view_dtype also only passes by coincidence for same-itemsize cases because inputs happen to be 1D.
test/domain_parallel/ops/test_unary_ops.py Fixes two bugs in assertions: both compared i_sharded_unsqueeze.shape with itself (tautology) rather than with i_unsharded_unsqueeze.shape. Now correctly compares shard and non-shard unsqueeze shapes.
test/domain_parallel/test_function_registration.py Registers aten.add.default handler alongside aten.add.Tensor for PyTorch version compatibility, and improves failure message for torch_dispatch_paths assertion. Minor doc fix for add_wrapper comment. No issues.
test/domain_parallel/models/test_sharded_domino.py Replaces inlined model_params dataclass with DEFAULT_MODEL_PARAMS from the actual DoMINO config, with processor_type overridden to 'conv' for faster tests. Cleaner and more maintainable. The generate_synthetic_data function now accepts and uses a config object.

Last reviewed commit: fa9c721

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@coreyjadams
Copy link
Collaborator Author

/blossom-ci

Copy link
Collaborator

@peterdsharpe peterdsharpe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good. A few minor things to consider left in the comments, but not enough to block merge. Nice work, particularly in the DoMINO cleanup.

@coreyjadams
Copy link
Collaborator Author

/blossom-ci

@coreyjadams coreyjadams added this pull request to the merge queue Feb 14, 2026
@coreyjadams coreyjadams deleted the shard_tensor_view_and_squeeze_fixes-clean branch February 14, 2026 21:55
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to a manual request Feb 14, 2026
@coreyjadams coreyjadams restored the shard_tensor_view_and_squeeze_fixes-clean branch February 17, 2026 20:08
@coreyjadams coreyjadams reopened this Feb 17, 2026
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@coreyjadams coreyjadams added this pull request to the merge queue Feb 17, 2026
Merged via the queue into NVIDIA:main with commit 1708820 Feb 17, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants