Skip to content

Commit

Permalink
[FIX] skip_conv_layers will affect quantization of nn.dense (#7795)
Browse files Browse the repository at this point in the history
* [FIX] `skip_conv_layers` will affect quantization of `nn.dense`

* [ add ] quantization test case for dense & conv2d

* [ fix ] reformat

* [ reformat ] test file
  • Loading branch information
AD1024 committed Apr 27, 2021
1 parent 82fecbf commit f681359
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
5 changes: 2 additions & 3 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,8 @@ def check_to_skip(self, ref_call):
if current_qconfig().skip_conv_layers is not None:
# check skip conv layers
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if self._conv2d_counter in skipped_indices:
if ref_call.op.name == "nn.conv2d":
self._conv2d_counter += 1
if self._conv2d_counter in skipped_indices and ref_call.op.name == "nn.conv2d":
self._conv2d_counter += 1
return True
if ref_call.op.name == "nn.conv2d":
self._conv2d_counter += 1
Expand Down
1 change: 1 addition & 0 deletions src/relay/quantize/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "global_scale=" << op->global_scale << ", ";
p->stream << "weight_scale=" << op->weight_scale << ", ";
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
p->stream << "skip_dense_layer==" << op->skip_dense_layer << ", ";
p->stream << "do_simulation==" << op->do_simulation << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", ";
Expand Down
35 changes: 35 additions & 0 deletions tests/python/relay/test_pass_auto_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,40 @@ def visit_call(self, call):
assert shift_amount >= 0, "Shift amount must be non-negative."


def test_dense_conv2d_rewrite():
n, c, h, w = 1, 16, 64, 64
data = relay.var("data", relay.TensorType((n, c, h, w)))
inp = relay.var("inp", relay.TensorType((n, c * h * w)))
weight_T = relay.const(np.random.random((n, c * h * w)), dtype="float32")
bias = relay.const(np.random.random((n,)), dtype="float32")
conv_w = relay.const(np.random.random((16, 16, 3, 3)), dtype="float32")

dense_o = relay.nn.dense(inp, weight_T)
linear_o = relay.nn.bias_add(dense_o, bias)
conv2d_o = relay.nn.conv2d(data, conv_w, kernel_size=(3, 3), padding=(1, 1), channels=16)
result = relay.Tuple((linear_o, conv2d_o))

mod = tvm.IRModule.from_expr(result)
with tvm.transform.PassContext(opt_level=3):
with relay.quantize.qconfig(
calibrate_mode="global_scale", global_scale=8.0, skip_dense_layer=False
):
qnn_mod = relay.quantize.quantize(mod)

def _check_dense(node):
if isinstance(node, Call):
if node.op.name == "nn.dense":
assert node.args[0].checked_type.dtype == "int8"
assert node.args[1].checked_type.dtype == "int8"
assert node.checked_type.dtype == "int32"
if node.op.name == "nn.conv2d":
assert node.args[0].checked_type.dtype == "float32"
assert node.args[1].checked_type.dtype == "float32"
assert node.checked_type.dtype == "float32"

relay.analysis.post_order_visit(qnn_mod["main"], _check_dense)


if __name__ == "__main__":
test_mul_rewrite()
test_batch_flatten_rewrite()
Expand All @@ -386,3 +420,4 @@ def visit_call(self, call):
test_unquantizable_core_partition()
test_unquantizable_suffix_partition()
test_left_shift_negative()
test_dense_conv2d_rewrite()

0 comments on commit f681359

Please sign in to comment.