Skip to content

Mxfp8 grouped and multi quantize#598

Open
alextmagro wants to merge 1 commit into
IFU-dev-20260315-v2.14from
mxfp8_grouped_quantize
Open

Mxfp8 grouped and multi quantize#598
alextmagro wants to merge 1 commit into
IFU-dev-20260315-v2.14from
mxfp8_grouped_quantize

Conversation

@alextmagro
Copy link
Copy Markdown
Contributor

Introduces grouped and multi quantize kernels for MXFP8. Grouped kernel requires later optimization, multi is a stand in replacement for MoE models

Improvements to quantize kernels via nontemporal stores
rocm quantize kernel body moved to .inc file to avoid repeated code. forceinline causes register spills and hurts performance for grouped/multi kernels
Branched off of IFU 2.14, merge target will be updated after IFU merge

Single Kernel Optimization Results

CAST_ONLY

IType Scaling Shape Baseline (us) Optimized (us) Speedup B200 (us) vs B200
FP32 rowwise 8192×8192 68.43 56.7 1.21x 60.57 107%
BF16 rowwise 8192×8192 41.63 38.0 1.10x 34.84 92%
FP16 rowwise 8192×8192 40.87 38.1 1.07x 34.94 92%
FP32 colwise 8192×8192 76.85 62.4 1.23x 55.77 89%
BF16 colwise 8192×8192 47.88 41.6 1.15x 39.61 95%
FP16 colwise 8192×8192 46.43 41.0 1.13x 39.10 95%
FP32 both 8192×8192 100.76 89.6 1.12x 76.95 86%
BF16 both 8192×8192 86.60 56.6 1.53x 52.34 92%
FP16 both 8192×8192 82.40 56.4 1.46x 51.16 91%
FP32 rowwise 16384×8192 132.64 129.0 1.03x 114.42 89%
BF16 rowwise 16384×8192 83.00 74.0 1.12x 66.94 90%
FP16 rowwise 16384×8192 84.04 72.4 1.16x 67.58 93%
FP32 colwise 16384×8192 132.60 125.0 1.06x 103.54 83%
BF16 colwise 16384×8192 91.24 77.0 1.18x 72.57 94%
FP16 colwise 16384×8192 89.20 76.7 1.16x 71.86 94%
FP32 both 16384×8192 175.92 173.0 1.02x 146.52 85%
BF16 both 16384×8192 147.40 122.0 1.21x 98.68 81%
FP16 both 16384×8192 140.16 119.0 1.18x 96.44 81%
FP32 rowwise 8192×28672 247.40 226.0 1.09x 194.30 86%
BF16 rowwise 8192×28672 145.28 140.0 1.04x 111.58 80%
FP16 rowwise 8192×28672 145.92 138.0 1.06x 111.67 81%
FP32 colwise 8192×28672 231.40 216.0 1.07x 175.54 81%
BF16 colwise 8192×28672 152.92 138.0 1.11x 122.26 89%
FP16 colwise 8192×28672 148.40 141.0 1.05x 120.82 86%
FP32 both 8192×28672 302.88 307.0 0.99x 251.93 82%
BF16 both 8192×28672 244.68 200.0 1.22x 167.42 84%
FP16 both 8192×28672 229.12 201.0 1.14x 163.64 81%
Average 1.14x 88%

DBIAS_DACT (GELU)

IType Scaling Shape Baseline (us) Optimized (us) Speedup B200 (us) vs B200
FP32 rowwise 8192×8192 149.04 147.0 1.01x 145.24 99%
BF16 rowwise 8192×8192 156.98 136.0 1.15x 128.79 95%
FP16 rowwise 8192×8192 139.99 126.0 1.11x 123.77 98%
FP32 colwise 8192×8192 190.28 194.0 0.98x 134.90 70%
BF16 colwise 8192×8192 226.50 170.0 1.33x 126.33 74%
FP16 colwise 8192×8192 178.74 161.0 1.11x 126.01 78%
FP32 both 8192×8192 274.71 202.0 1.36x 188.57 93%
BF16 both 8192×8192 275.38 163.0 1.69x 190.58 117%
FP16 both 8192×8192 270.46 148.0 1.83x 175.13 118%
FP32 rowwise 16384×8192 265.77 276.0 0.96x 280.86 102%
BF16 rowwise 16384×8192 289.15 279.0 1.04x 254.04 91%
FP16 rowwise 16384×8192 258.33 258.0 1.00x 235.61 91%
FP32 colwise 16384×8192 347.10 373.0 0.93x 258.49 69%
BF16 colwise 16384×8192 409.64 324.0 1.26x 247.77 76%
FP16 colwise 16384×8192 333.00 314.0 1.06x 247.86 79%
FP32 both 16384×8192 496.17 396.0 1.25x 363.93 92%
BF16 both 16384×8192 497.76 315.0 1.58x 371.96 118%
FP16 both 16384×8192 500.30 282.0 1.77x 358.68 127%
FP32 rowwise 8192×28672 475.55 447.0 1.06x 478.84 107%
BF16 rowwise 8192×28672 489.49 437.0 1.12x 433.14 99%
FP16 rowwise 8192×28672 432.49 392.0 1.10x 404.76 103%
FP32 colwise 8192×28672 610.30 603.0 1.01x 438.36 73%
BF16 colwise 8192×28672 694.42 509.0 1.36x 428.98 84%
FP16 colwise 8192×28672 578.20 504.0 1.15x 431.09 86%
FP32 both 8192×28672 860.63 658.0 1.31x 620.82 94%
BF16 both 8192×28672 873.07 504.0 1.73x 621.24 123%
FP16 both 8192×28672 854.57 448.0 1.91x 613.88 137%
Average 1.27x 96%

Gated SwiGLU (FWD)

IType Scaling Shape Baseline (us) Optimized (us) Speedup B200 (us) vs B200
FP32 rowwise 8192×28672 482.73 435.0 1.11x 477.94 110%
BF16 rowwise 8192×28672 302.57 290.0 1.04x 366.04 126%
FP16 rowwise 8192×28672 288.29 277.0 1.04x 361.94 131%
FP32 colwise 8192×28672 469.78 421.0 1.12x 380.53 90%
BF16 colwise 8192×28672 313.24 302.0 1.04x 338.78 112%
FP16 colwise 8192×28672 299.32 300.0 1.00x 337.62 113%
FP32 both 8192×28672 551.40 528.0 1.04x 733.97 139%
BF16 both 8192×28672 420.99 413.0 1.02x 534.10 129%
FP16 both 8192×28672 426.77 414.0 1.03x 537.11 130%
Average 1.05x 120%

Gated SwiGLU (BWD / dSwiGLU)

IType Scaling Shape Baseline (us) Optimized (us) Speedup B200 (us) vs B200
FP32 rowwise 8192×28672 753.29 699.0 1.08x 913.30 131%
BF16 rowwise 8192×28672 467.38 463.0 1.01x 569.05 123%
FP16 rowwise 8192×28672 470.34 436.0 1.08x 547.19 126%
FP32 colwise 8192×28672 717.36 692.0 1.04x 701.78 101%
BF16 colwise 8192×28672 647.80 655.0 0.99x 489.01 75%
FP16 colwise 8192×28672 472.62 541.0 0.87x 472.18 87%
FP32 both 8192×28672 1093.58 930.0 1.18x 1342.00 144%
BF16 both 8192×28672 919.08 798.0 1.15x 969.56 121%
FP16 both 8192×28672 878.03 736.0 1.19x 928.88 126%
Average 1.07x 115%

Grouped Quantize Kernel

Rowwise

Config Distribution Per-Tensor (us) Grouped (us) Speedup BW (TiB/s) % Peak
Qwen3 128exp H=4096 balanced 857 250 3.43x 3.31 45%
Qwen3 128exp I=1536 balanced 615 96 6.41x 3.21 44%
DeepSeek 256exp H=7168 balanced 1144 822 1.39x 3.55 49%
DeepSeek 256exp I=2048 balanced 1112 248 4.48x 3.37 46%
Qwen3 128exp H=4096 skewed 576 248 2.32x 3.31 45%
Qwen3 128exp I=1536 skewed 585 97 6.03x 3.21 44%
DeepSeek 256exp H=7168 skewed 1132 821 1.38x 3.54 49%
DeepSeek 256exp I=2048 skewed 1124 245 4.59x 3.36 46%
Average 3.75x 3.36 46%

Colwise

Config Distribution Per-Tensor (us) Grouped (us) Speedup BW (TiB/s) % Peak
Qwen3 128exp H=4096 balanced 587 265 2.22x 3.10 43%
Qwen3 128exp I=1536 balanced 591 104 5.68x 2.97 41%
DeepSeek 256exp H=7168 balanced 1134 878 1.29x 3.31 45%
DeepSeek 256exp I=2048 balanced 1108 268 4.13x 3.11 43%
Qwen3 128exp H=4096 skewed 590 266 2.22x 3.10 43%
Qwen3 128exp I=1536 skewed 423 104 4.07x 2.98 41%
DeepSeek 256exp H=7168 skewed 965 873 1.11x 3.31 45%
DeepSeek 256exp I=2048 skewed 816 266 3.07x 3.10 43%
Average 2.97x 3.12 43%

Both (Rowwise + Colwise)

Config Distribution Per-Tensor (us) Grouped (us) Speedup BW (TiB/s) % Peak
Qwen3 128exp H=4096 balanced 591 270 2.19x 4.07 56%
Qwen3 128exp I=1536 balanced 426 105 4.06x 3.97 55%
DeepSeek 256exp H=7168 balanced 1582 855 1.85x 4.56 63%
DeepSeek 256exp I=2048 balanced 821 276 2.97x 4.04 56%
Qwen3 128exp H=4096 skewed 652 268 2.43x 4.13 57%
Qwen3 128exp I=1536 skewed 424 104 4.08x 3.97 55%
DeepSeek 256exp H=7168 skewed 1595 849 1.88x 4.56 63%
DeepSeek 256exp I=2048 skewed 819 275 2.98x 4.05 56%
Average 2.80x 4.17 57%

Multi Quantize Kernel

Rowwise

Config Distribution Per-Tensor (us) Multi (us) Speedup BW (TiB/s) % Peak
Qwen3 128exp H=4096 balanced 857 226 3.79x 3.71 51%
Qwen3 128exp I=1536 balanced 615 72 8.54x 4.34 60%
DeepSeek 256exp H=7168 balanced 1144 781 1.46x 3.70 51%
DeepSeek 256exp I=2048 balanced 1112 212 5.25x 3.89 53%
Qwen3 128exp H=4096 skewed 576 263 2.19x 3.14 43%
Qwen3 128exp I=1536 skewed 585 92 6.36x 3.38 46%
DeepSeek 256exp H=7168 skewed 1132 870 1.30x 3.31 45%
DeepSeek 256exp I=2048 skewed 1124 240 4.68x 3.44 47%
Average 4.20x 3.61 50%

Colwise

Config Distribution Per-Tensor (us) Multi (us) Speedup BW (TiB/s) % Peak
Qwen3 128exp H=4096 balanced 587 200 2.94x 4.16 57%
Qwen3 128exp I=1536 balanced 591 79 7.48x 3.95 54%
DeepSeek 256exp H=7168 balanced 1134 695 1.63x 4.16 57%
DeepSeek 256exp I=2048 balanced 1108 207 5.35x 4.01 55%
Qwen3 128exp H=4096 skewed 590 241 2.45x 3.48 48%
Qwen3 128exp I=1536 skewed 423 99 4.27x 3.16 43%
DeepSeek 256exp H=7168 skewed 965 831 1.16x 3.47 48%
DeepSeek 256exp I=2048 skewed 816 236 3.46x 3.51 48%
Average 3.59x 3.74 51%

Both (Rowwise + Colwise)

Config Distribution Per-Tensor (us) Multi (us) Speedup BW (TiB/s) % Peak
Qwen3 128exp H=4096 balanced 591 317 1.86x 3.54 49%
Qwen3 128exp I=1536 balanced 426 108 3.94x 3.81 52%
DeepSeek 256exp H=7168 balanced 1582 1111 1.42x 3.50 48%
DeepSeek 256exp I=2048 balanced 821 316 2.60x 3.52 48%
Qwen3 128exp H=4096 skewed 652 333 1.96x 3.32 46%
Qwen3 128exp I=1536 skewed 424 133 3.19x 3.10 43%
DeepSeek 256exp H=7168 skewed 1595 1162 1.37x 3.34 46%
DeepSeek 256exp I=2048 skewed 819 329 2.49x 3.38 46%
Average 2.35x 3.44 47%

Claude change summary

What Changed

  • Replaced cached Vec::store_to() with NTVec::nt_store() via zero-copy reinterpret_cast for all rowwise output paths
  • bulk_tensor_2d_shared_to_global now uses NT stores for all colwise output paths
  • Applied to both rocm_quantize_mxfp8.cuh and rocm_gated_mxfp8.cuh
  • CDNA4 L2 is write-allocate: cached stores trigger read-for-ownership for write-only streams. NT stores bypass this.
  • NT loads were tested and rejected (-20% to -45% regression on rowwise path due to L2 prefetcher bypass)
  • NT stores for E8M0 scales were tested and rejected (+17.2% average regression on CAST_ONLY). Scales are small (~3% of total I/O) and interspersed with compute — NT stores disrupt write coalescing without meaningful bandwidth savings. Colwise-only was neutral (+0.5%), but rowwise (+25-37%) and both (+21-38%) regressed heavily.
  • Double-buffered LDS (ping-pong in_sh[2], out_colwise_sh[2]) was tested and rejected (+21.7% average regression on CAST_ONLY) — without async global→LDS copies (TDM/TMA), the 2x LDS footprint kills occupancy without providing true load/compute overlap; revisit on MI450 with TDM.
  • Activation caching in bidimensional path: rowwise pass writes activated+IType-rounded values back to in_sh, colwise pass skips IS_ACT/IS_DACT recomputation and act_in_sh reads. -29.8% average on DBIAS_DACT/both (FP16: -44%, BF16: -31%, FP32: -16%). Zero overhead for non-activation paths (constexpr guard).
  • Same activation caching applied to gated kernel (rocm_gated_mxfp8.cuh). Fixed the BF16/FP16 both-mode regression: FWD both BF16 485→413us (-15%), BWD both BF16 1019→798us (-22%). All gated both-mode configs now beat B200.
  • Column-major tile ordering for grouped kernel was tested and rejected — helps colwise (-2 to -7%) but regresses both mode (+7 to +29%) due to hurting L2 locality for the rowwise section.
  • Multi-tensor kernel (nvte_multi_quantize_mxfp8) added for per-tensor pointer API used by Megatron's split_quantize. Uses MultiQuantizeMXFP8Args struct with prefix-sum block_range and 2D grid (blockIdx.x = col tiles, blockIdx.y = row tiles with binary search). Kernel body shared with single/grouped via .inc file inclusion (compiler refused to inline 500-line __device__ __forceinline__ — 5 VGPRs). Balanced distributions: 4.20x rowwise, 3.59x colwise, 2.35x both vs per-tensor baseline. "Both" mode is slower than grouped (2.35x vs 2.80x) due to scattered per-tensor allocations hurting L2 locality — inherent to the per-tensor pointer API vs grouped's contiguous buffer.

Enable group_quantize for AMD and add multi tensor quantize kernel for
mxfp8
@alextmagro alextmagro force-pushed the mxfp8_grouped_quantize branch from fe5fd03 to bcfc909 Compare May 27, 2026 21:44
@alextmagro alextmagro added the ci-level 3 CI test level 3 label May 27, 2026
const int block_id_Y = blockIdx.y;
const int block_id_X = blockIdx.x;
const int dbias_y_offset = blockIdx.y;
#include "rocm_quantize_mxfp8_body.inc"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be inline function instead?

* \param[in] stream CUDA stream used for the operation.
*/
void nvte_multi_quantize_mxfp8(size_t num_tensors, const NVTETensor *input_list,
NVTETensor *output_list, cudaStream_t stream);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is nvte_multi_tensor_quantize() here so probably it should be used instead of separate call

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants