Fix FP8 block scaling with sequence parallel#2637
Conversation
Signed-off-by: Chen Cui <chcui@nvidia.com>
Greptile SummaryThis PR successfully fixes FP8 block scaling with sequence parallel by adding fallback handling for non-quantizable tensors. The core changes are sound:
The PR addresses the reported crash when using Float8BlockQuantizer with sequence parallelism and does not alter convergence behavior as verified in the test results. Confidence Score: 5/5
Last reviewed commit: ccd46cb |
This comment was marked as outdated.
This comment was marked as outdated.
Perform all-gather in high-precision if the input tensor is too small to quantize. Signed-off-by: Tim Moon <tmoon@nvidia.com>
timmoon10
left a comment
There was a problem hiding this comment.
I've tweaked the module logic so that it no longer checks that the local tensor can be quantized. In addition to FP8 block scaling, MXFP8 and NVFP4 will also support high-precision all-gathers. That said, this may not be fully robust for LayerNormLinear and LayerNormMLP, since the norm kernel may still expect to output to a quantized local tensor.
This comment was marked as outdated.
This comment was marked as outdated.
|
/te-ci pytorch L1 |
| out = torch.empty(out_shape, dtype=dtype, device=device) | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize() # Dequantize if needed |
There was a problem hiding this comment.
The dequantize() method defaults to dtype=torch.float32. Consider whether this is always appropriate for the fallback path, especially when the original tensor might have been in a different precision (e.g., bfloat16).
| ): | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize() # Dequantize if needed |
There was a problem hiding this comment.
Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.
| ): | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize() # Dequantize if needed |
There was a problem hiding this comment.
Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
|
/te-ci pytorch L1 |
| if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1: | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize(dtype=dtype) # Dequantize if needed | ||
| out = torch.empty(out_shape, dtype=dtype, device=device) | ||
| torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) | ||
| out = quantizer(out) |
There was a problem hiding this comment.
Non-contiguous gather input
In the new high-precision fallback (if not quantizer.is_quantizable(inp) ...), all_gather_into_tensor(out, inp, ...) passes inp directly. Elsewhere in this same module the plain-tensor path uses inp.contiguous() (distributed.py:1737-1742) and the FP8 path uses _data.contiguous() (distributed.py:1031-1035), which strongly suggests the collective expects contiguous inputs. If inp is a non-contiguous view (common after transpose/slicing), this fallback can raise at runtime. This same issue also appears in the NVFP4 and MXFP8 high-precision fallbacks (distributed.py:1353 and :1523).
|
/te-ci pytorch |
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize(dtype=dtype) # Dequantize if needed | ||
| out = torch.empty(out_shape, dtype=dtype, device=device) | ||
| torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) |
There was a problem hiding this comment.
Missing .contiguous() call on inp before all-gather
Other all-gather paths in this file use .contiguous() (lines 1739, 1033). Non-contiguous tensors (from transpose/slicing) can cause runtime errors.
| torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) | |
| torch.distributed.all_gather_into_tensor(out, inp.contiguous(), group=process_group, async_op=False) |
Additional Comments (2)
|
|
/te-ci pytorch |
Description
Problem
Using Float8BlockQuantizer with sequence parallel fails with
AssertionError: All-gather requires quantizable tensor for quantizer Float8BlockQuantizerwhen local tensor dimensions aren't divisible by 128.Solution
Skip the assert_dim_for_all_gather check for Float8BlockQuantizer since gather_along_first_dim already has a fallback path
Fix the fallback in _start_all_gather_fp8_blockwise to handle already-quantized inputs by dequantizing before high-precision all-gather
###Note
The fallback path (high-precision all-gather → quantize) may increase the communication overhead.
Verification
The code change does not alter convergence behavior

When SP is True, the previous code did not run. When SP is False, this change doesn't affect anything.

Type of change
Changes
Please list the changes introduced in this PR:
Checklist: