Skip to content

Commit

Permalink
Add feed&fetch as default deny ops. (#44708)
Browse files Browse the repository at this point in the history
  • Loading branch information
wzzju committed Aug 5, 2022
1 parent d0cf9d9 commit d4ca7ff
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
Expand Up @@ -62,6 +62,8 @@ const std::unordered_map<std::string, std::unordered_set<std::string>>
kDenyParamMap = {{"batch_norm", {"ReserveSpace"}},
{"batch_norm_grad", {"ReserveSpace"}}};

const std::unordered_set<std::string> kDefaultDenyOps = {"feed", "fetch"};

std::unordered_set<std::string> GetDenyVarNames(const GraphNodeSet& cluster) {
std::unordered_set<std::string> deny_var_set;

Expand Down Expand Up @@ -560,22 +562,24 @@ void SearchAllSubgraphs(Graph* graph) {
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
auto teller = [&allow_ops, &deny_ops](const Node* node) {
const auto& node_name = node->Name();
bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find(
node->Name()) != nullptr;
node_name) != nullptr;
// if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops
if (allow_ops.size()) {
return registered && allow_ops.count(node->Name());
if (!allow_ops.empty()) {
return registered && allow_ops.count(node_name);
}
// if the op type is registered in CINN and deny_ops is not empty, return
// true only when it is not in deny_ops
if (deny_ops.size()) {
return registered && !deny_ops.count(node->Name());
if (!deny_ops.empty()) {
return registered && !deny_ops.count(node_name);
}

// if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops,
// return true only when it is registered in CINN
return registered && (node->IsOp() && !IsInplaceOp(*node->Op()));
return registered && !kDefaultDenyOps.count(node_name) &&
(node->IsOp() && !IsInplaceOp(*node->Op()));
};
VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops;
VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops;
Expand Down

0 comments on commit d4ca7ff

Please sign in to comment.