Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,100 @@ def test_nvfp4_quantization_noncontiguous_inputs(
torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0)

torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)


@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# Aligned tiles
(128, 128),
(256, 256),
(512, 512),
(2048, 2048),
# Padded tiles (non-multiple of kTileDim=128)
(256, 272),
(304, 304),
(320, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
def test_nvfp4_2d_columnwise_only_matches_both_directions(
x_dtype: torch.dtype,
M: int,
N: int,
):
"""Bitwise check: 2D NVFP4 with columnwise-only must produce the same
columnwise data/scales as the columnwise half of (rowwise + columnwise) 2D.

Exercises the columnwise-only path through the 2D-amax-only pass added to
``quantize_transpose_vector_blockwise_fp4.cu``. Before that change, this
configuration was rejected by
``NVTE_CHECK(return_identity || !use_2d_quantization)``.
"""
te_dtype = tex.DType.kFloat4E2M1
device = "cuda"

torch.manual_seed(0)
torch.cuda.manual_seed(0)
x = torch.randn((M, N), dtype=x_dtype, device=device)

def _make_quantizer(*, rowwise: bool, columnwise: bool) -> NVFP4Quantizer:
return NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=rowwise,
columnwise=columnwise,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=True,
row_scaled_nvfp4=False,
)

# Reference: produce both directions in a single kernel call.
q_both = _make_quantizer(rowwise=True, columnwise=True)
out_both = q_both(x)

# SUT: produce columnwise only (the path that hits the new amax-only pass).
q_col_only = _make_quantizer(rowwise=False, columnwise=True)
out_col_only = q_col_only(x)

# Columnwise data/scales/amax must be bitwise identical between the two paths.
# If amax_smem is populated differently in the column-only path, scales diverge,
# and the FP4 cast (which divides by encode_scale) produces different bytes.
assert out_both._columnwise_data is not None
assert out_col_only._columnwise_data is not None
torch.testing.assert_close(
out_col_only._columnwise_data.view(dtype=torch.uint8),
out_both._columnwise_data.view(dtype=torch.uint8),
atol=0,
rtol=0,
)

# Compare only the valid (in-bounds) region of the columnwise scale tensor.
# The padded tail (rows K..round_up(K, 128), cols ceil(M/16)..round_up(.., 4))
# exists for cuBLAS alignment and is NEVER written by the kernel — its bytes
# are whatever ``at::empty`` returned, which differs between two allocations.
NVFP4_BLOCK = 16
valid_outer = N # cols of input == rows of columnwise scale tensor
valid_inner = (M + NVFP4_BLOCK - 1) // NVFP4_BLOCK
assert out_both._columnwise_scale_inv is not None
assert out_col_only._columnwise_scale_inv is not None
col_sx_both = out_both._columnwise_scale_inv.view(dtype=torch.uint8)
col_sx_col_only = out_col_only._columnwise_scale_inv.view(dtype=torch.uint8)
torch.testing.assert_close(
col_sx_col_only[:valid_outer, :valid_inner],
col_sx_both[:valid_outer, :valid_inner],
atol=0,
rtol=0,
)

assert out_both._amax_columnwise is not None
assert out_col_only._amax_columnwise is not None
torch.testing.assert_close(
out_col_only._amax_columnwise, out_both._amax_columnwise, atol=0, rtol=0
)

# Sanity: column-only path must not allocate a rowwise output.
assert out_col_only._rowwise_data is None
Original file line number Diff line number Diff line change
Expand Up @@ -353,14 +353,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
extern __shared__ char smem_base[];
SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]);

// 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode.
// Instead of static_assert, return early if these invalid modes are detected.
// 2D block scaling is not supported for E8 scaling MXFP4.
// Instead of static_assert, return early if this invalid mode is detected.
if constexpr (kIs2DBlockScaling && kIsE8Scaling) {
return;
}
if constexpr (kIs2DBlockScaling && !kReturnIdentity) {
return;
}
// for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4
// use constexpr to define the size, when not using 2D, use minimal size 1x1
constexpr int kFP4BlockScalingSize = 16;
Expand Down Expand Up @@ -576,6 +573,67 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}

// Step 2.5: 2D-amax-only pass for columnwise-only mode.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Step label collision with existing substep

The new outer-level block is named "Step 2.5" at line 576, but that same label is already used at line 522 for the "Write scale_inv" substep inside Step 2's loop (if constexpr (kReturnIdentity)). A future reader scanning the file will find two distinct "Step 2.5" sections with different semantics. Consider renaming the new block to something like "Step 2b" or "Step 2.5 (outer)" to distinguish it from the // Step 2.5: Write scale_inv substep inside the inner loop.

// When only the transposed output is requested but 2D block scaling is enabled, the columnwise
// reads in Step 3 (line ~660 below) still need amax_smem populated. Re-run the load + local-amax
// + 2D warp/smem reduction from Step 2 (steps 2.1-2.3), skipping the rowwise scale/quantize/store
// writes that Step 2 normally does. Same amax_smem values as the rowwise-enabled path, so the
// dgrad/wgrad columnwise output of (rowwise=False, columnwise=True, 2D) is bitwise identical to
// the columnwise half of (rowwise=True, columnwise=True, 2D).
if constexpr (!kReturnIdentity && kIs2DBlockScaling) {
constexpr int r_stride =
kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride; // 4 iterations for kTileDim=128
const int c_s =
(threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut / kNVecSMem];
// Step 2.1 (amax-only): Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem_vec[i] = smem[r * kSMemCol + c];
}
// Step 2.2 (amax-only): Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j]));
}
}
// Step 2.3 (amax-only): 2D warp + smem amax reduction (mirrors Step 2's 2D path)
constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; // 32
int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7
int tid_in_warp_x = threadIdx.x % kNumThreadsStore;
int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp;
CType amax_warp_reduced = groupMax<kNumRowsPerWarp, kNumThreadsStore>(
amax, WARP_REDUCE_AMAX_GROUP_MASKS[tid_in_warp_x]);
int data_row_idx = iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y;
if (tid_in_warp_y == 0) {
amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x]
[warp_idx % k2DBlockAmaxReduceDim] = amax_warp_reduced;
}
__syncthreads();

if (data_row_idx % kFP4BlockScalingSize == 0) {
CType amax_2d = 0.0;
for (int i = 0; i < k2DBlockAmaxReduceDim; i++) {
amax_2d =
fmaxf(amax_2d, amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]);
}
amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] = amax_2d;
}
__syncthreads();
r_s += r_stride;
}
}

// Step 3: Transpose, cast and store to output_t
if constexpr (kReturnTranspose) {
constexpr int c_stride =
Expand Down Expand Up @@ -731,8 +789,6 @@ void quantize_transpose_vector_blockwise_fp4(
NVTE_CHECK(return_identity || return_transpose,
"At least one of return_identity or return_transpose must be true.");

NVTE_CHECK(return_identity || !use_2d_quantization,
"2D block quantization is only supported when return_identity is true.");
NVTE_CHECK(!row_scaled_nvfp4 || (return_identity && !return_transpose),
"Row-scaled NVFP4 quantization only supports rowwise quantization.");
NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization,
Expand Down
Loading