diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index af59db38771d..bf100dc49c4c 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -357,11 +357,11 @@ TVM_DLL Pass BF16ComputeLegalize(); /*! * \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32 * before Ops, then add a cast back to fp8. - * \param promote_dtype_str The data type used for type promotion, defaults to float16 + * \param promote_dtype The data type used for type promotion, defaults to float16 * \note Must be run after BindTarget, as it relies on target attributes for PrimFuncs * \return The pass. */ -TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype_str = "float16"); +TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype = "float16"); /*! * \brief Legalize bf16 storage types to u16. diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index de11d30fbc6e..39105f21a23c 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -244,7 +244,7 @@ def BF16ComputeLegalize(): return _ffi_api.BF16ComputeLegalize() # type: ignore -def FP8ComputeLegalize(promote_dtype_str: str = "float32"): +def FP8ComputeLegalize(promote_dtype: str = "float32"): """Legalize fp8 compute Ops. Parameters @@ -257,7 +257,7 @@ def FP8ComputeLegalize(promote_dtype_str: str = "float32"): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore + return _ffi_api.FP8ComputeLegalize(promote_dtype) # type: ignore def BF16StorageLegalize(): diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index ecdb9883d15f..d35caa4db966 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -780,13 +780,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("tir.transform.BF16StorageLegalize", BF16StorageLegalize); } -Pass FP8ComputeLegalize(ffi::String promote_dtype_str) { +Pass FP8ComputeLegalize(ffi::String promote_dtype) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto target = f->GetAttr(tvm::attr::kTarget).value(); if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) { return f; } - return FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype_str))).Legalize(f); + return FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype))).Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); }