-
Notifications
You must be signed in to change notification settings - Fork 67
Have the pointwise scheduler process the Block Quantization Op. #5322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5805ff6
9a89804
6e3fd37
ad7b3db
ba21a4c
1a94cdc
138342b
8d8c115
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -10,8 +10,6 @@ | |||||||
|
|
||||||||
| #include <gmock/gmock-matchers.h> | ||||||||
| #include <gtest/gtest.h> | ||||||||
| #include <iomanip> | ||||||||
| #include <iostream> | ||||||||
|
|
||||||||
| #include <fusion.h> | ||||||||
| #include <ops/all_ops.h> | ||||||||
|
|
@@ -110,7 +108,10 @@ constexpr double F8E4M3_MAX = 448.0; | |||||||
| class NVFP4QuantizeTest : public BlackwellBase, | ||||||||
| public ::testing::WithParamInterface<DataType> {}; | ||||||||
| namespace { | ||||||||
| void createNVFP4QunatizationFusion(Fusion* fusion, DataType data_hp_dtype) { | ||||||||
| void createNVFP4QunatizationFusion( | ||||||||
| Fusion* fusion, | ||||||||
| DataType data_hp_dtype, | ||||||||
| bool swizzle_output = false) { | ||||||||
| auto tv_data_hp = makeContigTensor(2, data_hp_dtype); | ||||||||
| fusion->addInput(tv_data_hp); | ||||||||
|
|
||||||||
|
|
@@ -145,6 +146,28 @@ void createNVFP4QunatizationFusion(Fusion* fusion, DataType data_hp_dtype) { | |||||||
|
|
||||||||
| fusion->addOutput(tv_block_scale_fp8); | ||||||||
| fusion->addOutput(tv_data_lp); | ||||||||
|
|
||||||||
| if (swizzle_output) { | ||||||||
| tv_block_scale_fp8->split(0, 128); | ||||||||
| // m/128, 128, k | ||||||||
| tv_block_scale_fp8->split(1, 32); | ||||||||
| // m/128, 4(m_o), 32(m_i), k | ||||||||
| tv_block_scale_fp8->split(3, 4); | ||||||||
| // m/128, 4(m_o), 32(m_i), k/4, 4(k) | ||||||||
| std::vector<IterDomain*> tv_block_scale_fp8_alloc{ | ||||||||
| tv_block_scale_fp8->axis(0), | ||||||||
| tv_block_scale_fp8->axis(3), | ||||||||
| tv_block_scale_fp8->axis(2), | ||||||||
| tv_block_scale_fp8->axis(1), | ||||||||
| tv_block_scale_fp8->axis(4)}; | ||||||||
| // m/128, k/4, 32(m_i), 4(m_o), 4(k) | ||||||||
| tv_block_scale_fp8->setAllocationDomain(tv_block_scale_fp8_alloc, true); | ||||||||
|
|
||||||||
| // back to a 2D logical domain. | ||||||||
| tv_block_scale_fp8->merge(0); | ||||||||
| tv_block_scale_fp8->merge(0); | ||||||||
| tv_block_scale_fp8->merge(-1); | ||||||||
| } | ||||||||
| } | ||||||||
| } // namespace | ||||||||
|
|
||||||||
|
|
@@ -414,6 +437,137 @@ TEST_F(BQTest, ScheduleAsPointwise2D) { | |||||||
| EXPECT_EQ(quantized_tensor_output.dim(), 3); | ||||||||
| } | ||||||||
|
|
||||||||
| TEST_F(BQTest, AutoScheduleSingleOpWithSwizzle) { | ||||||||
| const int m = 1024; | ||||||||
| const int n = 1024; | ||||||||
|
|
||||||||
| std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>(); | ||||||||
| FusionGuard fg(fusion.get()); | ||||||||
| createNVFP4QunatizationFusion( | ||||||||
| fusion.get(), DataType::Float, /*swizzled*/ true); | ||||||||
|
|
||||||||
| FusionExecutorCache fec(std::move(fusion)); | ||||||||
|
|
||||||||
| std::vector<at::Tensor> inputs; | ||||||||
| inputs.push_back(at::randn({m, n}, at::device(at::kCUDA).dtype(at::kFloat))); | ||||||||
| auto outputs_baseline = fec.runFusionWithInputs(inputs); | ||||||||
|
|
||||||||
| // Print baseline outputs | ||||||||
| auto baseline_block_scales = outputs_baseline[0].as<at::Tensor>(); | ||||||||
| auto baseline_quantized_tensor = outputs_baseline[1].as<at::Tensor>(); | ||||||||
|
|
||||||||
| // Move baseline tensors from GPU to CPU | ||||||||
| auto baseline_block_scales_cpu = baseline_block_scales.cpu(); | ||||||||
| auto baseline_quantized_tensor_cpu = baseline_quantized_tensor.cpu(); | ||||||||
|
|
||||||||
| const uint8_t* baseline_block_scales_data = | ||||||||
| static_cast<const uint8_t*>(baseline_block_scales_cpu.data_ptr()); | ||||||||
| const uint8_t* baseline_quantized_data = | ||||||||
| static_cast<const uint8_t*>(baseline_quantized_tensor_cpu.data_ptr()); | ||||||||
|
|
||||||||
| std::unique_ptr<Fusion> fusion_new_op = std::make_unique<Fusion>(); | ||||||||
| FusionGuard fg2(fusion_new_op.get()); | ||||||||
|
|
||||||||
| auto tv_in_1 = makeContigTensor(2, DataType::Float); | ||||||||
| fusion_new_op->addInput(tv_in_1); | ||||||||
|
|
||||||||
| // t0 is 2D | ||||||||
| auto quantization_results = blockQuantize(tv_in_1); | ||||||||
|
|
||||||||
| // outputs are 3D | ||||||||
| fusion_new_op->addOutput(quantization_results.block_scales); | ||||||||
| fusion_new_op->addOutput(quantization_results.quantized_tensor); | ||||||||
|
|
||||||||
| auto temp_loop_domain = quantization_results.block_scales->getLoopDomain(); | ||||||||
|
|
||||||||
| quantization_results.block_scales->split(0, 128); | ||||||||
| // m/128, 128, k | ||||||||
| quantization_results.block_scales->split(1, 32); | ||||||||
| // m/128, 4(m_o), 32(m_i), k | ||||||||
| quantization_results.block_scales->split(3, 4); | ||||||||
| // m/128, 4(m_o), 32(m_i), k/4, 4(k) | ||||||||
| std::vector<IterDomain*> tv_block_scale_fp8_alloc{ | ||||||||
| quantization_results.block_scales->axis(0), | ||||||||
| quantization_results.block_scales->axis(3), | ||||||||
| quantization_results.block_scales->axis(2), | ||||||||
| quantization_results.block_scales->axis(1), | ||||||||
| quantization_results.block_scales->axis(4), | ||||||||
| quantization_results.block_scales->axis(5)}; | ||||||||
|
Comment on lines
+494
to
+495
|
||||||||
| quantization_results.block_scales->axis(4), | |
| quantization_results.block_scales->axis(5)}; | |
| quantization_results.block_scales->axis(4)}; |
Copilot
AI
Oct 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove commented-out code that is not being used.
| // auto t_out = set(quantization_results.quantized_tensor); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'Qunatization' to 'Quantization'.