diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 4d47663669dcc..aa1092fa6bb91 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -9,6 +9,7 @@ #include #include #include +#include #include "quantize.h" #include "pattern_util.h" @@ -23,24 +24,6 @@ namespace quantize { using runtime::TypedPackedFunc; -// SimulatedQuantize -struct SimulatedQuantizeAttrs : public tvm::AttrsNode { - bool sign; - std::string rounding; - int id; - int field_type; - - TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { - TVM_ATTR_FIELD(sign).set_default(true); - TVM_ATTR_FIELD(rounding).set_default("round") - .describe("rounding mode. Can be 'floor', 'ceil', 'round'"); - TVM_ATTR_FIELD(id) - .describe("id"); - TVM_ATTR_FIELD(field_type) - .describe("field_type"); - } -}; - TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs); bool SimulatedQuantizeRel(const Array& types, @@ -106,22 +89,7 @@ TVM_REGISTER_API("relay.op._make.simulated_quantize") // |->mul2 -// qtz_field -enum QField : int { - kFloat = 0, - kQInput = 1, - kQWeight = 2, - kQActivation = 3, -}; - -using FQFieldSpec = TypedPackedFunc(Attrs, Array)>; - -inline QField Int2Field(Integer x) { - return static_cast(x.operator int64_t()); -} - - -Expr MakeSimulatedQuantize(Expr x, QField field) { +Expr MakeSimulatedQuantize(Expr x, std::string field) { static const Op& op = Op::Get("simulated_quantize"); static int cnt = 0; std::string name_postfix = std::to_string(cnt++); @@ -134,117 +102,112 @@ Expr MakeSimulatedQuantize(Expr x, QField field) { auto attrs = make_node(); attrs->sign = true; attrs->rounding = "round"; - attrs->id = cnt; attrs->field_type = field; return CallNode::make(op, {x, dom_scale, bit, clip_min, clip_max}, Attrs(attrs), {}); } -class Annotator : public ExprMutator { - public: - Expr Annotate(Expr e) { - this->cnt_map_ = GetExprRefCount(e); - return this->Mutate(e); - } - - Expr VisitExpr_(const CallNode* n) final { - static const auto& fqfield_spec = - Op::GetAttr("FQFieldSpec"); - size_t ref_cnt = cnt_map_.at(n); - - Expr new_e = ExprMutator::VisitExpr_(n); - const auto* call = new_e.as(); - CHECK(call); - - size_t num_inputs = call->args.size(); - // prepare input fields - Array ifields; - for (size_t i = 0; i < num_inputs; ++i) { - ifields.push_back(field_map_.at(call->args[i].get())); - } - - auto f = GetFunc(fqfield_spec, call->op); - if (f != nullptr) { - // get fields spec - Array fields = f(call->attrs, ifields); - // insert simulated quantize - Array new_args; - for (size_t i = 0; i < num_inputs; ++i) { - if (Int2Field(ifields[i]) != Int2Field(fields[i])) { - new_args.push_back(MakeSimulatedQuantize(call->args[i], Int2Field(fields[i]))); - } else { - new_args.push_back(call->args[i]); - } - } - // mark output's field - Call ret = CallNode::make(call->op, new_args, call->attrs, call->type_args); - field_map_[ret.get()] = Int2Field(fields[num_inputs]); - return ret; +Array PrepareInputs(const Array& args) { + Array inputs; + for (Expr arg : args) { + if (const auto* n = arg.as()) { + inputs.push_back(QFieldExpr(arg.node_)); } else { - // default behavior for nodes like add, relu - // it will broadcast the previous node's field - QField field = SelectField(ifields); - // change to float field for multiple ref - field_map_[new_e.get()] = ref_cnt > 1 ? kFloat : field; - return new_e; + auto node = make_node(); + node->expr = arg; + node->field = "float"; + inputs.push_back(QFieldExpr(node)); } } + return inputs; +} - Expr VisitExpr_(const VarNode* n) final { - Expr new_e = ExprMutator::VisitExpr_(n); - field_map_[new_e.get()] = kFloat; - return new_e; - } - private: - std::unordered_map cnt_map_; - std::unordered_map field_map_; - - QField SelectField(Array ifields) { - for (auto field : ifields) { - // just return the first non-float field - if (Int2Field(field) != kFloat) { - return Int2Field(field); - } - } - return kFloat; +Expr Conv2DQFieldRewrite(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + auto rnode = make_node(); + Array inputs = PrepareInputs(new_args); + QFieldExpr lhs = inputs[0]; + QFieldExpr rhs = inputs[1]; + + Expr lhs_expr = MakeSimulatedQuantize(lhs->expr, "input"); + Expr rhs_expr = MakeSimulatedQuantize(rhs->expr, "weight"); + rnode->expr = CallNode::make(ref_call->op, + {lhs_expr, rhs_expr}, + ref_call->attrs, ref_call->type_args); + rnode->field = "activation"; + return Expr(rnode); +} + +RELAY_REGISTER_OP("nn.conv2d") +.set_attr("FQFieldRewrite", Conv2DQFieldRewrite); + + +Expr MulQFieldRewrite(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + auto rnode = make_node(); + Array inputs = PrepareInputs(new_args); + QFieldExpr lhs = inputs[0]; + QFieldExpr rhs = inputs[1]; + + if (lhs->field == "float" && rhs->field == "float") { + // execute the op on float domain + rnode->expr = CallNode::make(ref_call->op, + {lhs->expr, rhs->expr}, + ref_call->attrs, ref_call->type_args); + rnode->field = "float"; + } else if (lhs->field == "activation" && rhs->field == "float"){ + // quantize rhs first + Expr rhs_expr = MakeSimulatedQuantize(rhs->expr, "weight"); + rnode->expr = CallNode::make(ref_call->op, + {lhs->expr, rhs_expr}, + ref_call->attrs, ref_call->type_args); + rnode->field = "activation"; + } else { + LOG(FATAL) << "do not handle yet."; } -}; -Expr Annotate(const Expr& e) { - return Annotator().Annotate(e); + return Expr(rnode); } -TVM_REGISTER_API("relay._quantize.annotate") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = Annotate(args[0]); - }); -// register attribute for annotator -Array Conv2dQFieldSpec(Attrs attrs, Array ifields) { - return {kQInput, kQWeight, kQActivation}; +RELAY_REGISTER_OP("multiply") +.set_attr("FQFieldRewrite", MulQFieldRewrite); + + +// share rewrite function for now +RELAY_REGISTER_OP("add") +.set_attr("FQFieldRewrite", MulQFieldRewrite); + + +Expr ReluQFieldRewrite(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + auto rnode = make_node(); + Array inputs = PrepareInputs(new_args); + QFieldExpr input = inputs[0]; + + rnode->expr = CallNode::make(ref_call->op, {input->expr}, + ref_call->attrs, ref_call->type_args); + rnode->field = input->field; + return Expr(rnode); } -RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FQFieldSpec", Conv2dQFieldSpec); +RELAY_REGISTER_OP("nn.relu") +.set_attr("FQFieldRewrite", ReluQFieldRewrite); -Array MulQFieldSpec(Attrs attrs, Array ifields) { - CHECK(ifields.size() == 2); - QField lhs = Int2Field(ifields[0]); - QField rhs = Int2Field(ifields[1]); - if (lhs == kFloat && rhs == kFloat) { - return {kFloat, kFloat, kFloat}; - } else if (lhs == kQActivation || rhs == kQActivation) { - return {kQActivation, kQWeight, kQActivation}; - } else { - LOG(FATAL) << "wrong fields for mul"; - return {}; - } +Expr Annotate(Expr expr) { + return ForwardRewrite( + expr, "FQFieldRewrite"); } -RELAY_REGISTER_OP("multiply") -.set_attr("FQFieldSpec", MulQFieldSpec); +TVM_REGISTER_API("relay._quantize.annotate") +.set_body_typed(Annotate); + + // ============= @@ -319,7 +282,7 @@ QIntState RealizeQuantize(const Attrs attrs, if (static_cast(magnitude) == magnitude) { // int32->int8, idom_scale < odom_scale, right_shift data = RightShift(data, MakeConstantScalar(Int(32), static_cast(magnitude))); - // TODO do we need to clip + // TODO do we need to clip? DataType cast_dtype = Int(bit_imm); Expr cast_data = Cast(data, cast_dtype); return QIntStateNode::make(cast_data, odom_scale, bit); diff --git a/src/relay/pass/quantize.h b/src/relay/pass/quantize.h index 84c6a454e2f4b..6dbdf47cd6cc6 100644 --- a/src/relay/pass/quantize.h +++ b/src/relay/pass/quantize.h @@ -15,6 +15,41 @@ namespace tvm { namespace relay { namespace quantize { +// SimulatedQuantize +struct SimulatedQuantizeAttrs : public tvm::AttrsNode { + bool sign; + std::string rounding; + std::string field_type; + + TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { + TVM_ATTR_FIELD(sign).set_default(true); + TVM_ATTR_FIELD(rounding).set_default("round") + .describe("rounding mode. Can be 'floor', 'ceil', 'round'"); + TVM_ATTR_FIELD(field_type) + .describe("field_type"); + } +}; + +Expr MakeSimulatedQuantize(Expr x, std::string field); + + +class QFieldExprNode : public TempExprNode { + public: + Expr expr; + std::string field; + Expr Realize() const { + // dequantize + Expr ret = MakeSimulatedQuantize(expr, "float"); + return ret; + } + + static constexpr const char* _type_key = "relay.QFieldExpr"; + TVM_DECLARE_NODE_TYPE_INFO(QFieldExprNode, TempExprNode); +}; + +RELAY_DEFINE_NODE_REF(QFieldExpr, QFieldExprNode, TempExpr); + + class QState; class QIntState; class QRealState;