Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Refactor fc_sum_fuse (#21077)
Browse files Browse the repository at this point in the history
* Refactor fc_sum_fuse

* Fix sanity

* Simplify Selector

* Restore ConnectSubgraphOutputs in fc_sum_fuse_property

* Fix node name
  • Loading branch information
bartekkuncer committed Jul 7, 2022
1 parent 26243ee commit e36c9f0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
this output is scaled to the proper range.
*/

#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_
#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_
#if MXNET_USE_ONEDNN == 1

#include <memory>
Expand Down Expand Up @@ -55,27 +55,21 @@ inline bool EndsWith(std::string const& value, std::string const& ending) {
class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
private:
bool quantized_;
SelectStatus status_ = kFail;
std::vector<const BiDirectedNode*> matched_list_;
bool patternFound = false;

public:
explicit SgDNNLFCSumFuseSelector(bool quantized) : quantized_(quantized) {}

bool Select(const BiDirectedNode& seed_node,
const std::shared_ptr<NodeAttr>& node_attr) override {
const auto n = seed_node.node;
if (n->op() == Op::Get("_sg_onednn_fully_connected")) {
if (SupportDNNLAttr(node_attr) && (seed_node.outputs.size() == 1)) {
auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
if ((!quantized_) || (fc_param.dnnl_param.quantized && !fc_param.dnnl_param.with_eltwise)) {
// Start subgraph when fusing for floats (quantized_ is false for ONEDNN backend) or
// when FC is already quantized (second pass for ONEDNN_QUANTIZE) but not already fuzed
// with elemwise operator.
status_ = kStart;
matched_list_.clear();
matched_list_.push_back(&seed_node);
return true;
}
if (n->op() == Op::Get("_sg_onednn_fully_connected") && seed_node.outputs.size() == 1) {
auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
if (!quantized_ || (fc_param.dnnl_param.quantized && !fc_param.dnnl_param.with_eltwise)) {
// Start subgraph when fusing for floats (quantized_ is false for ONEDNN backend) or
// when FC is already quantized (second pass for ONEDNN_QUANTIZE) but not already fused
// with elemwise operator.
return true;
}
}
return false;
Expand All @@ -88,57 +82,37 @@ class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
bool SelectOutput(const BiDirectedNode& cur_node, const BiDirectedNode& output_node) override {
const auto cur_n = cur_node.node;
const auto output_n = output_node.node;
if (status_ == kFail || status_ == kSuccess || output_n->is_variable()) {
if (patternFound || output_n->is_variable()) {
return false;
}
// If n isn't the last matched node, then we encoutered an internal
// branch, we should pop out the node behind n and stop fusion.
if (matched_list_.back() != &cur_node) {
if (std::find(matched_list_.begin(), matched_list_.end(), &cur_node) != matched_list_.end()) {
while (matched_list_.back() != &cur_node) {
matched_list_.pop_back();

// Find _contrib_quantized_elemwise_add or elemwise_add
if (EndsWith(output_n->op()->name, "elemwise_add")) {
if (quantized_) {
auto const& fc_param = nnvm::get<DNNLFCFullParam>(cur_n->attrs.parsed);
if (!fc_param.dnnl_param.enable_float_output) {
// For quantized graph, when FC floating point output is not enabled elementwise add must
// also be quantized (min and max value have to be already stored in elementwise add).
CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1);
}
}
status_ = kSuccess;
patternFound = true;
return true;
} else {
return false;
}

switch (status_) {
case kStart:
// Find _contrib_quantized_elemwise_add or elemwise_add
if (EndsWith(output_n->op()->name, "elemwise_add")) {
if (quantized_) {
auto const& fc_param = nnvm::get<DNNLFCFullParam>(cur_n->attrs.parsed);
if (!fc_param.dnnl_param.enable_float_output) {
// For quantized graph, when FC floating point output is not enabled
// elementwise add must also be quantized (min and max value have to be already stored
// in elementwise add).
CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1);
}
}
matched_list_.push_back(&output_node);
status_ = kSuccess;
return true;
}
default:
status_ = kFail;
return false;
}
}

std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& candidates) override {
if (status_ == kSuccess) {
if (patternFound) {
return candidates;
} else {
return std::vector<BiDirectedNode*>(0);
}
}

void Reset() override {
CHECK_GE(matched_list_.size(), 1);
auto new_selector = SgDNNLFCSumFuseSelector(quantized_);
new_selector.Select(*matched_list_[0], nullptr);
*this = new_selector;
patternFound = false;
}
};

Expand All @@ -147,11 +121,11 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
SgDNNLFCSumFuseProperty() {}

static SubgraphPropertyPtr Create() {
static const std::string& name = "DNNL fuse FullyConnected with sum";
static const std::string& name = "oneDNN fuse FullyConnected with sum";
auto property = std::make_shared<SgDNNLFCSumFuseProperty>();
property->SetAttr<std::string>("property_name", name);
property->SetAttr<bool>("inference_only", true);
if (dmlc::GetEnv("MXNET_DISABLE_DNNL_FC_SUM", 0)) {
if (dmlc::GetEnv("MXNET_DISABLE_ONEDNN_FC_SUM", 0)) {
property->SetAttr<bool>("disable", true);
}
return property;
Expand Down Expand Up @@ -207,33 +181,33 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
return selector;
}

void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
void ConnectSubgraphOutputs(const nnvm::ObjectPtr subgraph_node,
std::vector<nnvm::NodeEntry*>* output_entries) const override {
// Connect all extern output entries to output[0]
for (size_t i = 0; i < output_entries->size(); ++i) {
auto entry_ptr = output_entries->at(i);
*entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0};
*entry_ptr = nnvm::NodeEntry{subgraph_node, entry_ptr->index, 0};
}
}

void ConnectSubgraphInputs(const nnvm::ObjectPtr n,
void ConnectSubgraphInputs(const nnvm::ObjectPtr subgraph_node,
std::vector<nnvm::NodeEntry*>* input_entries,
std::vector<nnvm::NodeEntry>* orig_input_entries) const override {
auto sym = n->attrs.subgraphs[0];
auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
std::unordered_set<const nnvm::Node*> node_sets;
auto sym = subgraph_node->attrs.subgraphs[0];
auto const& fc_param = nnvm::get<DNNLFCFullParam>(subgraph_node->attrs.parsed);
std::unordered_set<const nnvm::Node*> node_set;
DFSVisit(sym->outputs, [&](const nnvm::ObjectPtr& node) {
if (node->is_variable()) {
return;
}
node_sets.insert(node.get());
node_set.insert(node.get());
if (EndsWith(node->op()->name, "elemwise_add")) {
const size_t base_inputs = fc_param.default_param.no_bias ? 3 : 4;
// Make sure fc output is the left operand of the add operator, if not:
// - swap inputs of add operator
// - switch add operands sequence to ensure that
// the tensor (sum_tensor) to which FC output is added is the last input.
if (node_sets.count(node->inputs[1].node.get())) {
if (node_set.count(node->inputs[1].node.get())) {
// Example of input_entries reordering for channel-wise quantized graph:
// sum_tensor.data --> fc.data
// fc.data --> fc.weight0
Expand Down Expand Up @@ -272,12 +246,12 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
}
}
});
n->inputs = *orig_input_entries;
subgraph_node->inputs = *orig_input_entries;
}
};

} // namespace op
} // namespace mxnet

#endif // if MXNET_USE_ONEDNN == 1
#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_
2 changes: 1 addition & 1 deletion src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include "dnnl_pow_mul_scalar_property.h"
#include "dnnl_transformer_qk_property.h"
#include "dnnl_transformer_valatt_property.h"
#include "dnnl_fc_sum_fuse.h"
#include "dnnl_fc_sum_fuse_property.h"
#include "dnnl_remove_casts_property.h"

namespace mxnet {
Expand Down

0 comments on commit e36c9f0

Please sign in to comment.