From e36c9f075aa7ee14e1a8ccf245df5d1e1648515b Mon Sep 17 00:00:00 2001 From: bartekkuncer Date: Thu, 7 Jul 2022 16:35:29 +0200 Subject: [PATCH] Refactor fc_sum_fuse (#21077) * Refactor fc_sum_fuse * Fix sanity * Simplify Selector * Restore ConnectSubgraphOutputs in fc_sum_fuse_property * Fix node name --- ...sum_fuse.h => dnnl_fc_sum_fuse_property.h} | 100 +++++++----------- .../subgraph/dnnl/dnnl_subgraph_property.cc | 2 +- 2 files changed, 38 insertions(+), 64 deletions(-) rename src/operator/subgraph/dnnl/{dnnl_fc_sum_fuse.h => dnnl_fc_sum_fuse_property.h} (71%) diff --git a/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h similarity index 71% rename from src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h rename to src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h index c65711493c8b..2c19b7b68d7e 100644 --- a/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h +++ b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h @@ -26,8 +26,8 @@ this output is scaled to the proper range. */ -#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_ -#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_ +#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_ #if MXNET_USE_ONEDNN == 1 #include @@ -55,8 +55,7 @@ inline bool EndsWith(std::string const& value, std::string const& ending) { class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 { private: bool quantized_; - SelectStatus status_ = kFail; - std::vector matched_list_; + bool patternFound = false; public: explicit SgDNNLFCSumFuseSelector(bool quantized) : quantized_(quantized) {} @@ -64,18 +63,13 @@ class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 { bool Select(const BiDirectedNode& seed_node, const std::shared_ptr& node_attr) override { const auto n = seed_node.node; - if (n->op() == Op::Get("_sg_onednn_fully_connected")) { - if (SupportDNNLAttr(node_attr) && (seed_node.outputs.size() == 1)) { - auto const& fc_param = nnvm::get(n->attrs.parsed); - if ((!quantized_) || (fc_param.dnnl_param.quantized && !fc_param.dnnl_param.with_eltwise)) { - // Start subgraph when fusing for floats (quantized_ is false for ONEDNN backend) or - // when FC is already quantized (second pass for ONEDNN_QUANTIZE) but not already fuzed - // with elemwise operator. - status_ = kStart; - matched_list_.clear(); - matched_list_.push_back(&seed_node); - return true; - } + if (n->op() == Op::Get("_sg_onednn_fully_connected") && seed_node.outputs.size() == 1) { + auto const& fc_param = nnvm::get(n->attrs.parsed); + if (!quantized_ || (fc_param.dnnl_param.quantized && !fc_param.dnnl_param.with_eltwise)) { + // Start subgraph when fusing for floats (quantized_ is false for ONEDNN backend) or + // when FC is already quantized (second pass for ONEDNN_QUANTIZE) but not already fused + // with elemwise operator. + return true; } } return false; @@ -88,46 +82,29 @@ class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 { bool SelectOutput(const BiDirectedNode& cur_node, const BiDirectedNode& output_node) override { const auto cur_n = cur_node.node; const auto output_n = output_node.node; - if (status_ == kFail || status_ == kSuccess || output_n->is_variable()) { + if (patternFound || output_n->is_variable()) { return false; } - // If n isn't the last matched node, then we encoutered an internal - // branch, we should pop out the node behind n and stop fusion. - if (matched_list_.back() != &cur_node) { - if (std::find(matched_list_.begin(), matched_list_.end(), &cur_node) != matched_list_.end()) { - while (matched_list_.back() != &cur_node) { - matched_list_.pop_back(); + + // Find _contrib_quantized_elemwise_add or elemwise_add + if (EndsWith(output_n->op()->name, "elemwise_add")) { + if (quantized_) { + auto const& fc_param = nnvm::get(cur_n->attrs.parsed); + if (!fc_param.dnnl_param.enable_float_output) { + // For quantized graph, when FC floating point output is not enabled elementwise add must + // also be quantized (min and max value have to be already stored in elementwise add). + CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1); } } - status_ = kSuccess; + patternFound = true; + return true; + } else { return false; } - - switch (status_) { - case kStart: - // Find _contrib_quantized_elemwise_add or elemwise_add - if (EndsWith(output_n->op()->name, "elemwise_add")) { - if (quantized_) { - auto const& fc_param = nnvm::get(cur_n->attrs.parsed); - if (!fc_param.dnnl_param.enable_float_output) { - // For quantized graph, when FC floating point output is not enabled - // elementwise add must also be quantized (min and max value have to be already stored - // in elementwise add). - CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1); - } - } - matched_list_.push_back(&output_node); - status_ = kSuccess; - return true; - } - default: - status_ = kFail; - return false; - } } std::vector Filter(const std::vector& candidates) override { - if (status_ == kSuccess) { + if (patternFound) { return candidates; } else { return std::vector(0); @@ -135,10 +112,7 @@ class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 { } void Reset() override { - CHECK_GE(matched_list_.size(), 1); - auto new_selector = SgDNNLFCSumFuseSelector(quantized_); - new_selector.Select(*matched_list_[0], nullptr); - *this = new_selector; + patternFound = false; } }; @@ -147,11 +121,11 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty { SgDNNLFCSumFuseProperty() {} static SubgraphPropertyPtr Create() { - static const std::string& name = "DNNL fuse FullyConnected with sum"; + static const std::string& name = "oneDNN fuse FullyConnected with sum"; auto property = std::make_shared(); property->SetAttr("property_name", name); property->SetAttr("inference_only", true); - if (dmlc::GetEnv("MXNET_DISABLE_DNNL_FC_SUM", 0)) { + if (dmlc::GetEnv("MXNET_DISABLE_ONEDNN_FC_SUM", 0)) { property->SetAttr("disable", true); } return property; @@ -207,33 +181,33 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty { return selector; } - void ConnectSubgraphOutputs(const nnvm::ObjectPtr n, + void ConnectSubgraphOutputs(const nnvm::ObjectPtr subgraph_node, std::vector* output_entries) const override { // Connect all extern output entries to output[0] for (size_t i = 0; i < output_entries->size(); ++i) { auto entry_ptr = output_entries->at(i); - *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0}; + *entry_ptr = nnvm::NodeEntry{subgraph_node, entry_ptr->index, 0}; } } - void ConnectSubgraphInputs(const nnvm::ObjectPtr n, + void ConnectSubgraphInputs(const nnvm::ObjectPtr subgraph_node, std::vector* input_entries, std::vector* orig_input_entries) const override { - auto sym = n->attrs.subgraphs[0]; - auto const& fc_param = nnvm::get(n->attrs.parsed); - std::unordered_set node_sets; + auto sym = subgraph_node->attrs.subgraphs[0]; + auto const& fc_param = nnvm::get(subgraph_node->attrs.parsed); + std::unordered_set node_set; DFSVisit(sym->outputs, [&](const nnvm::ObjectPtr& node) { if (node->is_variable()) { return; } - node_sets.insert(node.get()); + node_set.insert(node.get()); if (EndsWith(node->op()->name, "elemwise_add")) { const size_t base_inputs = fc_param.default_param.no_bias ? 3 : 4; // Make sure fc output is the left operand of the add operator, if not: // - swap inputs of add operator // - switch add operands sequence to ensure that // the tensor (sum_tensor) to which FC output is added is the last input. - if (node_sets.count(node->inputs[1].node.get())) { + if (node_set.count(node->inputs[1].node.get())) { // Example of input_entries reordering for channel-wise quantized graph: // sum_tensor.data --> fc.data // fc.data --> fc.weight0 @@ -272,7 +246,7 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty { } } }); - n->inputs = *orig_input_entries; + subgraph_node->inputs = *orig_input_entries; } }; @@ -280,4 +254,4 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty { } // namespace mxnet #endif // if MXNET_USE_ONEDNN == 1 -#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_ +#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_ diff --git a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc index 69fae1c97d36..86e08020eecb 100644 --- a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc +++ b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc @@ -30,7 +30,7 @@ #include "dnnl_pow_mul_scalar_property.h" #include "dnnl_transformer_qk_property.h" #include "dnnl_transformer_valatt_property.h" -#include "dnnl_fc_sum_fuse.h" +#include "dnnl_fc_sum_fuse_property.h" #include "dnnl_remove_casts_property.h" namespace mxnet {