diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index 33a9b8629..6855c9487 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -76,12 +76,12 @@ void scale_block(const ProcessingMethod processing_method, continue; } amax = std::max(amax, std::abs(elt)); -#else +#else // #ifdef __HIP_PLATFORM_AMD__ if (std::isinf(elt) || std::isnan(elt)) { continue; } amax = fmaxf(amax, fabsf(elt)); -#endif +#endif // #ifdef __HIP_PLATFORM_AMD__ } } @@ -312,6 +312,23 @@ void performTest_x1(const ProcessingMethod processing_method, block_size_cols, scales_stride); + +#ifdef __HIP_PLATFORM_AMD__ + if (processing_method != ProcessingMethod::CAST_ONLY) { + std::vector> mismatch_idx; + compare_e8m0_scaling_factors("scales", output_c, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, 0.01, rowwise, mismatch_idx); + + if (mismatch_idx.size()) { + adjust_ref(mismatch_idx, ref_output_c.get(), unpadded_blocks_Y, unpadded_blocks_X, rows, cols, otype); + } + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol); + } + else +#endif // #ifdef __HIP_PLATFORM_AMD__ + { auto [atol, rtol] = getTolerances(otype); compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol); @@ -321,6 +338,7 @@ void performTest_x1(const ProcessingMethod processing_method, compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + } if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { auto [atol_dbias, rtol_dbias] = getTolerances(itype); @@ -454,7 +472,29 @@ void performTest_x2(const ProcessingMethod processing_method, block_size_cols, scales_stride_rowwise, scales_stride_colwise); +#ifdef __HIP_PLATFORM_AMD__ + if (processing_method != ProcessingMethod::CAST_ONLY) { + std::vector> mismatch_idx_r; + compare_e8m0_scaling_factors("scales_rowwise", output, ref_scales_rowwise.get(), + unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise, 0.01, true, mismatch_idx_r); + + if (mismatch_idx_r.size()) { + adjust_ref(mismatch_idx_r, ref_output_c_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, rows, cols, otype); + } + std::vector> mismatch_idx_c; + compare_e8m0_scaling_factors("scales_colwise", output, ref_scales_colwise.get(), + unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise, 0.01, false, mismatch_idx_c); + + if (mismatch_idx_c.size()) { + adjust_ref(mismatch_idx_c, ref_output_c_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, rows, cols, otype); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); + compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); + } else +#endif // #ifdef __HIP_PLATFORM_AMD__ + { auto [atol, rtol] = getTolerances(otype); compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); @@ -464,6 +504,7 @@ void performTest_x2(const ProcessingMethod processing_method, compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise); + } if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { auto [atol_dbias, rtol_dbias] = getTolerances(itype); @@ -563,7 +604,7 @@ TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { if (getDeviceComputeCapability() < blackwellComputeCapability) { GTEST_SKIP(); } -#endif +#endif // #ifdef __HIP_PLATFORM_AMD__ using namespace transformer_engine; using namespace test; diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index f93c8c9e0..4acbac4fb 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -262,9 +262,24 @@ void performTest_x1(const size_t rows, block_size_rows, block_size_cols, scales_stride); +#ifdef __HIP_PLATFORM_AMD__ + std::vector> mismatch_idx; + if (rowwise) { + compare_e8m0_scaling_factors("rowwise scales", output, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, 0.01, true, mismatch_idx); + } else { + compare_e8m0_scaling_factors("colwise scales", output, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, 0.01, false, mismatch_idx); + } + if (mismatch_idx.size()) { + adjust_ref(mismatch_idx, ref_output.get(), unpadded_blocks_Y, unpadded_blocks_X, rows, cols, otype); + } auto [atol, rtol] = getTolerances(otype); compareResults("output", output, ref_output.get(), rowwise, atol, rtol); +#else // #ifdef __HIP_PLATFORM_AMD__ + auto [atol, rtol] = getTolerances(otype); + compareResults("output", output, ref_output.get(), rowwise, atol, rtol); const uint8_t * const gpu_scales_ptr = rowwise ? output.rowwise_cpu_scale_inv_ptr() @@ -276,6 +291,7 @@ void performTest_x1(const size_t rows, compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, scales_stride); } +#endif // #ifdef __HIP_PLATFORM_AMD__ } /** @@ -361,17 +377,41 @@ void performTest_x2(const size_t rows, block_size_cols, scales_stride_rowwise, scales_stride_colwise); +#ifdef __HIP_PLATFORM_AMD__ + std::vector> mismatch_idx_r; + compare_e8m0_scaling_factors("scales_rowwise", output, + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise, 0.01, true, mismatch_idx_r); + + if (mismatch_idx_r.size()) { + adjust_ref(mismatch_idx_r, ref_output_colwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, rows, cols, otype); + } + + std::vector> mismatch_idx_c; + compare_e8m0_scaling_factors("scales_colwise", output, + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise, 0.01, false, mismatch_idx_c); + + if (mismatch_idx_c.size()) { + adjust_ref(mismatch_idx_c, ref_output_rowwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, rows, cols, otype); + } auto [atol, rtol] = getTolerances(otype); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol); +#else // #ifdef __HIP_PLATFORM_AMD__ + auto [atol, rtol] = getTolerances(otype); + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); + compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol); compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise); compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise); +#endif // #ifdef __HIP_PLATFORM_AMD__ } std::vector> matrix_sizes = { @@ -418,12 +458,12 @@ class CastMXFP8_GatedActTestSuite : public ::testing::TestWithParam TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) { #ifdef __HIP_PLATFORM_AMD__ omp_set_num_threads(std::min(128, omp_get_max_threads())); // Using threads = # of vcpus causes occasional errors. -#else +#else // #ifdef __HIP_PLATFORM_AMD__ // Skip tests for pre-Blackwell architectures if (getDeviceComputeCapability() < blackwellComputeCapability) { GTEST_SKIP(); } -#endif +#endif // #ifdef __HIP_PLATFORM_AMD__ using namespace transformer_engine; diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 32eb1d63a..72ceb601c 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -711,6 +711,74 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, } } +#ifdef __HIP_PLATFORM_AMD__ +void compare_e8m0_scaling_factors(const std::string &name, Tensor &output, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + double tol, bool rowwise, std::vector> &mismatch_idx) { + const uint8_t *const test = rowwise ? output.rowwise_cpu_scale_inv_ptr() + : output.columnwise_cpu_scale_inv_ptr(); + + const float scale_tol = std::max(1.f, row_blocks * col_blocks * tol); + + for (int i = 0; i < row_blocks; i++) { + for (int j = 0; j < col_blocks; j++) { + const int idx = i * stride + j; + if (test[idx] != ref[idx]) { + int t_scale = static_cast(test[idx]); + int r_scale = static_cast(ref[idx]); + if (std::abs(t_scale - r_scale) == 1) { + mismatch_idx.emplace_back(i, j, r_scale-t_scale); + } else { + GTEST_FAIL() << "Error in " << name << std::endl + << "Mismatch: " << t_scale << " vs " + << r_scale << " at index " << idx; + } + } + } + } + const size_t scale_mismatches = mismatch_idx.size(); + + ASSERT_FALSE(scale_mismatches > scale_tol) + << "Error in " << name << std::endl << std::setprecision(4) + << "Total scale mismatches: " << scale_mismatches << " (" << 100.*(double)scale_mismatches/(double)(row_blocks*col_blocks) + << "%) Exceeds tolerance of " << scale_tol << " (" << 100.*tol << "%) mismatches"; + + if (scale_mismatches) { + std::cout << "\x1b[33mWARNING:\x1b[0m " << scale_mismatches + << " scale mismatches were found. This does not imply an accuracy issue." << std::endl; + } +} + +void adjust_ref(std::vector> mismatch_idx, void *ref, const size_t row_blocks, + const size_t col_blocks, const size_t rows, const size_t cols, DType otype) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY( otype, T, + T *ref_data = reinterpret_cast(ref); + double scale_val; + const size_t col_blocks_size = cols / col_blocks; + const size_t row_blocks_size = rows / row_blocks; + for (const auto &[i, j, scale_diff] : mismatch_idx) { + if (scale_diff == 1) { + scale_val = 2.; + } else if (scale_diff == -1) { + scale_val = .5; + } else { // Shouldn't ever reach this + GTEST_FAIL() << "Error in adjust_ref, |scale_diff| > 1"; + } + size_t ii_min = i * row_blocks_size; + const size_t ii_max = std::min(ii_min + row_blocks_size, rows); + for (; ii_min < ii_max; ii_min++) { + size_t jj_min = j * col_blocks_size; + const size_t jj_max = std::min(jj_min + col_blocks_size, cols); + for (; jj_min < jj_max; jj_min++) { + const size_t data_idx = ii_min * cols + jj_min; + ref_data[data_idx] = static_cast(static_cast(ref_data[data_idx]) * scale_val); + } + } + } + ); // NOLINT(*) +} +#endif // #ifdef __HIP_PLATFORM_AMD__ + std::pair getTolerances(const DType type) { switch(type) { case DType::kFloat32: diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 7ac2b75a6..6b9514d38 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -19,6 +19,7 @@ #else #include #include "amd_detail/hip_float8.h" +#include #endif #include @@ -461,6 +462,14 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const size_t row_blocks, const size_t col_blocks, const size_t stride); void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, const size_t N); +#ifdef USE_ROCM +void compare_e8m0_scaling_factors(const std::string &name, Tensor &output, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + double tol, bool rowwise, std::vector> &mismatch_idx); + +void adjust_ref(std::vector> mismatch_idx, void *ref, const size_t row_blocks, + const size_t col_blocks, const size_t rows, const size_t cols, DType otype); +#endif std::array get_scale_tensor_dims(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols);