Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,21 @@ void performTest(const TestParams& params) {
(void)cudaGetDeviceProperties(&prop, 0);

#ifdef __HIP_PLATFORM_AMD__

// Enable FP8 GEMM + GELU fusion tests only on MI300 (gfx942) with ROCm > 7.0.
// hipBLASLt currently supports this config only
bool fp8_gelu_fusion_config = false;
#if HIP_VERSION >= 70000000
if (prop.major == 9 && prop.minor == 4)
{
fp8_gelu_fusion_config = atype == DType::kFloat8E4M3 &&
btype == DType::kFloat8E4M3 &&
dtype == DType::kFloat8E4M3 &&
(params.use_gelu && gelu_type == DType::kFloat16) &&
(!params.use_bias || bias_type == DType::kFloat16);
}
#endif

if (has_fp8)
{
bool fp8_supported = (prop.major == 9 && prop.minor >= 4);
Expand All @@ -227,8 +242,8 @@ void performTest(const TestParams& params) {
}
}

if (params.use_gelu) {
GTEST_SKIP() << "FP8 GEMM with GELU is not supported";
if (params.use_gelu && !fp8_gelu_fusion_config) {
GTEST_SKIP() << "FP8 GEMM with GELU is not supported in current config";
}
if (params.use_bias && dtype == DType::kFloat16) {
GTEST_SKIP() << "FP8 GEMM with bias and FP16 output is not supported";
Expand All @@ -252,7 +267,7 @@ void performTest(const TestParams& params) {
if (params.use_gelu && dtype == DType::kBFloat16 && !params.transa) {
GTEST_SKIP() << "BF16 GEMM with GELU is not supported in current config";
}
if (has_fp8 && params.use_bias && dtype == DType::kFloat8E4M3) {
if (has_fp8 && params.use_bias && dtype == DType::kFloat8E4M3 && !fp8_gelu_fusion_config) {
GTEST_SKIP() << "FP8 GEMM with bias and FP8 output is not supported in current config";
}
}
Expand Down Expand Up @@ -506,6 +521,7 @@ MAKE_GEMM_TEST(Testbf8xfp8xbf16xbf16xfp8, bf8, fp8, bf16, bf16, fp8);

MAKE_GEMM_TEST(Testbf8xfp8xbf16xbf16xbf8, bf8, fp8, bf16, bf16, bf8);

MAKE_GEMM_TEST(Testfp8xfp8xfp16xfp16xfp8, fp8, fp8, fp16, fp16, fp8);

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
Expand Down
49 changes: 40 additions & 9 deletions transformer_engine/common/gemm/rocm_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ static class GemmAlgoCache {
public:
struct Key {
int deviceCap;
hipDataType a_type, b_type, d_type, bias_type;
hipDataType a_type, b_type, d_type, bias_type, aux_type;
int m, n, k;
int lda, ldb, ldd;
hipblasOperation_t transa, transb;
Expand All @@ -532,13 +532,13 @@ public:

Key(int deviceCap_,
hipDataType a_type_, hipDataType b_type_,
hipDataType d_type_, hipDataType bias_type_,
hipDataType d_type_, hipDataType bias_type_, hipDataType aux_type_,
int m_, int n_, int k_, int lda_, int ldb_, int ldd_,
hipblasOperation_t transa_, hipblasOperation_t transb_,
int scaling_mode_, hipblasLtEpilogue_t epilogue_):
deviceCap(deviceCap_),
a_type(a_type_), b_type(b_type_),
d_type(d_type_), bias_type(bias_type_),
d_type(d_type_), bias_type(bias_type_), aux_type(aux_type_),
m(m_), n(n_), k(k_), lda(lda_), ldb(ldb_), ldd(ldd_),
transa(transa_), transb(transb_),
scaling_mode(scaling_mode_), epilogue(epilogue_) {}
Expand All @@ -550,6 +550,7 @@ public:
return ((deviceCap == val.deviceCap)
&& (a_type == val.a_type) && (b_type == val.b_type)
&& (d_type == val.d_type) && (bias_type == val.bias_type)
&& (aux_type == val.aux_type)
&& (m == val.m) && (n == val.n) && (k == val.k)
&& (lda == val.lda) && (ldb == val.ldb) && (ldd == val.ldd)
&& (transa == val.transa) && (transb == val.transb)
Expand Down Expand Up @@ -681,7 +682,7 @@ protected:
{
csv_helper fs(ofs, csv_sep);
fs << "dev_cap" << "m" << "n" << "k" << "trans_a" << "trans_b"
<< "type_a" << "type_b" << "type_d" << "bias_type"
<< "type_a" << "type_b" << "type_d" << "bias_type" << "aux_type"
<< "lda" << "ldb" << "ldd" << "scale_mode" << "epi" << "comp" << "scale_type"
<< "ws_min" << "ws_max" << "algo_id" << "aidx";
}
Expand Down Expand Up @@ -723,7 +724,7 @@ protected:
if (line.empty() || line[0] == '#') continue;
std::istringstream is(line);
char c;
std::string type_a, type_b, type_d, bias_type, trans_a, trans_b, epi, comp, scale;
std::string type_a, type_b, type_d, bias_type, aux_type, trans_a, trans_b, epi, comp, scale;
int64_t algo_id;
int algo_idx;
size_t ws_min, ws_max;
Expand All @@ -750,6 +751,7 @@ protected:
std::getline(is, type_d, csv_sep);
std::getline(is, bias_type, csv_sep);
is >> cfg.lda >> c >> cfg.ldb >> c >> cfg.ldd >> c >> cfg.scaling_mode >> c;
std::getline(is, aux_type, csv_sep);
std::getline(is, epi, csv_sep);
std::getline(is, comp, csv_sep);
std::getline(is, scale, csv_sep);
Expand Down Expand Up @@ -801,6 +803,9 @@ protected:
cfg.bias_type = (bias_type == "-")
? (hipDataType)-1
: typeNameMapper.getValue(bias_type, "bias_type", fp8_filter);
cfg.aux_type = (aux_type == "-")
? (hipDataType)-1
: typeNameMapper.getValue(aux_type, "aux_type", fp8_filter);

cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a");
cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b");
Expand Down Expand Up @@ -886,6 +891,7 @@ protected:
<< transposeNameMapper.getName(cfg.transa) << transposeNameMapper.getName(cfg.transb)
<< typeNameMapper.getName(cfg.a_type) << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type)
<< ((cfg.bias_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.bias_type))
<< ((cfg.aux_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.aux_type))
<< cfg.lda << cfg.ldb << cfg.ldd << cfg.scaling_mode << epilogueNameMapper.getName(cfg.epilogue)
<< computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F)
<< algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end() << "\n";
Expand Down Expand Up @@ -1003,19 +1009,35 @@ void hipblaslt_gemm(const Tensor *inputA,
const hipDataType B_type = get_hipblaslt_dtype(param.Btype);
const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype);
const hipDataType bias_type = get_hipblaslt_dtype(inputBias->data.dtype);
// const hipblasltDatatype_t aux_type = get_hipblaslt_dtype(outputPreGelu->data.dtype);
const hipDataType aux_type = get_hipblaslt_dtype(outputPreGelu->data.dtype);

NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr,
"FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
"FP8 input to GEMM requires inverse of scale!");

// check consistency of arguments:
// if fp8 is desired, context cannot be null
#if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15
if (use_fp8 && gelu) {
hipDeviceProp_t prop;
NVTE_CHECK_CUDA(hipGetDeviceProperties(&prop, 0));
// Currently hipblasLT only supports fp8 gemm + gelu fusion only on MI300
if (prop.major == 9 && prop.minor == 4) {
bool allow_fp8_gemm = (param.Atype == DType::kFloat8E4M3) &&
(param.Btype == DType::kFloat8E4M3) &&
(outputD->data.dtype == DType::kFloat8E4M3) &&
(!bias || inputBias->data.dtype == DType::kFloat16) &&
(outputPreGelu->data.dtype == DType::kFloat16 || outputPreGelu->data.dtype == outputD->data.dtype);
NVTE_CHECK(allow_fp8_gemm, "fp8 gemm + gelu fusion is unavailable with current config!");
} else {
NVTE_CHECK(false, "fp8 gemm + gelu fusion is unavailable right now!");
}
}
#else
// fp8 + gelu fusion + fp8 aux is unavailable right now.
if (use_fp8) {
NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
}
#endif
if (is_fp8_dtype(outputD->data.dtype)) {
NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!");
}
Expand Down Expand Up @@ -1064,7 +1086,7 @@ void hipblaslt_gemm(const Tensor *inputA,
&param.transB, sizeof(param.transB)));

// set fp8 attributes -- input and output types should already be set to fp8 as appropriate
// Note: gelu fusion isn't available right now, and we don't need
// Note: gelu fusion is available for certain config from rocm 7.0
// amax(D) either (next op is high precision).
#if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15
hipblasLtMatmulMatrixScale_t scaling_mode;
Expand Down Expand Up @@ -1116,6 +1138,14 @@ void hipblaslt_gemm(const Tensor *inputA,
HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE,
&bias_type, sizeof(bias_type)));
}
#if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15
if (gelu){
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE,
&aux_type,
sizeof(aux_type)));
}
#endif
}

if (bias && gelu) {
Expand Down Expand Up @@ -1167,6 +1197,7 @@ void hipblaslt_gemm(const Tensor *inputA,

GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type,
use_fp8 ? bias_type : (hipDataType)-1,
(use_fp8 && gelu) ? aux_type : (hipDataType)-1,
m, n, k, param.lda, param.ldb, ldd, param.transA, param.transB, scaling_mode, epilogue );
GemmAlgoCache::Algo cached_algo;
if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value())
Expand Down