diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 8f7333051a4c..3b4d97576cd7 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -220,9 +220,8 @@ def check_to_skip(self, ref_call): if current_qconfig().skip_conv_layers is not None: # check skip conv layers skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if self._conv2d_counter in skipped_indices: - if ref_call.op.name == "nn.conv2d": - self._conv2d_counter += 1 + if self._conv2d_counter in skipped_indices and ref_call.op.name == "nn.conv2d": + self._conv2d_counter += 1 return True if ref_call.op.name == "nn.conv2d": self._conv2d_counter += 1 diff --git a/src/relay/quantize/quantize.cc b/src/relay/quantize/quantize.cc index 846367c9c8a9..afd5e522657d 100644 --- a/src/relay/quantize/quantize.cc +++ b/src/relay/quantize/quantize.cc @@ -126,6 +126,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "global_scale=" << op->global_scale << ", "; p->stream << "weight_scale=" << op->weight_scale << ", "; p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; + p->stream << "skip_dense_layer==" << op->skip_dense_layer << ", "; p->stream << "do_simulation==" << op->do_simulation << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", "; diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py index 6ebff0e6ac8b..c123b182c0b6 100644 --- a/tests/python/relay/test_pass_auto_quantize.py +++ b/tests/python/relay/test_pass_auto_quantize.py @@ -370,6 +370,40 @@ def visit_call(self, call): assert shift_amount >= 0, "Shift amount must be non-negative." +def test_dense_conv2d_rewrite(): + n, c, h, w = 1, 16, 64, 64 + data = relay.var("data", relay.TensorType((n, c, h, w))) + inp = relay.var("inp", relay.TensorType((n, c * h * w))) + weight_T = relay.const(np.random.random((n, c * h * w)), dtype="float32") + bias = relay.const(np.random.random((n,)), dtype="float32") + conv_w = relay.const(np.random.random((16, 16, 3, 3)), dtype="float32") + + dense_o = relay.nn.dense(inp, weight_T) + linear_o = relay.nn.bias_add(dense_o, bias) + conv2d_o = relay.nn.conv2d(data, conv_w, kernel_size=(3, 3), padding=(1, 1), channels=16) + result = relay.Tuple((linear_o, conv2d_o)) + + mod = tvm.IRModule.from_expr(result) + with tvm.transform.PassContext(opt_level=3): + with relay.quantize.qconfig( + calibrate_mode="global_scale", global_scale=8.0, skip_dense_layer=False + ): + qnn_mod = relay.quantize.quantize(mod) + + def _check_dense(node): + if isinstance(node, Call): + if node.op.name == "nn.dense": + assert node.args[0].checked_type.dtype == "int8" + assert node.args[1].checked_type.dtype == "int8" + assert node.checked_type.dtype == "int32" + if node.op.name == "nn.conv2d": + assert node.args[0].checked_type.dtype == "float32" + assert node.args[1].checked_type.dtype == "float32" + assert node.checked_type.dtype == "float32" + + relay.analysis.post_order_visit(qnn_mod["main"], _check_dense) + + if __name__ == "__main__": test_mul_rewrite() test_batch_flatten_rewrite() @@ -386,3 +420,4 @@ def visit_call(self, call): test_unquantizable_core_partition() test_unquantizable_suffix_partition() test_left_shift_negative() + test_dense_conv2d_rewrite()