Skip to content

Commit

Permalink
[Graph] Fix inputs from different frame bug for DIEN model when enabl…
Browse files Browse the repository at this point in the history
…e auto op fusion. (#153)
  • Loading branch information
JackMoriarty committed Apr 11, 2022
1 parent c2d88f5 commit 8db8689
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 2 deletions.
3 changes: 2 additions & 1 deletion tensorflow/core/graph/optimizer_fusion_engine.cc
Expand Up @@ -26,10 +26,10 @@ limitations under the License.
#include "tensorflow/core/graph/template_select_pruning_else_const.h"
#include "tensorflow/core/graph/template_select_pruning_then_const.h"
#include "tensorflow/core/graph/template_sparse_inner_flatten.h"

namespace tensorflow {

bool OptimizeFusion(Graph* g) {

bool changed = false;
std::vector<std::unique_ptr<TemplateBase>> templates;
templates.emplace_back(new TemplateSparseInnerFlatten());
Expand All @@ -46,6 +46,7 @@ bool OptimizeFusion(Graph* g) {
new OptimizerFusionImpl(g, t.get()));
changed |= opt->Optimize();
}

return changed;
}

Expand Down
51 changes: 51 additions & 0 deletions tensorflow/core/graph/optimizer_fusion_engine_impl.cc
@@ -1,5 +1,6 @@
#include <algorithm>
#include <tuple>
#include <queue>
#include "tensorflow/core/graph/optimizer_fusion_engine_impl.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/graph/graph.h"
Expand All @@ -25,6 +26,34 @@ OptimizerFusionImpl::OptimizerFusionImpl(Graph* g, TemplateBase* t)
fused_op_outputs_.resize(t_->num_outputs_);
use_dynamic_output_keys_ = false;
use_dynamic_input_keys_ = false;

std::unordered_set<Node *> enter_nodes;
for (Node *node : g->nodes()) {
node_frame_map_[node] = "";
if (node->IsEnter()) {
enter_nodes.insert(node);
}
}

std::unordered_set<Node *> has_visited;
for (Node *node : enter_nodes) {
const std::string frame_name = node->def().attr().at("frame_name").s();
std::queue<Node *> q;
q.push(node);
while (!q.empty()) {
Node *n = q.front();
q.pop();
has_visited.insert(n);
node_frame_map_[n] = frame_name;
for (auto e : n->out_edges()) {
Node *dst = e->dst();
if (has_visited.find(dst) == has_visited.end() &&
(!dst->IsExit() || !dst->IsNextIteration())) {
q.push(dst);
}
}
}
}
}

bool OptimizerFusionImpl::VisitMatchedNodes() {
Expand Down Expand Up @@ -265,6 +294,21 @@ bool OptimizerFusionImpl::CheckInputs(const Node* node,
return true;
}

bool OptimizerFusionImpl::CheckMatchedNodeInSameFrame() {
// TODO: only op in default frame can be fused
const Node *first_key_node = matched_node_map_[t_->first_key_].node;
std::string frame_name = node_frame_map_[first_key_node];
if (frame_name != "")
return false;
for (auto matched_node_it : matched_node_map_) {
const Node * node = std::get<1>(matched_node_it).node;
if (node_frame_map_[node] != frame_name)
return false;
}

return true;
}

bool OptimizerFusionImpl::Optimize() {
bool changed = false;
// TODO(minmin) check Template consistency before really optimizing
Expand Down Expand Up @@ -329,6 +373,12 @@ bool OptimizerFusionImpl::Optimize() {
<< temp_node_map_.size();
continue;
}

// double check the matched nodes are in same frame
if (!CheckMatchedNodeInSameFrame()) {
VLOG(2) << "Failed double check the matched nodes, they are not in same frame";
continue;
}
// double check the matched inputs
bool passed = true;
for (int i = 0; i < t_->num_inputs_; ++i) {
Expand Down Expand Up @@ -363,6 +413,7 @@ bool OptimizerFusionImpl::Optimize() {
VLOG(2) << "Failed double check the matched outputs";
continue;
}

++num_matched_;
VLOG(2) << "Matched: " << num_matched_;
for (auto iter = matched_node_map_.begin();
Expand Down
5 changes: 4 additions & 1 deletion tensorflow/core/graph/optimizer_fusion_engine_impl.h
Expand Up @@ -20,7 +20,9 @@ class OptimizerFusionImpl {
const TempNode* temp_node);
bool CheckInputs(const Node* node,
const TempNode* temp_node);
private:
bool CheckMatchedNodeInSameFrame();

private:
Graph* g_;
TemplateBase* t_;
std::map<const std::string, TempNode> temp_node_map_;
Expand All @@ -36,6 +38,7 @@ class OptimizerFusionImpl {
int dynamic_input_port_cur_;
std::vector<std::vector<const Edge*>> fused_op_outputs_dynamic_;
std::vector<const Edge*> fused_op_input_dynamic_;
std::map<const Node *, std::string> node_frame_map_;
};

}
Expand Down

0 comments on commit 8db8689

Please sign in to comment.