diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 5dfdec8ea034..6b00a4010554 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -72,6 +72,7 @@ from .lazy_transform_params import LazyTransformParams from .optimize_layout_transform import OptimizeLayoutTransform from .remove_redundant_reshape import RemoveRedundantReshape +from .fast_math import FastMathTransform # Import to register the legalization functions. from . import legalize_ops, tuning_api diff --git a/python/tvm/relax/transform/fast_math.py b/python/tvm/relax/transform/fast_math.py new file mode 100644 index 000000000000..2aebd96db343 --- /dev/null +++ b/python/tvm/relax/transform/fast_math.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local +"""Relax Use Fast Math pass.""" +import tvm +from tvm import topi +from tvm.ir.module import IRModule +from tvm.relax import Expr, Call, expr_functor, PyExprMutator + + +@expr_functor.mutator +class FastMathCodeGenerator(PyExprMutator): + """ + Converts the expensive non linear functions to their fast but approximate counterparts. + + Parameters + ---------- + mod: IRModule + The module to be transformed + """ + + def __init__(self, mod): + super().__init__(mod) + + def visit_call_(self, call: Call) -> Expr: + if call.op.name == "relax.nn.softmax": + return self.builder_.call_te(topi.nn.fast_softmax, call.args[0], call.attrs.axis) + if call.op.name == "relax.exp": + return self.builder_.call_te(topi.fast_exp, call.args[0]) + if call.op.name == "relax.erf": + return self.builder_.call_te(topi.fast_erf, call.args[0]) + if call.op.name == "relax.tanh": + return self.builder_.call_te(topi.fast_tanh, call.args[0]) + + return super().visit_call_(call) + + +@tvm.transform.module_pass(opt_level=0, name="FastMathTransform") +class FastMathTransform: + """ + Pass to convert the expensive non linear functions to their fast but approximate counterparts. + """ + + def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule: + fast_math_codegen = FastMathCodeGenerator(mod) + for gv in mod.functions: + func = mod[gv] + if not isinstance(func, tvm.relax.Function): + continue + func = fast_math_codegen.visit_expr(func) + fast_math_codegen.builder_.update_func(gv, func) + + return fast_math_codegen.builder_.get() diff --git a/tests/python/relax/test_fast_math_transform.py b/tests/python/relax/test_fast_math_transform.py new file mode 100644 index 000000000000..f5b88f312cab --- /dev/null +++ b/tests/python/relax/test_fast_math_transform.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests to validate relax fast math tranform pass.""" + +import pytest +import tvm.testing +from tvm import relax, topi +from tvm.ir.base import assert_structural_equal +from tvm.relax.transform import FastMathTransform +from tvm.script import ir as I, relax as R + + +def _run_pass_compare_output(Before, Expected): + fast_mod = FastMathTransform()(Before) + if not relax.analysis.well_formed(fast_mod): + print("IRModule is not well-formed") + assert_structural_equal(Expected, fast_mod) + + +def test_optimize_transform_layout_pass_one_arg(): + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"): + lv1: R.Tensor((16,), dtype="float32") = R.nn.softmax(x) + lv2: R.Tensor((16,), dtype="float32") = R.exp(lv1) + lv3: R.Tensor((16,), dtype="float32") = R.erf(lv2) + lv4: R.Tensor((16,), dtype="float32") = R.tanh(lv3) + return lv4 + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((16,), "float32")) + with bb.function("main", [x]): + lv1 = bb.emit_te(topi.nn.fast_softmax, x) + lv2 = bb.emit_te(topi.fast_exp, lv1) + lv3 = bb.emit_te(topi.fast_erf, lv2) + lv4 = bb.emit_te(topi.fast_tanh, lv3) + bb.emit_func_output(lv4) + Expected = bb.get() + + _run_pass_compare_output(Before, Expected) + + +if __name__ == "__main__": + tvm.testing.main()