Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions csrc/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,17 @@ bool PointWiseScheduler::canScheduleCompileTime(Fusion* fusion) {
return false;
}

// The block scales output of the Block Quantization Op
// should be a segment output as it is written to the global
// memory.
if (registry_utils::hasNonTerminalBlockQuantizeOp(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(),
"no support for block quantization where block scales is not a fusion "
"output");
return false;
}

return true;
}

Expand Down Expand Up @@ -1234,6 +1245,13 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
}
}
}

auto bq_ops = ir_utils::getOpsOfType<BlockQuantizationOp>(fusion);
for (auto bq_op : bq_ops) {
vectorized_tvs.emplace_back(bq_op->quantizedOutput()->as<TensorView>());
vectorized_tvs.emplace_back(bq_op->blockScales()->as<TensorView>());
}

if (!vectorized_tvs.empty()) {
// Aggressively mark with vectorized and cleanup later. That way we
// don't have to manually specify parallelization outside the reference.
Expand Down
13 changes: 13 additions & 0 deletions csrc/scheduler/registry_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,19 @@ PrimDataType getIndexTypeOfKernel(
return PrimDataType::Int32;
}

bool hasNonTerminalBlockQuantizeOp(Fusion* fusion) {
for (auto expr : fusion->exprs()) {
if (expr->isA<BlockQuantizationOp>()) {
auto block_scales =
expr->as<BlockQuantizationOp>()->blockScales()->as<TensorView>();
if (!block_scales->isFusionOutput()) {
return true;
}
}
}
return false;
}

bool SchedulerTopologyChecker::hasNonNormalizePostReductionBCast(
Fusion* fusion) {
auto all_vals = fusion->usedMathVals();
Expand Down
4 changes: 4 additions & 0 deletions csrc/scheduler/registry_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ bool checkPatternEquivalence(
// hard to optimize problem and likely indicates we shouldn't be fusing.
bool hasNonUniqueBcast(Fusion* fusion);

// Check to see if the block scales output of Block Quantization Op
// is a segment output.
bool hasNonTerminalBlockQuantizeOp(Fusion* fusion);

// TODO: remove this requirement entirely
bool rejectScheduleForMemoryPromotion(
Fusion* fusion,
Expand Down
16 changes: 14 additions & 2 deletions csrc/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,17 @@ std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll) {
return cached_inputs;
}

namespace {
bool isBlockScaleOutput(TensorView* tv) {
if (tv->definition() == nullptr ||
!tv->definition()->isA<BlockQuantizationOp>()) {
return false;
}
auto bq_op = tv->definition()->as<BlockQuantizationOp>();
return bq_op->blockScales() == tv;
}
} // namespace

// Returns the pairs of <cache of each fusion output, corresponding output> for
// all outputs.
std::vector<std::pair<TensorView*, TensorView*>> cacheAndForkOutputs(
Expand All @@ -1341,8 +1352,9 @@ std::vector<std::pair<TensorView*, TensorView*>> cacheAndForkOutputs(
for (auto output : ir_utils::filterByType<TensorView>(fusion->outputs())) {
if (output->definition() == nullptr ||
// the output of ScatterOp must on the global memory due to the random
// or atomic access.
output->definition()->isA<ScatterOp>()) {
// or atomic access. We write back the block scale output of block
// scaling to global memory.
output->definition()->isA<ScatterOp>() || isBlockScaleOutput(output)) {
continue;
}
if (!output->uses().empty()) {
Expand Down
22 changes: 21 additions & 1 deletion csrc/scheduler/vectorize_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ Val* commonOrConstExtent(
return ca_map->getConcreteMappedID(id, IdMappingMode::ALMOSTEXACT)->extent();
}

// Is the TV the block scales output of quantization.
bool isTvBlockScalesOutputOfBlockQuantization(const TensorView* tv) {
if (!tv->definition() || !tv->definition()->isA<BlockQuantizationOp>()) {
return false;
}
auto bq_op = tv->definition()->as<BlockQuantizationOp>();
return bq_op->blockScales() == tv;
}

} // namespace

Val* ContiguousInnerDimensionsMapper::isFullyProjected(IterDomain* id) {
Expand Down Expand Up @@ -807,8 +816,19 @@ Val* ContiguousInnerDimensionsMapper::getContigMergeOfInnerSize(
{alloc_iid});
IterDomain* logical_id = alloc_iid;
Val* num_devices = of_tv->container()->oneVal();
auto is_block_scales_output =
isTvBlockScalesOutputOfBlockQuantization(of_tv);
for (Expr* expr : exprs | std::views::reverse) {
validateDeviceSplit(expr);
if (is_block_scales_output) {
NVF_ERROR(
expr->isA<Split>(),
"alloc domain of block quantization should only have splits");
if (!expr->as<Split>()->outer()->isDeviceDim()) {
continue;
}
} else {
validateDeviceSplit(expr);
}
auto* split = expr->as<Split>();
logical_id = split->in();
num_devices = SimplifyingIrBuilder::mulExpr(num_devices, split->factor());
Expand Down
160 changes: 157 additions & 3 deletions tests/cpp/test_low_precision_recipe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
#include <iomanip>
#include <iostream>

#include <fusion.h>
#include <ops/all_ops.h>
Expand Down Expand Up @@ -110,7 +108,10 @@ constexpr double F8E4M3_MAX = 448.0;
class NVFP4QuantizeTest : public BlackwellBase,
public ::testing::WithParamInterface<DataType> {};
namespace {
void createNVFP4QunatizationFusion(Fusion* fusion, DataType data_hp_dtype) {
void createNVFP4QunatizationFusion(
Copy link

Copilot AI Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'Qunatization' to 'Quantization'.

Suggested change
void createNVFP4QunatizationFusion(
void createNVFP4QuantizationFusion(

Copilot uses AI. Check for mistakes.
Fusion* fusion,
DataType data_hp_dtype,
bool swizzle_output = false) {
auto tv_data_hp = makeContigTensor(2, data_hp_dtype);
fusion->addInput(tv_data_hp);

Expand Down Expand Up @@ -145,6 +146,28 @@ void createNVFP4QunatizationFusion(Fusion* fusion, DataType data_hp_dtype) {

fusion->addOutput(tv_block_scale_fp8);
fusion->addOutput(tv_data_lp);

if (swizzle_output) {
tv_block_scale_fp8->split(0, 128);
// m/128, 128, k
tv_block_scale_fp8->split(1, 32);
// m/128, 4(m_o), 32(m_i), k
tv_block_scale_fp8->split(3, 4);
// m/128, 4(m_o), 32(m_i), k/4, 4(k)
std::vector<IterDomain*> tv_block_scale_fp8_alloc{
tv_block_scale_fp8->axis(0),
tv_block_scale_fp8->axis(3),
tv_block_scale_fp8->axis(2),
tv_block_scale_fp8->axis(1),
tv_block_scale_fp8->axis(4)};
// m/128, k/4, 32(m_i), 4(m_o), 4(k)
tv_block_scale_fp8->setAllocationDomain(tv_block_scale_fp8_alloc, true);

// back to a 2D logical domain.
tv_block_scale_fp8->merge(0);
tv_block_scale_fp8->merge(0);
tv_block_scale_fp8->merge(-1);
}
}
} // namespace

Expand Down Expand Up @@ -414,6 +437,137 @@ TEST_F(BQTest, ScheduleAsPointwise2D) {
EXPECT_EQ(quantized_tensor_output.dim(), 3);
}

TEST_F(BQTest, AutoScheduleSingleOpWithSwizzle) {
const int m = 1024;
const int n = 1024;

std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
createNVFP4QunatizationFusion(
fusion.get(), DataType::Float, /*swizzled*/ true);

FusionExecutorCache fec(std::move(fusion));

std::vector<at::Tensor> inputs;
inputs.push_back(at::randn({m, n}, at::device(at::kCUDA).dtype(at::kFloat)));
auto outputs_baseline = fec.runFusionWithInputs(inputs);

// Print baseline outputs
auto baseline_block_scales = outputs_baseline[0].as<at::Tensor>();
auto baseline_quantized_tensor = outputs_baseline[1].as<at::Tensor>();

// Move baseline tensors from GPU to CPU
auto baseline_block_scales_cpu = baseline_block_scales.cpu();
auto baseline_quantized_tensor_cpu = baseline_quantized_tensor.cpu();

const uint8_t* baseline_block_scales_data =
static_cast<const uint8_t*>(baseline_block_scales_cpu.data_ptr());
const uint8_t* baseline_quantized_data =
static_cast<const uint8_t*>(baseline_quantized_tensor_cpu.data_ptr());

std::unique_ptr<Fusion> fusion_new_op = std::make_unique<Fusion>();
FusionGuard fg2(fusion_new_op.get());

auto tv_in_1 = makeContigTensor(2, DataType::Float);
fusion_new_op->addInput(tv_in_1);

// t0 is 2D
auto quantization_results = blockQuantize(tv_in_1);

// outputs are 3D
fusion_new_op->addOutput(quantization_results.block_scales);
fusion_new_op->addOutput(quantization_results.quantized_tensor);

auto temp_loop_domain = quantization_results.block_scales->getLoopDomain();

quantization_results.block_scales->split(0, 128);
// m/128, 128, k
quantization_results.block_scales->split(1, 32);
// m/128, 4(m_o), 32(m_i), k
quantization_results.block_scales->split(3, 4);
// m/128, 4(m_o), 32(m_i), k/4, 4(k)
std::vector<IterDomain*> tv_block_scale_fp8_alloc{
quantization_results.block_scales->axis(0),
quantization_results.block_scales->axis(3),
quantization_results.block_scales->axis(2),
quantization_results.block_scales->axis(1),
quantization_results.block_scales->axis(4),
quantization_results.block_scales->axis(5)};
Comment on lines +494 to +495
Copy link

Copilot AI Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accessing axis(5) when only 5 axes (0-4) exist after the splits. The block_scales tensor has 3 dimensions initially and after splits should have indices 0-4.

Suggested change
quantization_results.block_scales->axis(4),
quantization_results.block_scales->axis(5)};
quantization_results.block_scales->axis(4)};

Copilot uses AI. Check for mistakes.
// m/128, k/4, 32(m_i), 4(m_o), 4(k)
quantization_results.block_scales->setAllocationDomain(
tv_block_scale_fp8_alloc, true);

quantization_results.block_scales->setLoopDomain(temp_loop_domain);

FusionExecutorCache executor_cache(std::move(fusion_new_op));
auto outputs_new_op = executor_cache.runFusionWithInputs(inputs);

// Verify we got the expected outputs
auto block_scales_output = outputs_new_op[0].as<at::Tensor>();
auto quantized_tensor_output = outputs_new_op[1].as<at::Tensor>();

// Move tensors from GPU to CPU
auto block_scales_cpu = block_scales_output.cpu();
auto quantized_tensor_cpu = quantized_tensor_output.cpu();

auto block_scales_bytes = (m * n) / block_size;
auto quantized_tensor_bytes = (m * n) / 2;

const uint8_t* block_scales_data =
static_cast<const uint8_t*>(block_scales_cpu.data_ptr());
for (int i = 0; i < block_scales_bytes; ++i) {
EXPECT_EQ(
block_scales_data[i],
baseline_block_scales_data[i]); // Compare with baseline
}

const uint8_t* quantized_data =
static_cast<const uint8_t*>(quantized_tensor_cpu.data_ptr());
for (int i = 0; i < quantized_tensor_bytes; ++i) {
EXPECT_EQ(
quantized_data[i],
baseline_quantized_data[i]); // Compare with baseline
}
}

TEST_F(BQTest, AutoScheduleMultipleOps) {
const int m = 1024;
const int n = 1024;
const int k = 16;
std::vector<at::Tensor> inputs;
inputs.push_back(at::randn({n, k}, at::device(at::kCUDA).dtype(at::kFloat)));
inputs.push_back(
at::randn({m, n, k}, at::device(at::kCUDA).dtype(at::kFloat)));

std::unique_ptr<Fusion> fusion_new_op = std::make_unique<Fusion>();
FusionGuard fg2(fusion_new_op.get());

auto tv_in_1 = makeContigTensor(2, DataType::Float);
auto tv_in_2 = makeContigTensor(3, DataType::Float);
fusion_new_op->addInput(tv_in_1);
fusion_new_op->addInput(tv_in_2);

// tv_in_1 = broadcast(tv_in_1, {false, false, true});

// t0 is 2D
auto t_add = add(tv_in_1, tv_in_2);
auto t_relu = relu(t_add);
auto quantization_results = blockQuantize(t_relu);
quantization_results.quantized_tensor->setMemoryType(MemoryType::Local);
// auto t_out = set(quantization_results.quantized_tensor);
Copy link

Copilot AI Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove commented-out code that is not being used.

Suggested change
// auto t_out = set(quantization_results.quantized_tensor);

Copilot uses AI. Check for mistakes.

// outputs are 3D
fusion_new_op->addOutput(quantization_results.block_scales);
fusion_new_op->addOutput(quantization_results.quantized_tensor);

FusionExecutorCache executor_cache(std::move(fusion_new_op));
auto outputs_new_op = executor_cache.runFusionWithInputs(inputs);

// Verify we got the expected outputs
auto block_scales_output = outputs_new_op[0].as<at::Tensor>();
auto quantized_tensor_output = outputs_new_op[1].as<at::Tensor>();
}

TEST_P(NVFP4QuantizeTest, SwizzledOuputAndWithoutPerTensorAmax) {
auto data_hp_dtype = GetParam();

Expand Down