From f943db15b640fd8d97bd6baaa7910c5a3d030fe0 Mon Sep 17 00:00:00 2001 From: jikechao Date: Sat, 9 Sep 2023 12:50:36 +0800 Subject: [PATCH 1/6] Update test_forward.py --- tests/python/frontend/oneflow/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py index cc9333cd03bd..17583b3c25d4 100644 --- a/tests/python/frontend/oneflow/test_forward.py +++ b/tests/python/frontend/oneflow/test_forward.py @@ -721,7 +721,7 @@ def forward(self, x): for device in ["llvm"]: verify_activation(model1, device=device) - # verify_activation(model2, device=device) # NO PASS + verify_activation(model2, device=device) verify_activation(model3, device=device) verify_activation(model4, device=device) verify_activation(model5, device=device) From afafd10e6dbf8068b44fff6ebc44b5e1032adc7d Mon Sep 17 00:00:00 2001 From: jikechao Date: Sat, 9 Sep 2023 12:52:32 +0800 Subject: [PATCH 2/6] fix a bug in softplus --- python/tvm/relay/frontend/oneflow.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index 4f278d8249e8..ceb5da573999 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -1119,8 +1119,11 @@ class Softplus(OneFlowOpConverter): def _impl_v1(cls, inputs, attrs, params): data = inputs[0] data_dtype = infer_type(data).checked_type.dtype - data = _op.exp(data) + _expr.const(1, dtype=data_dtype) - return _op.log(data) + beta = _expr.const(int(inputs[1]), dtype=data_dtype) + threshold = int(inputs[2]) if inputs[2] else 20 + threshold_ = _op.full_like(data, fill_value=_expr.const(threshold)) + softplus_value = _op.log(_op.exp(data * beta) + _expr.const(1.0, dtype=data_dtype)) / beta + return _op.where(_op.greater(data * beta, threshold_), data, softplus_value) class Softsign(OneFlowOpConverter): From ce6bf304be2c53e74d0cf9886687386d5e38e9c5 Mon Sep 17 00:00:00 2001 From: jikechao Date: Sat, 9 Sep 2023 13:15:49 +0800 Subject: [PATCH 3/6] Update oneflow.py --- python/tvm/relay/frontend/oneflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index ceb5da573999..7a713e5e15ee 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -1119,8 +1119,8 @@ class Softplus(OneFlowOpConverter): def _impl_v1(cls, inputs, attrs, params): data = inputs[0] data_dtype = infer_type(data).checked_type.dtype - beta = _expr.const(int(inputs[1]), dtype=data_dtype) - threshold = int(inputs[2]) if inputs[2] else 20 + beta = _expr.const(float(attrs.get("beta", 1.0))) + threshold = float(attrs.get("threshold", 20.0)) threshold_ = _op.full_like(data, fill_value=_expr.const(threshold)) softplus_value = _op.log(_op.exp(data * beta) + _expr.const(1.0, dtype=data_dtype)) / beta return _op.where(_op.greater(data * beta, threshold_), data, softplus_value) From 23874bead14c2aa1bb45c1bcc525c590583fd3b9 Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Tue, 23 Sep 2025 10:34:28 +0800 Subject: [PATCH 4/6] Update transform.py --- python/tvm/tir/transform/transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index de11d30fbc6e..39105f21a23c 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -244,7 +244,7 @@ def BF16ComputeLegalize(): return _ffi_api.BF16ComputeLegalize() # type: ignore -def FP8ComputeLegalize(promote_dtype_str: str = "float32"): +def FP8ComputeLegalize(promote_dtype: str = "float32"): """Legalize fp8 compute Ops. Parameters @@ -257,7 +257,7 @@ def FP8ComputeLegalize(promote_dtype_str: str = "float32"): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore + return _ffi_api.FP8ComputeLegalize(promote_dtype) # type: ignore def BF16StorageLegalize(): From 07fb300e29a5b03b6f2da4c304236fef67872afc Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Tue, 23 Sep 2025 10:37:29 +0800 Subject: [PATCH 5/6] Update transform.h --- include/tvm/tir/transform.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index af59db38771d..bf100dc49c4c 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -357,11 +357,11 @@ TVM_DLL Pass BF16ComputeLegalize(); /*! * \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32 * before Ops, then add a cast back to fp8. - * \param promote_dtype_str The data type used for type promotion, defaults to float16 + * \param promote_dtype The data type used for type promotion, defaults to float16 * \note Must be run after BindTarget, as it relies on target attributes for PrimFuncs * \return The pass. */ -TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype_str = "float16"); +TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype = "float16"); /*! * \brief Legalize bf16 storage types to u16. From fc9bf0287ee555803deb11f226e8ddaa970c75f3 Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Tue, 23 Sep 2025 10:38:19 +0800 Subject: [PATCH 6/6] Update unsupported_dtype_legalize.cc --- src/tir/transforms/unsupported_dtype_legalize.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index ecdb9883d15f..d35caa4db966 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -780,13 +780,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("tir.transform.BF16StorageLegalize", BF16StorageLegalize); } -Pass FP8ComputeLegalize(ffi::String promote_dtype_str) { +Pass FP8ComputeLegalize(ffi::String promote_dtype) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto target = f->GetAttr(tvm::attr::kTarget).value(); if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) { return f; } - return FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype_str))).Legalize(f); + return FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype))).Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); }