From 8db8689297c78f5db11c59a62195ccc8dbd79957 Mon Sep 17 00:00:00 2001 From: Chen Bangduo Date: Mon, 11 Apr 2022 16:48:32 +0800 Subject: [PATCH] [Graph] Fix inputs from different frame bug for DIEN model when enable auto op fusion. (#153) --- .../core/graph/optimizer_fusion_engine.cc | 3 +- .../graph/optimizer_fusion_engine_impl.cc | 51 +++++++++++++++++++ .../core/graph/optimizer_fusion_engine_impl.h | 5 +- 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/graph/optimizer_fusion_engine.cc b/tensorflow/core/graph/optimizer_fusion_engine.cc index 8b6603daf4d..d1bcdbfb8e8 100644 --- a/tensorflow/core/graph/optimizer_fusion_engine.cc +++ b/tensorflow/core/graph/optimizer_fusion_engine.cc @@ -26,10 +26,10 @@ limitations under the License. #include "tensorflow/core/graph/template_select_pruning_else_const.h" #include "tensorflow/core/graph/template_select_pruning_then_const.h" #include "tensorflow/core/graph/template_sparse_inner_flatten.h" - namespace tensorflow { bool OptimizeFusion(Graph* g) { + bool changed = false; std::vector> templates; templates.emplace_back(new TemplateSparseInnerFlatten()); @@ -46,6 +46,7 @@ bool OptimizeFusion(Graph* g) { new OptimizerFusionImpl(g, t.get())); changed |= opt->Optimize(); } + return changed; } diff --git a/tensorflow/core/graph/optimizer_fusion_engine_impl.cc b/tensorflow/core/graph/optimizer_fusion_engine_impl.cc index afc5ffcf755..47b976209eb 100644 --- a/tensorflow/core/graph/optimizer_fusion_engine_impl.cc +++ b/tensorflow/core/graph/optimizer_fusion_engine_impl.cc @@ -1,5 +1,6 @@ #include #include +#include #include "tensorflow/core/graph/optimizer_fusion_engine_impl.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/graph/graph.h" @@ -25,6 +26,34 @@ OptimizerFusionImpl::OptimizerFusionImpl(Graph* g, TemplateBase* t) fused_op_outputs_.resize(t_->num_outputs_); use_dynamic_output_keys_ = false; use_dynamic_input_keys_ = false; + + std::unordered_set enter_nodes; + for (Node *node : g->nodes()) { + node_frame_map_[node] = ""; + if (node->IsEnter()) { + enter_nodes.insert(node); + } + } + + std::unordered_set has_visited; + for (Node *node : enter_nodes) { + const std::string frame_name = node->def().attr().at("frame_name").s(); + std::queue q; + q.push(node); + while (!q.empty()) { + Node *n = q.front(); + q.pop(); + has_visited.insert(n); + node_frame_map_[n] = frame_name; + for (auto e : n->out_edges()) { + Node *dst = e->dst(); + if (has_visited.find(dst) == has_visited.end() && + (!dst->IsExit() || !dst->IsNextIteration())) { + q.push(dst); + } + } + } + } } bool OptimizerFusionImpl::VisitMatchedNodes() { @@ -265,6 +294,21 @@ bool OptimizerFusionImpl::CheckInputs(const Node* node, return true; } +bool OptimizerFusionImpl::CheckMatchedNodeInSameFrame() { + // TODO: only op in default frame can be fused + const Node *first_key_node = matched_node_map_[t_->first_key_].node; + std::string frame_name = node_frame_map_[first_key_node]; + if (frame_name != "") + return false; + for (auto matched_node_it : matched_node_map_) { + const Node * node = std::get<1>(matched_node_it).node; + if (node_frame_map_[node] != frame_name) + return false; + } + + return true; +} + bool OptimizerFusionImpl::Optimize() { bool changed = false; // TODO(minmin) check Template consistency before really optimizing @@ -329,6 +373,12 @@ bool OptimizerFusionImpl::Optimize() { << temp_node_map_.size(); continue; } + + // double check the matched nodes are in same frame + if (!CheckMatchedNodeInSameFrame()) { + VLOG(2) << "Failed double check the matched nodes, they are not in same frame"; + continue; + } // double check the matched inputs bool passed = true; for (int i = 0; i < t_->num_inputs_; ++i) { @@ -363,6 +413,7 @@ bool OptimizerFusionImpl::Optimize() { VLOG(2) << "Failed double check the matched outputs"; continue; } + ++num_matched_; VLOG(2) << "Matched: " << num_matched_; for (auto iter = matched_node_map_.begin(); diff --git a/tensorflow/core/graph/optimizer_fusion_engine_impl.h b/tensorflow/core/graph/optimizer_fusion_engine_impl.h index faee92deb07..a091eb91b63 100644 --- a/tensorflow/core/graph/optimizer_fusion_engine_impl.h +++ b/tensorflow/core/graph/optimizer_fusion_engine_impl.h @@ -20,7 +20,9 @@ class OptimizerFusionImpl { const TempNode* temp_node); bool CheckInputs(const Node* node, const TempNode* temp_node); - private: + bool CheckMatchedNodeInSameFrame(); + +private: Graph* g_; TemplateBase* t_; std::map temp_node_map_; @@ -36,6 +38,7 @@ class OptimizerFusionImpl { int dynamic_input_port_cur_; std::vector> fused_op_outputs_dynamic_; std::vector fused_op_input_dynamic_; + std::map node_frame_map_; }; }