Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import topi
from . import _quantize
from .quantize import QAnnotateKind, current_qconfig
from .quantize import _conv_counter, _set_conv_counter
from .quantize import _conv_counter, _set_conv_counter, _get_scale_counter, _set_scale_counter
from .. import expr as _expr
from .. import op as _op
from ..op import op as _reg
Expand All @@ -23,6 +23,12 @@ def simulated_quantize_compute(attrs, inputs, out_type, target):

data, scale, clip_min, clip_max = inputs

if attrs.passthrough:
# if original value should be passed through
assert attrs.kind != QAnnotateKind.WEIGHT
rdata = topi.identity(data)
return [rdata]

# simulate rounding error
scaled_data = topi.divide(data, scale)
clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
Expand Down Expand Up @@ -112,11 +118,17 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
kind: QAnnotateKind
the kind of annotation field.
"""
dom_scale = _expr.var("dom_scale")
counter = _get_scale_counter()
dom_scale = _expr.var("dom_scale" + str(counter))
clip_min = _expr.var("clip_min")
clip_max = _expr.var("clip_max")
passthrough = 0
passthrough_bound = current_qconfig().passthrough_bound
if kind != QAnnotateKind.WEIGHT:
passthrough = counter > passthrough_bound
_set_scale_counter(counter + 1)
return _quantize.simulated_quantize(
data, dom_scale, clip_min, clip_max, kind, sign, rounding)
data, dom_scale, clip_min, clip_max, kind, sign, rounding, passthrough)


@register_annotate_function("nn.contrib_conv2d_NCHWc")
Expand Down
38 changes: 28 additions & 10 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class QConfig(NodeBase):
"dtype_activation": "int32",
"global_scale": 8.0,
"skip_k_conv": 1,
"passthrough_bound": 1e9,
"round_for_shift": True,
"store_lowbit_output": True,
"debug_enabled_ops": None,
Expand Down Expand Up @@ -152,6 +153,19 @@ def _set_conv_counter(n):
global CONV_COUNTER
CONV_COUNTER = n

SCALE_COUNTER = 0


def _get_scale_counter():
"""Get the global counter for scale setting."""
return SCALE_COUNTER


def _set_scale_counter(n):
"""Set the value of the global scale setting counter."""
global SCALE_COUNTER
SCALE_COUNTER = n


def annotate(graph):
"""Given a float32 graph, annotate will rewrite the graph
Expand All @@ -169,10 +183,11 @@ def annotate(graph):
The graph after annotation
"""
_set_conv_counter(0) # reset counter
_set_scale_counter(0) # reset scale counter
return _quantize.annotate(graph)


def calibrate(graph, dataset=None):
def calibrate(graph, dataset=None, profile_mode=False):
"""The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator.
Expand All @@ -198,6 +213,7 @@ def power2_scale(arr):
cfg = current_qconfig()
const_params = {}
quantize_op = _op.get("relay.op.annotation.simulated_quantize")
profile_data = []

def visit_func(expr):
"""Internal visit function"""
Expand All @@ -208,24 +224,26 @@ def visit_func(expr):
nbit = cfg.get_nbit_by_kind(kind)

valid_bit = nbit - attrs.sign
valid_range = 2**valid_bit

def _make_const(val):
return _expr.const(val, 'float32')

if kind == QAnnotateKind.WEIGHT:
var = expr.args[0]
assert isinstance(var, _expr.Constant)
scale = power2_scale(var.data)
const_params[ndom_scale] = _make_const(scale / valid_range)
else:
scale = cfg.global_scale

def _make_const(val):
return _expr.const(val, 'float32')

valid_range = 2**valid_bit
const_params[ndom_scale] = _make_const(scale / valid_range)
if profile_mode:
profile_data.append((ndom_scale.name_hint, expr.args[0]))
const_params[nclip_min] = _make_const(- (valid_range - 1))
const_params[nclip_max] = _make_const((valid_range - 1))

_ir_pass.post_order_visit(graph, visit_func)
return _expr.bind(graph, const_params)
if profile_mode:
for i, data in enumerate(profile_data):
profile_data[i] = (data[0], _expr.bind(data[1], const_params))
return _expr.bind(graph, const_params), profile_data


def realize(graph):
Expand Down
17 changes: 11 additions & 6 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
int kind;
bool sign;
std::string rounding;
int passthrough;

TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") {
TVM_ATTR_FIELD(kind)
Expand All @@ -36,6 +37,8 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
.describe("whether to use signed data type.");
TVM_ATTR_FIELD(rounding).set_default("round")
.describe("rounding mode. Can be 'floor', 'ceil', 'round'");
TVM_ATTR_FIELD(passthrough).set_default(false)
.describe("whether to passthrough full precision value");
}
};

Expand Down Expand Up @@ -72,13 +75,14 @@ RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
.add_type_rel("SimulatedQuantize", SimulatedQuantizeRel);

TVM_REGISTER_API("relay._quantize.simulated_quantize")
.set_body_typed<Expr(Expr, Expr, Expr, Expr, int, bool, std::string)>(
.set_body_typed<Expr(Expr, Expr, Expr, Expr, int, bool, std::string, bool)>(
[](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max,
int kind, bool sign, std::string rounding) {
int kind, bool sign, std::string rounding, int passthrough) {
auto attrs = make_node<SimulatedQuantizeAttrs>();
attrs->kind = kind;
attrs->sign = sign;
attrs->rounding = rounding;
attrs->passthrough = passthrough;
static const Op& op = Op::Get("relay.op.annotation.simulated_quantize");
return CallNode::make(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {});
});
Expand Down Expand Up @@ -527,10 +531,11 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "nbit_weight=" << op->nbit_weight << ", ";
p->stream << "nbit_activation=" << op->nbit_activation << ", ";
p->stream << "global_scale=" << op->global_scale << ", ";
p->stream << "skip_k_conv==" << op->skip_k_conv << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
p->stream << "skip_k_conv=" << op->skip_k_conv << ", ";
p->stream << "passthrough_bound" << op->passthrough_bound << ", ";
p->stream << "round_for_shift=" << op->round_for_shift << ", ";
p->stream << "store_lowbit_output=" << op->store_lowbit_output << ", ";
p->stream << "debug_enabled_ops=" << op->debug_enabled_ops;
p->stream << ")";
});

Expand Down
2 changes: 2 additions & 0 deletions src/relay/pass/quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class QConfigNode : public Node {
DataType dtype_activation = Int(32);
double global_scale = 8.0;
int skip_k_conv = 1;
int passthrough_bound = 1e9;
bool round_for_shift = true;
bool store_lowbit_output = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
Expand All @@ -119,6 +120,7 @@ class QConfigNode : public Node {
v->Visit("dtype_weight", &dtype_weight);
v->Visit("dtype_activation", &dtype_activation);
v->Visit("global_scale", &global_scale);
v->Visit("passthrough_bound", &passthrough_bound);
v->Visit("skip_k_conv", &skip_k_conv);
v->Visit("round_for_shift", &round_for_shift);
v->Visit("store_lowbit_output", &store_lowbit_output);
Expand Down