Skip to content

Fix FP8 block scaling with sequence parallel#2637

Merged
ksivaman merged 8 commits intoNVIDIA:mainfrom
cuichenx:chcui/fix_subchannel_fp8+sp
Mar 8, 2026
Merged

Fix FP8 block scaling with sequence parallel#2637
ksivaman merged 8 commits intoNVIDIA:mainfrom
cuichenx:chcui/fix_subchannel_fp8+sp

Conversation

@cuichenx
Copy link
Copy Markdown
Contributor

@cuichenx cuichenx commented Jan 31, 2026

Description

Problem

Using Float8BlockQuantizer with sequence parallel fails with AssertionError: All-gather requires quantizable tensor for quantizer Float8BlockQuantizer when local tensor dimensions aren't divisible by 128.

Solution

Skip the assert_dim_for_all_gather check for Float8BlockQuantizer since gather_along_first_dim already has a fallback path
Fix the fallback in _start_all_gather_fp8_blockwise to handle already-quantized inputs by dequantizing before high-precision all-gather

###Note
The fallback path (high-precision all-gather → quantize) may increase the communication overhead.

Verification

The code change does not alter convergence behavior
image

When SP is True, the previous code did not run. When SP is False, this change doesn't affect anything.
image

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Chen Cui <chcui@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jan 31, 2026

Greptile Summary

This PR successfully fixes FP8 block scaling with sequence parallel by adding fallback handling for non-quantizable tensors. The core changes are sound:

  • Adds high-precision all-gather fallback in _start_all_gather_fp8_blockwise, _all_gather_nvfp4, and _all_gather_mxfp8 for tensors whose dimensions aren't divisible by the quantization block size
  • Fixes .dequantize() calls to explicitly pass dtype=dtype, preserving high-precision types during fallback paths
  • Removes now-redundant assert_dim_for_all_gather checks and function from utils

The PR addresses the reported crash when using Float8BlockQuantizer with sequence parallelism and does not alter convergence behavior as verified in the test results.

Confidence Score: 5/5

  • PR is safe to merge. The fix correctly addresses the reported crash with FP8 block scaling and sequence parallelism through well-scoped fallback mechanisms.
  • All changes are focused and correct. The fallback paths handle non-quantizable tensors appropriately by dequantizing before high-precision all-gather. The explicit dtype parameter ensures correct precision preservation. Module-level cleanup delegating to gather helpers' own fallback logic is sound. No edge cases identified that would cause issues in practice.
  • No files require special attention

Last reviewed commit: ccd46cb

greptile-apps[bot]

This comment was marked as outdated.

@cyanguwa cyanguwa requested a review from timmoon10 February 2, 2026 18:48
@timmoon10

This comment was marked as outdated.

timmoon10

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

@timmoon10 timmoon10 self-requested a review February 2, 2026 19:38
Comment thread transformer_engine/pytorch/utils.py Outdated
Perform all-gather in high-precision if the input tensor is too small to quantize.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

timmoon10
timmoon10 previously approved these changes Feb 2, 2026
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tweaked the module logic so that it no longer checks that the local tensor can be quantized. In addition to FP8 block scaling, MXFP8 and NVFP4 will also support high-precision all-gathers. That said, this may not be fully robust for LayerNormLinear and LayerNormMLP, since the norm kernel may still expect to output to a quantized local tensor.

@timmoon10

This comment was marked as outdated.

@timmoon10
Copy link
Copy Markdown
Collaborator

/te-ci pytorch L1

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

out = torch.empty(out_shape, dtype=dtype, device=device)
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() # Dequantize if needed
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dequantize() method defaults to dtype=torch.float32. Consider whether this is always appropriate for the fallback path, especially when the original tensor might have been in a different precision (e.g., bfloat16).

):
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() # Dequantize if needed
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.

):
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() # Dequantize if needed
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Feb 10, 2026

/te-ci pytorch L1

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 1102 to 1108
if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1:
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize(dtype=dtype) # Dequantize if needed
out = torch.empty(out_shape, dtype=dtype, device=device)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
out = quantizer(out)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-contiguous gather input

In the new high-precision fallback (if not quantizer.is_quantizable(inp) ...), all_gather_into_tensor(out, inp, ...) passes inp directly. Elsewhere in this same module the plain-tensor path uses inp.contiguous() (distributed.py:1737-1742) and the FP8 path uses _data.contiguous() (distributed.py:1031-1035), which strongly suggests the collective expects contiguous inputs. If inp is a non-contiguous view (common after transpose/slicing), this fallback can raise at runtime. This same issue also appears in the NVFP4 and MXFP8 high-precision fallbacks (distributed.py:1353 and :1523).

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Feb 12, 2026

/te-ci pytorch

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize(dtype=dtype) # Dequantize if needed
out = torch.empty(out_shape, dtype=dtype, device=device)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing .contiguous() call on inp before all-gather

Other all-gather paths in this file use .contiguous() (lines 1739, 1033). Non-contiguous tensors (from transpose/slicing) can cause runtime errors.

Suggested change
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
torch.distributed.all_gather_into_tensor(out, inp.contiguous(), group=process_group, async_op=False)

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 12, 2026

Additional Comments (2)

transformer_engine/pytorch/distributed.py
Missing .contiguous() call on inp before all-gather

        torch.distributed.all_gather_into_tensor(out, inp.contiguous(), group=process_group)

transformer_engine/pytorch/distributed.py
Missing .contiguous() call on inp before all-gather

        torch.distributed.all_gather_into_tensor(out, inp.contiguous(), group=process_group)

@ksivaman
Copy link
Copy Markdown
Member

ksivaman commented Mar 7, 2026

/te-ci pytorch

@ksivaman ksivaman merged commit 5fd5c35 into NVIDIA:main Mar 8, 2026
20 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants