Skip to content
4 changes: 2 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,11 @@ TVM_DLL Pass BF16ComputeLegalize();
/*!
* \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
* before Ops, then add a cast back to fp8.
* \param promote_dtype_str The data type used for type promotion, defaults to float16
* \param promote_dtype The data type used for type promotion, defaults to float16
* \note Must be run after BindTarget, as it relies on target attributes for PrimFuncs
* \return The pass.
*/
TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype_str = "float16");
TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype = "float16");

/*!
* \brief Legalize bf16 storage types to u16.
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def BF16ComputeLegalize():
return _ffi_api.BF16ComputeLegalize() # type: ignore


def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
def FP8ComputeLegalize(promote_dtype: str = "float32"):
"""Legalize fp8 compute Ops.

Parameters
Expand All @@ -257,7 +257,7 @@ def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore
return _ffi_api.FP8ComputeLegalize(promote_dtype) # type: ignore


def BF16StorageLegalize():
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/unsupported_dtype_legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -780,13 +780,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::GlobalDef().def("tir.transform.BF16StorageLegalize", BF16StorageLegalize);
}

Pass FP8ComputeLegalize(ffi::String promote_dtype_str) {
Pass FP8ComputeLegalize(ffi::String promote_dtype) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
return f;
}
return FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype_str))).Legalize(f);
return FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype))).Legalize(f);
};
return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {});
}
Expand Down
Loading