Skip to content

Fix zero input shape for bgrad_group_quantize#2854

Merged
vthumbe1503 merged 2 commits intoNVIDIA:mainfrom
vthumbe1503:bug_fix_zero_tensor
Apr 8, 2026
Merged

Fix zero input shape for bgrad_group_quantize#2854
vthumbe1503 merged 2 commits intoNVIDIA:mainfrom
vthumbe1503:bug_fix_zero_tensor

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

Copy link
Copy Markdown
Collaborator

@zhongbozhu zhongbozhu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 8, 2026

Greptile Summary

This PR fixes a crash in bgrad_group_quantize when the input tensor has zero rows by replacing the hard NVTE_CHECK rejection with a graceful early-return path that skips kernel execution and returns a zero-filled dbias. A new test exercises the (0, 1024) input shape to guard against regressions.

Confidence Score: 5/5

Safe to merge — all remaining findings are P2 style/coverage suggestions with no correctness impact.

The fix is minimal and targeted: it replaces a hard error with a graceful zero-return for empty input. The logic is correct, the dbias dtype/device match the input, and a new test covers the fixed case. Only P2 observations remain (early-return placement and incomplete output assertion in the test), neither of which affects correctness.

No files require special attention.

Vulnerabilities

No security concerns identified.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/cast.cpp Removes the NVTE_CHECK that rejected zero-dimension inputs and instead adds an early-return path for empty tensors, returning a zero-filled dbias. The early return is placed after GroupedTensorWrapper construction and create_grouped_tensor, meaning those calls still execute on empty input.
tests/pytorch/test_grouped_tensor.py Adds a new test that passes a (0, 1024) input to bgrad_group_quantize; verifies dbias shape and that all values are zero. Does not validate grouped_output shape/content in the zero-size case.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[bgrad_group_quantize called] --> B[Validate tensor is 2D]
    B --> C[Extract logical_first_dim / logical_last_dim]
    C --> D{empty_input_buffer?\nfirst_dim==0 or last_dim==0}
    D -->|No - original path| E[Check MXFP8 quantizer]
    E --> F[Build GroupedTensorWrapper for input]
    F --> G[create_grouped_tensor → grouped_output_py]
    G --> H[Build GroupedTensorWrapper for dbias]
    H --> I[nvte_group_quantize_dbias workspace query]
    I --> J[nvte_group_quantize_dbias execute]
    J --> K[Return grouped_output_py + dbias_torch]
    D -->|Yes - new early-return path| E2[Check MXFP8 quantizer]
    E2 --> F2[Build GroupedTensorWrapper for input]
    F2 --> G2[create_grouped_tensor → grouped_output_py]
    G2 --> L[Return grouped_output_py +\nzeros dbias shape num_tensors×last_dim]
Loading

Reviews (1): Last reviewed commit: "Merge branch 'main' into bug_fix_zero_te..." | Re-trigger Greptile

Comment on lines +266 to +272
if (empty_input_buffer) {
at::Tensor dbias_torch =
at::zeros({static_cast<int64_t>(num_tensors), static_cast<int64_t>(logical_last_dim)},
tensor.options());
return py::make_tuple(py::reinterpret_borrow<py::object>(grouped_output_py),
py::cast(std::move(dbias_torch)));
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Early return placed after expensive setup calls

The empty_input_buffer early return fires only after GroupedTensorWrapper construction, set_rowwise_data, and create_grouped_tensor have already been called on the empty tensor. If any of those operations have hidden allocations or non-trivial paths on zero-size input, the early return provides no protection. Consider moving the guard to immediately after the shape is known (before GroupedTensorWrapper grouped_input_tensor) and constructing a minimal empty grouped_output_py separately, so the fast path truly skips kernel-setup work.

Comment on lines +431 to +432
assert dbias.shape == (num_tensors, last_dim)
assert torch.all(dbias == 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 grouped_output not validated in zero-size test

The test verifies dbias shape and values but makes no assertions on grouped_output. At minimum, checking that grouped_output is not None (and ideally that its underlying data tensor has shape (0, last_dim)) would ensure the output contract is correct for the empty path, not just the dbias.

Copy link
Copy Markdown
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vthumbe1503 vthumbe1503 merged commit a30a126 into NVIDIA:main Apr 8, 2026
21 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants