diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index d87cf9811bc7..283a9073390f 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -903,7 +904,7 @@ IRModule Prepare(IRModule mod, Device device, Target target) { tec::DeviceMap device_map; // Run minimal transforms on module to establish invariants needed by interpreter. - transform::Sequential seq({transform::SimplifyInference(), + transform::Sequential seq({transform::SimplifyInference(), qnn::transform::Legalize(), // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' // attribute. transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(), diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index d545518c1c3c..2093cea88aea 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -168,9 +169,13 @@ class ConstantFolder : public MixedModeMutator { // We should think about potentially constant evaluation over these ops too. static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); + static auto qnn_canonicalize = Op::GetAttrMap("FTVMQnnCanonicalize"); if (const auto* call_node = call->op.as()) { Op op = GetRef(call_node); - if ((fnoncomputational.count(op) && fnoncomputational[op]) || (call->op == device_copy_op_)) { + + bool is_no_qnn_canonicalized = !qnn_canonicalize.count(op); + bool is_no_computational = fnoncomputational.count(op) && fnoncomputational[op]; + if ((is_no_computational && is_no_qnn_canonicalized) || call->op == device_copy_op_) { return GetRef(call); } } diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 7b4eb5231a2c..7d1529de6c04 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -298,6 +298,31 @@ def before(): assert tvm.ir.structural_equal(run_infer_type(before_mod["main"]), after_mod["main"]) +def test_fold_qnn_quantize(): + t = relay.TensorType([1, 2, 3], "int8") + + def before(): + data = tvm.nd.array(np.array([1.0, 2.0, 3.0], dtype="float32")) + const_fp = relay.const(data, dtype="float32") + const_i8 = relay.qnn.op.quantize(const_fp, output_scale=relay.const(0.5), output_zero_point=relay.const(0)) + x = relay.var("x", t) + add = relay.op.add(x, const_i8) + func = relay.Function([x], add) + return func + + def expected(): + data = tvm.nd.array(np.array([2, 4, 6], dtype="int8")) + const_i8 = relay.const(data, dtype="int8") + x = relay.var("x", t) + add = relay.op.add(x, const_i8) + func = relay.Function([x], add) + return func + + zz = run_opt_pass(before(), transform.FoldConstant()) + zexpected = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(zz, zexpected) + + if __name__ == "__main__": test_fold_const() test_fold_let() @@ -307,3 +332,4 @@ def before(): test_fold_batch_norm() test_fold_ndarray_size() test_fold_dropout() + test_fold_qnn_quantize()