Skip to content

Commit

Permalink
keep quantize pass in quantize namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Jun 5, 2019
1 parent c0d669c commit 921356f
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 64 deletions.
16 changes: 0 additions & 16 deletions include/tvm/relay/transform.h
Expand Up @@ -531,22 +531,6 @@ TVM_DLL Pass CanonicalizeOps();
*/
TVM_DLL Pass AlterOpLayout();

/*!
* \brief Rewrite a graph and return a graph that simulates the error introduced
* by the current quantization scheme.
*
* \return The pass.
*/
TVM_DLL Pass QuantizeAnnotate();

/*!
* \brief This pass transforms the simulated quantized graph to a low-bit
* integer graph.
*
* \return The pass.
*/
TVM_DLL Pass QuantizeRealize();

} // namespace transform
} // namespace relay
} // namespace tvm
Expand Down
40 changes: 35 additions & 5 deletions python/tvm/relay/quantize/quantize.py
Expand Up @@ -23,6 +23,7 @@
from .. import expr as _expr
from .. import module as _module
from .. import ir_pass as _ir_pass
from .. import transform as _transform
from .. import op as _op
from ... import make as _make
from ..base import NodeBase, register_relay_node
Expand Down Expand Up @@ -238,6 +239,33 @@ def _make_const(val):
return _expr.bind(graph, const_params)


def annotate():
"""Given a float32 graph, this pass will rewrite the graph and return
a graph which simulates the error brought by the current quantization
scheme.
Returns
-------
ret: tvm.relay.Pass
The registered pass for quantization annotation.
"""
return _quantize.QuantizeAnnotate()


def realize():
"""The realize pass will transform the simulated quantized graph, which
actually computes with float32, to a real low-bit integer graph. It will
replace the `simulated_quantize` with several fine-grained operators like
add, multiply, and shift as much as possible for better performance.
Returns
-------
ret: tvm.relay.Pass
The registered pass for quantization realization.
"""
return _quantize.QuantizeRealize()


def _bind_params(func, params):
"""Bind the params to the expression.
"""
Expand Down Expand Up @@ -295,14 +323,16 @@ def quantize(graph, params=None, dataset=None):
_transform.FoldConstant()])

calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
name="calibrate")
name="QuantizeCalibrate")
_set_conv_counter(0) # reset counter
quantize_seq = _transform.Sequential([_transform.QuantizeAnnotate(),
quantize_seq = _transform.Sequential([annotate(),
calibrate_pass,
_transform.QuantizeRealize(),
realize(),
_transform.FoldConstant()])

with _transform.PassContext(opt_level=3, required_pass=["calibrate"]):
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
mod = optimize(mod)
mod = quantize_seq(mod)
return mod[mod.entry_func.name_hint]
27 changes: 0 additions & 27 deletions python/tvm/relay/transform.py
Expand Up @@ -593,30 +593,3 @@ def PartialEvaluate():
The registered pass that performs partial evaluation on an expression.
"""
return _transform.PartialEvaluate()


def QuantizeAnnotate():
"""Given a float32 graph, this pass will rewrite the graph and return
a graph which simulates the error brought by the current quantization
scheme.
Returns
-------
ret: tvm.relay.Pass
The registered pass for quantization annotation.
"""
return _transform.QuantizeAnnotate()


def QuantizeRealize():
"""The realize pass will transform the simulated quantized graph, which
actually computes with float32, to a real low-bit integer graph. It will
replace the `simulated_quantize` with several fine-grained operators like
add, multiply, and shift as much as possible for better performance.
Returns
-------
ret: tvm.relay.Pass
The registered pass for quantization realization.
"""
return _transform.QuantizeRealize()
4 changes: 0 additions & 4 deletions src/relay/pass/pass_manager.cc
Expand Up @@ -64,10 +64,6 @@ Pass GetPass(const std::string& pass_name) {
return FoldScaleAxis();
} else if (pass_name == "PartialEvaluate") {
return SimplifyInference();
} else if (pass_name == "QuantizeAnnotate") {
return QuantizeAnnotate();
} else if (pass_name == "QuantizeRealize") {
return QuantizeRealize();
} else if (pass_name == "SimplifyInference") {
return SimplifyInference();
} else if (pass_name == "ToANormalForm") {
Expand Down
19 changes: 7 additions & 12 deletions src/relay/pass/quantize.cc
Expand Up @@ -43,6 +43,8 @@ namespace tvm {
namespace relay {
namespace quantize {

using namespace relay::transform;

/*! \brief Attribute for simulated quantize operator */
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
int kind;
Expand Down Expand Up @@ -588,12 +590,6 @@ TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
.set_body_typed(QConfig::ExitQConfigScope);

} // namespace quantize

namespace transform {

using namespace relay::quantize;

Pass QuantizeAnnotate() {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
Expand All @@ -615,10 +611,10 @@ Pass QuantizeAnnotate() {
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}

TVM_REGISTER_API("relay._transform.QuantizeAnnotate")
TVM_REGISTER_API("relay._quantize.QuantizeAnnotate")
.set_body_typed(QuantizeAnnotate);

Pass QuantizeRealize() {
Pass QuantizeRealizePass() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
Expand All @@ -627,10 +623,9 @@ Pass QuantizeRealize() {
return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
}

TVM_REGISTER_API("relay._transform.QuantizeRealize")
.set_body_typed(QuantizeRealize);

} // namespace transform
TVM_REGISTER_API("relay._quantize.QuantizeRealize")
.set_body_typed(QuantizeRealizePass);

} // namespace quantize
} // namespace relay
} // namespace tvm

0 comments on commit 921356f

Please sign in to comment.