diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index e7244bb0b..7d0597ef7 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -209,6 +209,21 @@ void performTest(const TestParams& params) { (void)cudaGetDeviceProperties(&prop, 0); #ifdef __HIP_PLATFORM_AMD__ + + // Enable FP8 GEMM + GELU fusion tests only on MI300 (gfx942) with ROCm > 7.0. + // hipBLASLt currently supports this config only + bool fp8_gelu_fusion_config = false; + #if HIP_VERSION >= 70000000 + if (prop.major == 9 && prop.minor == 4) + { + fp8_gelu_fusion_config = atype == DType::kFloat8E4M3 && + btype == DType::kFloat8E4M3 && + dtype == DType::kFloat8E4M3 && + (params.use_gelu && gelu_type == DType::kFloat16) && + (!params.use_bias || bias_type == DType::kFloat16); + } + #endif + if (has_fp8) { bool fp8_supported = (prop.major == 9 && prop.minor >= 4); @@ -227,8 +242,8 @@ void performTest(const TestParams& params) { } } - if (params.use_gelu) { - GTEST_SKIP() << "FP8 GEMM with GELU is not supported"; + if (params.use_gelu && !fp8_gelu_fusion_config) { + GTEST_SKIP() << "FP8 GEMM with GELU is not supported in current config"; } if (params.use_bias && dtype == DType::kFloat16) { GTEST_SKIP() << "FP8 GEMM with bias and FP16 output is not supported"; @@ -252,7 +267,7 @@ void performTest(const TestParams& params) { if (params.use_gelu && dtype == DType::kBFloat16 && !params.transa) { GTEST_SKIP() << "BF16 GEMM with GELU is not supported in current config"; } - if (has_fp8 && params.use_bias && dtype == DType::kFloat8E4M3) { + if (has_fp8 && params.use_bias && dtype == DType::kFloat8E4M3 && !fp8_gelu_fusion_config) { GTEST_SKIP() << "FP8 GEMM with bias and FP8 output is not supported in current config"; } } @@ -506,6 +521,7 @@ MAKE_GEMM_TEST(Testbf8xfp8xbf16xbf16xfp8, bf8, fp8, bf16, bf16, fp8); MAKE_GEMM_TEST(Testbf8xfp8xbf16xbf16xbf8, bf8, fp8, bf16, bf16, bf8); +MAKE_GEMM_TEST(Testfp8xfp8xfp16xfp16xfp8, fp8, fp8, fp16, fp16, fp8); INSTANTIATE_TEST_SUITE_P( OperatorTest, diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index e73e3c1de..dcba674e4 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -522,7 +522,7 @@ static class GemmAlgoCache { public: struct Key { int deviceCap; - hipDataType a_type, b_type, d_type, bias_type; + hipDataType a_type, b_type, d_type, bias_type, aux_type; int m, n, k; int lda, ldb, ldd; hipblasOperation_t transa, transb; @@ -532,13 +532,13 @@ public: Key(int deviceCap_, hipDataType a_type_, hipDataType b_type_, - hipDataType d_type_, hipDataType bias_type_, + hipDataType d_type_, hipDataType bias_type_, hipDataType aux_type_, int m_, int n_, int k_, int lda_, int ldb_, int ldd_, hipblasOperation_t transa_, hipblasOperation_t transb_, int scaling_mode_, hipblasLtEpilogue_t epilogue_): deviceCap(deviceCap_), a_type(a_type_), b_type(b_type_), - d_type(d_type_), bias_type(bias_type_), + d_type(d_type_), bias_type(bias_type_), aux_type(aux_type_), m(m_), n(n_), k(k_), lda(lda_), ldb(ldb_), ldd(ldd_), transa(transa_), transb(transb_), scaling_mode(scaling_mode_), epilogue(epilogue_) {} @@ -550,6 +550,7 @@ public: return ((deviceCap == val.deviceCap) && (a_type == val.a_type) && (b_type == val.b_type) && (d_type == val.d_type) && (bias_type == val.bias_type) + && (aux_type == val.aux_type) && (m == val.m) && (n == val.n) && (k == val.k) && (lda == val.lda) && (ldb == val.ldb) && (ldd == val.ldd) && (transa == val.transa) && (transb == val.transb) @@ -681,7 +682,7 @@ protected: { csv_helper fs(ofs, csv_sep); fs << "dev_cap" << "m" << "n" << "k" << "trans_a" << "trans_b" - << "type_a" << "type_b" << "type_d" << "bias_type" + << "type_a" << "type_b" << "type_d" << "bias_type" << "aux_type" << "lda" << "ldb" << "ldd" << "scale_mode" << "epi" << "comp" << "scale_type" << "ws_min" << "ws_max" << "algo_id" << "aidx"; } @@ -723,7 +724,7 @@ protected: if (line.empty() || line[0] == '#') continue; std::istringstream is(line); char c; - std::string type_a, type_b, type_d, bias_type, trans_a, trans_b, epi, comp, scale; + std::string type_a, type_b, type_d, bias_type, aux_type, trans_a, trans_b, epi, comp, scale; int64_t algo_id; int algo_idx; size_t ws_min, ws_max; @@ -750,6 +751,7 @@ protected: std::getline(is, type_d, csv_sep); std::getline(is, bias_type, csv_sep); is >> cfg.lda >> c >> cfg.ldb >> c >> cfg.ldd >> c >> cfg.scaling_mode >> c; + std::getline(is, aux_type, csv_sep); std::getline(is, epi, csv_sep); std::getline(is, comp, csv_sep); std::getline(is, scale, csv_sep); @@ -801,6 +803,9 @@ protected: cfg.bias_type = (bias_type == "-") ? (hipDataType)-1 : typeNameMapper.getValue(bias_type, "bias_type", fp8_filter); + cfg.aux_type = (aux_type == "-") + ? (hipDataType)-1 + : typeNameMapper.getValue(aux_type, "aux_type", fp8_filter); cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a"); cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b"); @@ -886,6 +891,7 @@ protected: << transposeNameMapper.getName(cfg.transa) << transposeNameMapper.getName(cfg.transb) << typeNameMapper.getName(cfg.a_type) << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type) << ((cfg.bias_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.bias_type)) + << ((cfg.aux_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.aux_type)) << cfg.lda << cfg.ldb << cfg.ldd << cfg.scaling_mode << epilogueNameMapper.getName(cfg.epilogue) << computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F) << algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end() << "\n"; @@ -1003,19 +1009,35 @@ void hipblaslt_gemm(const Tensor *inputA, const hipDataType B_type = get_hipblaslt_dtype(param.Btype); const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype); const hipDataType bias_type = get_hipblaslt_dtype(inputBias->data.dtype); - // const hipblasltDatatype_t aux_type = get_hipblaslt_dtype(outputPreGelu->data.dtype); + const hipDataType aux_type = get_hipblaslt_dtype(outputPreGelu->data.dtype); NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr, "FP8 input to GEMM requires inverse of scale!"); NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr, "FP8 input to GEMM requires inverse of scale!"); - // check consistency of arguments: - // if fp8 is desired, context cannot be null +#if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15 + if (use_fp8 && gelu) { + hipDeviceProp_t prop; + NVTE_CHECK_CUDA(hipGetDeviceProperties(&prop, 0)); + // Currently hipblasLT only supports fp8 gemm + gelu fusion only on MI300 + if (prop.major == 9 && prop.minor == 4) { + bool allow_fp8_gemm = (param.Atype == DType::kFloat8E4M3) && + (param.Btype == DType::kFloat8E4M3) && + (outputD->data.dtype == DType::kFloat8E4M3) && + (!bias || inputBias->data.dtype == DType::kFloat16) && + (outputPreGelu->data.dtype == DType::kFloat16 || outputPreGelu->data.dtype == outputD->data.dtype); + NVTE_CHECK(allow_fp8_gemm, "fp8 gemm + gelu fusion is unavailable with current config!"); + } else { + NVTE_CHECK(false, "fp8 gemm + gelu fusion is unavailable right now!"); + } + } +#else // fp8 + gelu fusion + fp8 aux is unavailable right now. if (use_fp8) { NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!"); } +#endif if (is_fp8_dtype(outputD->data.dtype)) { NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!"); } @@ -1064,7 +1086,7 @@ void hipblaslt_gemm(const Tensor *inputA, ¶m.transB, sizeof(param.transB))); // set fp8 attributes -- input and output types should already be set to fp8 as appropriate - // Note: gelu fusion isn't available right now, and we don't need + // Note: gelu fusion is available for certain config from rocm 7.0 // amax(D) either (next op is high precision). #if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15 hipblasLtMatmulMatrixScale_t scaling_mode; @@ -1116,6 +1138,14 @@ void hipblaslt_gemm(const Tensor *inputA, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type))); } +#if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15 + if (gelu){ + NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, + HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, + &aux_type, + sizeof(aux_type))); + } +#endif } if (bias && gelu) { @@ -1167,6 +1197,7 @@ void hipblaslt_gemm(const Tensor *inputA, GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type, use_fp8 ? bias_type : (hipDataType)-1, + (use_fp8 && gelu) ? aux_type : (hipDataType)-1, m, n, k, param.lda, param.ldb, ldd, param.transA, param.transB, scaling_mode, epilogue ); GemmAlgoCache::Algo cached_algo; if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value())