From e9ac480552f6069c9753a1823422b107e14569f4 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Sat, 2 Jun 2018 01:03:00 +0000 Subject: [PATCH 01/12] add support for subgraphs. --- nnvm/include/nnvm/node.h | 9 ++ nnvm/include/nnvm/op_attr_types.h | 12 +++ nnvm/src/core/symbolic.cc | 81 +++++++++++---- nnvm/src/pass/saveload_json.cc | 165 ++++++++++++++++++++---------- 4 files changed, 196 insertions(+), 71 deletions(-) diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index 15db77ee6043..1629a70cfc3d 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -18,6 +18,7 @@ namespace nnvm { // Forward declare node. class Node; +class Symbol; /*! * \brief we always used NodePtr for a reference pointer @@ -90,6 +91,14 @@ struct NodeAttrs { * The object can be used to quickly access attributes. */ any parsed; + /*! + * \brief Some operators take graphs as input. These operators include + * control flow operators and high-order functions. + * These graphs don't change when the operators are invoked for different + * mini-batches. In this sense, the subgraphs are kind of similar to + * the parameters and show be kept as node attributes. + */ + std::vector > subgraphs; }; /*! diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index e58e9ceb3581..b7f6be408a16 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -202,6 +202,18 @@ using FCorrectLayout = std::function *last_ilayouts, std::vector *olayouts)>; +/*! + * \brief Get a list of inputs that represent graphs instead of data. + * Normally, input symbols are considered as data to the operator. However, + * control flow operators and high-order functions need to interpret symbols + * as graphs. + * \param attrs The attributes of this node. + * \return a list of input index that are interpreted as symbols by the operator. + * + * \note Register under "FInputGraph". + */ +using FInputGraph = std::function(const NodeAttrs& attrs)>; + } // namespace nnvm #endif // NNVM_OP_ATTR_TYPES_H_ diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc index 2a2f5be50bc0..927dd2b70e44 100644 --- a/nnvm/src/core/symbolic.cc +++ b/nnvm/src/core/symbolic.cc @@ -267,14 +267,36 @@ void Symbol::Compose(const array_view& args, const std::string& name) { static auto& flist_inputs = Op::GetAttr("FListInputNames"); static auto& fset_attrs = Op::GetAttr("FSetInputVarAttrOnCompose"); + static auto& fgraph = Op::GetAttr("FInputGraph"); + + // The arguments that contain graphs. + Node* n = outputs[0].node.get(); + FInputGraph fng = fgraph.get(n->op(), nullptr); + std::vector garg_idx; + if (fng != nullptr) + garg_idx = fng(n->attrs); + + // The names of the arguments that contain graphs. + FListInputNames name_fn = flist_inputs.get(n->op(), nullptr); + auto arg_names = (name_fn == nullptr) ? std::vector{"data"} : name_fn(n->attrs); + std::vector garg_names(garg_idx.size()); + for (size_t i = 0; i < garg_idx.size(); i++) { + size_t idx = garg_idx[i]; + if (idx < arg_names.size()) + garg_names[i] = arg_names[idx]; + } // parameter check. for (size_t i = 0; i < args.size(); ++i) { - CHECK_EQ(args[i]->outputs.size(), 1U) + // If the argument isn't a graph, it should have only one output. + if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end()) + CHECK_EQ(args[i]->outputs.size(), 1U) << "Argument " << i << " is a tuple, single value is required"; } for (const auto& kv : kwargs) { - CHECK_EQ(kv.second->outputs.size(), 1U) + if (garg_names.empty() + || std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end()) + CHECK_EQ(kv.second->outputs.size(), 1U) << "Keyword Argument " << kv.first << " is a tuple, single value is required"; } // assign new name @@ -282,28 +304,49 @@ void Symbol::Compose(const array_view& args, // Atomic functor composition. if (IsAtomic(outputs)) { - Node* n = outputs[0].node.get(); uint32_t n_req = n->num_inputs(); + std::vector arg_vec(args.begin(), args.end()); + std::unordered_map kwarg_map(kwargs.begin(), kwargs.end()); + // If one of the input arguments is a graph, we need to remove it from the + // list. + if (fng != nullptr) { + std::vector idxes = fng(n->attrs); + for (auto idx : idxes) { + const Symbol *sym; + if (idx < arg_vec.size()) { + sym = arg_vec[idx]; + arg_vec.erase(arg_vec.begin() + idx); + } else { + auto it = kwarg_map.find(arg_names[idx]); + CHECK(it != kwarg_map.end()); + sym = it->second; + kwarg_map.erase(it); + } + + if (n_req != kVarg) + n_req--; + arg_names.erase(arg_names.begin() + idx); + n->attrs.subgraphs.push_back(std::make_shared(*sym)); + } + } if (n_req != kVarg) { n->inputs.resize(n_req); - CHECK_LE(args.size(), n_req) + CHECK_LE(arg_vec.size(), n_req) << "Incorrect number of arguments, requires " << n_req - << ", provided " << args.size(); - for (size_t i = 0; i < args.size(); ++i) { - n->inputs[i] = args[i]->outputs[0]; + << ", provided " << arg_vec.size(); + for (size_t i = 0; i < arg_vec.size(); ++i) { + n->inputs[i] = arg_vec[i]->outputs[0]; } // switch to keyword argument matching - if (args.size() != n_req) { - FListInputNames fn = flist_inputs.get(n->op(), nullptr); - auto arg_names = (fn == nullptr) ? std::vector{"data"} : fn(n->attrs); + if (arg_vec.size() != n_req) { if (arg_names.size() != n_req) { LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name; } size_t nmatched = 0; - for (size_t i = args.size(); i < n_req; ++i) { - auto it = kwargs.find(arg_names[i]); - if (it != kwargs.end() && it->first == arg_names[i]) { + for (size_t i = arg_vec.size(); i < n_req; ++i) { + auto it = kwarg_map.find(arg_names[i]); + if (it != kwarg_map.end() && it->first == arg_names[i]) { n->inputs[i] = it->second->outputs[0]; ++nmatched; } else { @@ -314,18 +357,18 @@ void Symbol::Compose(const array_view& args, } } - if (nmatched != kwargs.size()) { + if (nmatched != kwarg_map.size()) { n->inputs.clear(); - std::vector keys = GetKeys(kwargs); - array_view view(dmlc::BeginPtr(arg_names) + args.size(), + std::vector keys = GetKeys(kwarg_map); + array_view view(dmlc::BeginPtr(arg_names) + arg_vec.size(), dmlc::BeginPtr(arg_names) + arg_names.size()); KeywordArgumentMismatch("Symbol.Compose", keys, view); } } } else { - CHECK_EQ(kwargs.size(), 0U) << "Variable length function do not accept kwargs"; - n->inputs.reserve(args.size()); - for (const Symbol* s : args) { + CHECK_EQ(kwarg_map.size(), 0U) << "Variable length function do not accept kwargs"; + n->inputs.reserve(arg_vec.size()); + for (const Symbol* s : arg_vec) { n->inputs.push_back(s->outputs[0]); } } diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 3170d245ca7a..f1d99616f58f 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -29,6 +29,11 @@ namespace nnvm { namespace pass { namespace { +// JSONNode represents an nnvm::Node in JSON +struct JSONNode; +// JSONGraph represents an nnvm::Graph or nnvm::Symbol in JSON +struct JSONGraph; + // auxiliary node structure for serialization. struct JSONNode { // the node entry structure in serialized format @@ -36,6 +41,10 @@ struct JSONNode { uint32_t node_id; uint32_t index; uint32_t version; + Entry() = default; + Entry(uint32_t node_id, uint32_t index, uint32_t version): + node_id(node_id), index(index), version(version) { + } void Save(dmlc::JSONWriter *writer) const { writer->BeginArray(false); writer->WriteArrayItem(node_id); @@ -64,6 +73,8 @@ struct JSONNode { std::vector inputs; // control flow dependencies std::vector control_deps; + // subgraphs + std::vector subgraphs; // function to save JSON node. void Save(dmlc::JSONWriter *writer) const { @@ -85,6 +96,9 @@ struct JSONNode { if (control_deps.size() != 0) { writer->WriteObjectKeyValue("control_deps", control_deps); } + if (subgraphs.size() != 0) { + writer->WriteObjectKeyValue("subgraphs", subgraphs); + } writer->EndObject(); } @@ -99,6 +113,7 @@ struct JSONNode { helper.DeclareOptionalField("attrs", &(node->attrs.dict)); helper.DeclareOptionalField("attr", &(node->attrs.dict)); helper.DeclareOptionalField("control_deps", &control_deps); + helper.DeclareOptionalField("subgraphs", &subgraphs); // backward compatible code with mxnet graph. int backward_source_id; std::unordered_map param; @@ -129,6 +144,8 @@ struct JSONGraph { std::vector node_row_ptr; std::vector heads; std::unordered_map > attrs; + // all subgraphs of the JSONGraph, used only in saving. + std::vector subgraphs; void Save(dmlc::JSONWriter *writer) const { writer->BeginObject(); @@ -154,86 +171,130 @@ struct JSONGraph { } }; -// Load a graph from JSON file. -Graph LoadJSON(Graph src) { - CHECK_NE(src.attrs.count("json"), 0U) - << "Load JSON require json to be presented."; - const std::string &json_str = - nnvm::get(*src.attrs.at("json")); - bool no_parse = false; - if (src.attrs.count("load_json_no_parse")) { - no_parse = nnvm::get(*src.attrs.at("load_json_no_parse")); +void Symbol2JSONGraph(std::shared_ptr src, JSONGraph &jgraph) { + std::unordered_map node2index; + jgraph.node_row_ptr.push_back(0); + DFSVisit(src->outputs, [&node2index, &jgraph](const NodePtr& n) { + uint32_t nid = static_cast(jgraph.nodes.size()); + node2index[n.get()] = nid; + if (n->is_variable()) { + jgraph.arg_nodes.push_back(nid); + } + JSONNode jnode; + jnode.node = n; + jnode.inputs.reserve(n->inputs.size()); + for (const NodeEntry& e : n->inputs) { + jnode.inputs.emplace_back(node2index.at(e.node.get()), e.index, e.version); + } + for (const NodePtr& c : n->control_deps) { + jnode.control_deps.push_back(node2index.at(c.get())); + } + jgraph.node_row_ptr.push_back(jgraph.node_row_ptr.back() + n->num_outputs()); + jgraph.nodes.emplace_back(std::move(jnode)); + }); + for (const NodeEntry& e : src->outputs) { + jgraph.heads.emplace_back(node2index.at(e.node.get()), e.index, e.version); } - std::istringstream is(json_str); - dmlc::JSONReader reader(&is); - JSONGraph jgraph; - // load in json graph. - jgraph.Load(&reader); - // connects the nodes +} + +std::shared_ptr JSONGraph2Symbol(JSONGraph &jgraph, bool no_parse) { for (JSONNode &n : jgraph.nodes) { n.node->inputs.reserve(n.inputs.size()); for (const JSONNode::Entry &e : n.inputs) { - n.node->inputs.emplace_back( - NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); + n.node->inputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); } n.node->control_deps.reserve(n.control_deps.size()); for (uint32_t nid : n.control_deps) { n.node->control_deps.push_back(jgraph.nodes[nid].node); } // rebuild attribute parser - if (!no_parse && n.node->op() != nullptr && - n.node->op()->attr_parser != nullptr) { + if (!no_parse && n.node->op() != nullptr && n.node->op()->attr_parser != nullptr) { n.node->op()->attr_parser(&(n.node->attrs)); } + for (JSONGraph &subgraph : n.subgraphs) { + // The "no_parse" option here, is to be compatible with + // commit cfd3075e85807dcd8f9534c37e053583dee87524 + // (https://github.com/apache/incubator-mxnet/tree/cfd3075e85807dcd8f9534c37e053583dee87524), + // where the parsing of main graph is deferred until + // incubator-mxnet/src/nnvm/legacy_json_util.cc:UpgradeJSON_Parse + n.node->attrs.subgraphs.push_back(JSONGraph2Symbol(subgraph, false)); + } } // consistent check for (uint32_t nid : jgraph.arg_nodes) { CHECK(jgraph.nodes[nid].node->is_variable()); } + std::shared_ptr symbol = std::make_shared(); + symbol->outputs.reserve(jgraph.heads.size()); + for (const JSONNode::Entry &e : jgraph.heads) { + symbol->outputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); + } + return symbol; +} + +void DFSSubgraph(std::shared_ptr root, JSONGraph &jroot) { + // a standard DFS. + // the JGraphPtr in the stack forms a path from root to current node. + // the uint32_t indicates how many children of the corresponding JGraphPtr have been pushed to stack. + Symbol2JSONGraph(root, jroot); + std::vector > stack; + stack.emplace_back(&jroot, 0U); + while (!stack.empty()) { + std::pair &top = stack.back(); + JSONGraph *jgraph = top.first; + uint32_t &next_subgraph = top.second; + if (next_subgraph == 0U) { + // this is the first time we visit this jgraph + // convert each jnode's symbolic subgraphs to JSONGraph + for (JSONNode &jnode : jgraph->nodes) { + const std::vector> &subgraphs = jnode.node->attrs.subgraphs; + jnode.subgraphs.resize(subgraphs.size()); + for (uint32_t i = 0; i < subgraphs.size(); ++i) { + Symbol2JSONGraph(subgraphs[i], jnode.subgraphs[i]); + jgraph->subgraphs.push_back(&jnode.subgraphs[i]); + } + } + } + if (next_subgraph == jgraph->subgraphs.size()) { + stack.pop_back(); + } + else { + JSONGraph *subgraph = jgraph->subgraphs[next_subgraph++]; + stack.emplace_back(subgraph, 0U); + } + } +} +// Load a graph from JSON file. +Graph LoadJSON(Graph src) { + CHECK_NE(src.attrs.count("json"), 0U) + << "Load JSON require json to be presented."; + const std::string &json_str = + nnvm::get(*src.attrs.at("json")); + bool no_parse = false; + if (src.attrs.count("load_json_no_parse")) { + no_parse = nnvm::get(*src.attrs.at("load_json_no_parse")); + } + std::istringstream is(json_str); + dmlc::JSONReader reader(&is); + JSONGraph jgraph; + // load in json graph. + jgraph.Load(&reader); + std::shared_ptr symbol = JSONGraph2Symbol(jgraph, no_parse); // return the graph Graph ret; ret.attrs = std::move(jgraph.attrs); - ret.outputs.reserve(jgraph.heads.size()); - for (const JSONNode::Entry &e : jgraph.heads) { - ret.outputs.emplace_back( - NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); - } + ret.outputs = symbol->outputs; return ret; } // save a graph to json Graph SaveJSON(Graph src) { + std::shared_ptr src_symbol = std::make_shared(); + src_symbol->outputs = src.outputs; JSONGraph jgraph; + DFSSubgraph(src_symbol, jgraph); jgraph.attrs = src.attrs; - std::unordered_map node2index; - jgraph.node_row_ptr.push_back(0); - DFSVisit(src.outputs, [&node2index, &jgraph](const NodePtr& n) { - uint32_t nid = static_cast(jgraph.nodes.size()); - node2index[n.get()] = nid; - if (n->is_variable()) { - jgraph.arg_nodes.push_back(nid); - } - JSONNode jnode; - jnode.node = n; - jnode.inputs.reserve(n->inputs.size()); - for (const NodeEntry& e : n->inputs) { - jnode.inputs.emplace_back( - JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version}); - } - for (const NodePtr& c : n->control_deps) { - jnode.control_deps.push_back(node2index.at(c.get())); - } - jgraph.node_row_ptr.push_back( - jgraph.node_row_ptr.back() + n->num_outputs()); - jgraph.nodes.emplace_back(std::move(jnode)); - }); - - for (const NodeEntry& e : src.outputs) { - jgraph.heads.push_back( - JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version}); - } - std::ostringstream os; dmlc::JSONWriter writer(&os); jgraph.Save(&writer); From b3cfddee0d7b237bbbb03891738f563b9a169b1b Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Sat, 2 Jun 2018 01:11:08 +0000 Subject: [PATCH 02/12] fix. --- nnvm/src/pass/saveload_json.cc | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index f1d99616f58f..d58ef1aaccc3 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -171,14 +171,14 @@ struct JSONGraph { } }; -void Symbol2JSONGraph(std::shared_ptr src, JSONGraph &jgraph) { +void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { std::unordered_map node2index; - jgraph.node_row_ptr.push_back(0); - DFSVisit(src->outputs, [&node2index, &jgraph](const NodePtr& n) { - uint32_t nid = static_cast(jgraph.nodes.size()); + jgraph->node_row_ptr.push_back(0); + DFSVisit(src->outputs, [&node2index, jgraph](const NodePtr& n) { + uint32_t nid = static_cast(jgraph->nodes.size()); node2index[n.get()] = nid; if (n->is_variable()) { - jgraph.arg_nodes.push_back(nid); + jgraph->arg_nodes.push_back(nid); } JSONNode jnode; jnode.node = n; @@ -189,15 +189,15 @@ void Symbol2JSONGraph(std::shared_ptr src, JSONGraph &jgraph) { for (const NodePtr& c : n->control_deps) { jnode.control_deps.push_back(node2index.at(c.get())); } - jgraph.node_row_ptr.push_back(jgraph.node_row_ptr.back() + n->num_outputs()); - jgraph.nodes.emplace_back(std::move(jnode)); + jgraph->node_row_ptr.push_back(jgraph->node_row_ptr.back() + n->num_outputs()); + jgraph->nodes.emplace_back(std::move(jnode)); }); for (const NodeEntry& e : src->outputs) { - jgraph.heads.emplace_back(node2index.at(e.node.get()), e.index, e.version); + jgraph->heads.emplace_back(node2index.at(e.node.get()), e.index, e.version); } } -std::shared_ptr JSONGraph2Symbol(JSONGraph &jgraph, bool no_parse) { +std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) { for (JSONNode &n : jgraph.nodes) { n.node->inputs.reserve(n.inputs.size()); for (const JSONNode::Entry &e : n.inputs) { @@ -232,13 +232,13 @@ std::shared_ptr JSONGraph2Symbol(JSONGraph &jgraph, bool no_parse) { return symbol; } -void DFSSubgraph(std::shared_ptr root, JSONGraph &jroot) { +void DFSSubgraph(std::shared_ptr root, JSONGraph *jroot) { // a standard DFS. // the JGraphPtr in the stack forms a path from root to current node. // the uint32_t indicates how many children of the corresponding JGraphPtr have been pushed to stack. Symbol2JSONGraph(root, jroot); std::vector > stack; - stack.emplace_back(&jroot, 0U); + stack.emplace_back(jroot, 0U); while (!stack.empty()) { std::pair &top = stack.back(); JSONGraph *jgraph = top.first; @@ -250,7 +250,7 @@ void DFSSubgraph(std::shared_ptr root, JSONGraph &jroot) { const std::vector> &subgraphs = jnode.node->attrs.subgraphs; jnode.subgraphs.resize(subgraphs.size()); for (uint32_t i = 0; i < subgraphs.size(); ++i) { - Symbol2JSONGraph(subgraphs[i], jnode.subgraphs[i]); + Symbol2JSONGraph(subgraphs[i], &jnode.subgraphs[i]); jgraph->subgraphs.push_back(&jnode.subgraphs[i]); } } From 115405f9baef39560f52ec8f5b19e7f3fa038f14 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Sat, 2 Jun 2018 01:12:57 +0000 Subject: [PATCH 03/12] fix. --- nnvm/src/pass/saveload_json.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index d58ef1aaccc3..dd237b545b73 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -235,7 +235,8 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) void DFSSubgraph(std::shared_ptr root, JSONGraph *jroot) { // a standard DFS. // the JGraphPtr in the stack forms a path from root to current node. - // the uint32_t indicates how many children of the corresponding JGraphPtr have been pushed to stack. + // the uint32_t indicates how many children of the corresponding JGraphPtr + // have been pushed to stack. Symbol2JSONGraph(root, jroot); std::vector > stack; stack.emplace_back(jroot, 0U); @@ -257,8 +258,7 @@ void DFSSubgraph(std::shared_ptr root, JSONGraph *jroot) { } if (next_subgraph == jgraph->subgraphs.size()) { stack.pop_back(); - } - else { + } else { JSONGraph *subgraph = jgraph->subgraphs[next_subgraph++]; stack.emplace_back(subgraph, 0U); } From b3887c19beb7f31da3b9dc5bf11197d12e3d8e91 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 4 Jun 2018 12:09:36 -0700 Subject: [PATCH 04/12] Fix compilation error --- nnvm/src/pass/saveload_json.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index dd237b545b73..eb309fa65f6d 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -198,7 +198,7 @@ void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { } std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) { - for (JSONNode &n : jgraph.nodes) { + for (const JSONNode &n : jgraph.nodes) { n.node->inputs.reserve(n.inputs.size()); for (const JSONNode::Entry &e : n.inputs) { n.node->inputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); @@ -211,7 +211,7 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) if (!no_parse && n.node->op() != nullptr && n.node->op()->attr_parser != nullptr) { n.node->op()->attr_parser(&(n.node->attrs)); } - for (JSONGraph &subgraph : n.subgraphs) { + for (const JSONGraph &subgraph : n.subgraphs) { // The "no_parse" option here, is to be compatible with // commit cfd3075e85807dcd8f9534c37e053583dee87524 // (https://github.com/apache/incubator-mxnet/tree/cfd3075e85807dcd8f9534c37e053583dee87524), From af3177ab90ac25a8bd7c0f593118d347258ef889 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 4 Jun 2018 13:16:29 -0700 Subject: [PATCH 05/12] Fix compilation error --- nnvm/src/pass/saveload_json.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index eb309fa65f6d..245945915410 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -293,7 +293,7 @@ Graph SaveJSON(Graph src) { std::shared_ptr src_symbol = std::make_shared(); src_symbol->outputs = src.outputs; JSONGraph jgraph; - DFSSubgraph(src_symbol, jgraph); + DFSSubgraph(src_symbol, &jgraph); jgraph.attrs = src.attrs; std::ostringstream os; dmlc::JSONWriter writer(&os); From 3ec43c93a6d87990a6dd1a106355d7017e2b2348 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Thu, 7 Jun 2018 18:07:44 +0000 Subject: [PATCH 06/12] add comments. --- nnvm/include/nnvm/node.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index 1629a70cfc3d..61647ea77256 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -97,6 +97,11 @@ struct NodeAttrs { * These graphs don't change when the operators are invoked for different * mini-batches. In this sense, the subgraphs are kind of similar to * the parameters and show be kept as node attributes. + * + * Users need to make sure the subgraphs are disjoint with the main graph. + * Otherwise, loading a graph with subgraphs by using LoadJSON may generate + * a graph that has a different structure from the original graph (some of + * the nodes are duplicated). */ std::vector > subgraphs; }; From 0de1441d6b7b20fcbb323352b0673ccee2ab4402 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Thu, 7 Jun 2018 18:28:56 +0000 Subject: [PATCH 07/12] update comments. --- nnvm/include/nnvm/node.h | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index 61647ea77256..57afb0c5587a 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -99,9 +99,11 @@ struct NodeAttrs { * the parameters and show be kept as node attributes. * * Users need to make sure the subgraphs are disjoint with the main graph. - * Otherwise, loading a graph with subgraphs by using LoadJSON may generate - * a graph that has a different structure from the original graph (some of - * the nodes are duplicated). + * If a graph shares nodes with subgraphs, loading the graph from LoadJSON + * may generate a graph that has a different structure from the original graph + * (some of the nodes are duplicated). If nodes are shared between two graphs, + * shared nodes might be executed multiple times, which can be a problem for + * stateful operators. */ std::vector > subgraphs; }; From 18432e18ac247b5728d5d16a44866e84dc463f5c Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 7 Jun 2018 17:09:38 -0700 Subject: [PATCH 08/12] Sanity check on subgraphs when creating IndexedGraph --- nnvm/src/core/graph.cc | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index 62c7085c1210..9780ffbfb9cf 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -20,7 +20,37 @@ const IndexedGraph& Graph::indexed_graph() const { IndexedGraph::IndexedGraph(const Graph &g) { entry_rptr_.push_back(0); std::vector inputs_rptr{0}, control_rptr{0}; - + // sanity check here + // a subgraph should not refer to any nodes with higher level + // where "level" refers to the nested depth of the subgraph + // e.g. the main graph is level 0 + // subgraphs of the main graph is level 1 + // subgraphs of the subgraphs of the main graph is level 2 + std::vector*> curr_level; + std::vector*> next_level; + std::unordered_map node2level; + next_level.push_back(&g.outputs); + for (uint32_t level = 0; !next_level.empty(); ++level) { + curr_level.swap(next_level); + next_level.clear(); + for (const std::vector *graph_ptr : curr_level) { + const std::vector &graph = *graph_ptr; + DFSVisit(graph, [&next_level, &node2level, level](const NodePtr& n) { + nnvm::Node *node = n.get(); + // if the node is visited, but on a different level, then check failed + // if check failed here or before, we stop doing anything, but raise an error + CHECK(!node2level.count(node) || node2level[node] == level) + << "A subgraph should not depend on the outputs of nodes on higher levels"; + // otherwise, this node belongs to the current level + node2level[node] = level; + // subgraphs of current node belongs to next level + for (const auto& subgraph : n->attrs.subgraphs) { + next_level.push_back(&subgraph->outputs); + } + }); + } + } + // sanity check finishes, then we build the indexed graph DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr] (const NodePtr& n) { CHECK_LT(nodes_.size(), std::numeric_limits::max()); From 2fcb5fe54b57f88555bbc2754b3aa39daf4adada Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 8 Jun 2018 23:49:27 +0000 Subject: [PATCH 09/12] avoid the overhead of sanity check. --- nnvm/src/core/graph.cc | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index 9780ffbfb9cf..b8bcae70f2e0 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -16,25 +16,22 @@ const IndexedGraph& Graph::indexed_graph() const { return *indexed_graph_; } -// implement constructor from graph -IndexedGraph::IndexedGraph(const Graph &g) { - entry_rptr_.push_back(0); - std::vector inputs_rptr{0}, control_rptr{0}; - // sanity check here - // a subgraph should not refer to any nodes with higher level - // where "level" refers to the nested depth of the subgraph - // e.g. the main graph is level 0 - // subgraphs of the main graph is level 1 - // subgraphs of the subgraphs of the main graph is level 2 +// a subgraph should not refer to any nodes with higher level +// where "level" refers to the nested depth of the subgraph +// e.g. the main graph is level 0 +// subgraphs of the main graph is level 1 +// subgraphs of the subgraphs of the main graph is level 2 +static void SubgraphSanityCheck(const std::vector> &subgraphs) { std::vector*> curr_level; std::vector*> next_level; std::unordered_map node2level; - next_level.push_back(&g.outputs); + for (auto &subgraph : subgraphs) + next_level.push_back(&subgraph->outputs); for (uint32_t level = 0; !next_level.empty(); ++level) { curr_level.swap(next_level); next_level.clear(); - for (const std::vector *graph_ptr : curr_level) { - const std::vector &graph = *graph_ptr; + for (const std::vector *graph_ptr : curr_level) { + const std::vector &graph = *graph_ptr; DFSVisit(graph, [&next_level, &node2level, level](const NodePtr& n) { nnvm::Node *node = n.get(); // if the node is visited, but on a different level, then check failed @@ -50,11 +47,20 @@ IndexedGraph::IndexedGraph(const Graph &g) { }); } } - // sanity check finishes, then we build the indexed graph - DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr] +} + +// implement constructor from graph +IndexedGraph::IndexedGraph(const Graph &g) { + entry_rptr_.push_back(0); + std::vector inputs_rptr{0}, control_rptr{0}; + std::vector> subgraphs; + + DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs] (const NodePtr& n) { CHECK_LT(nodes_.size(), std::numeric_limits::max()); uint32_t nid = static_cast(nodes_.size()); + for (const auto &subgraph : n->attrs.subgraphs) + subgraphs.push_back(subgraph); // nodes_ IndexedGraph::Node new_node; new_node.source = n.get(); @@ -83,6 +89,8 @@ IndexedGraph::IndexedGraph(const Graph &g) { } control_rptr.push_back(control_deps_.size()); }); + if (!subgraphs.empty()) + SubgraphSanityCheck(subgraphs); for (const auto& e : g.outputs) { outputs_.emplace_back(NodeEntry{ From c29886754482774ab5f1a39e87b0596cbb562657 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 11 Jun 2018 12:23:11 -0700 Subject: [PATCH 10/12] Stop using non-recursive DFS --- nnvm/src/pass/saveload_json.cc | 49 +++++++++------------------------- 1 file changed, 12 insertions(+), 37 deletions(-) diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 245945915410..195d49bfb9b4 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -144,8 +144,6 @@ struct JSONGraph { std::vector node_row_ptr; std::vector heads; std::unordered_map > attrs; - // all subgraphs of the JSONGraph, used only in saving. - std::vector subgraphs; void Save(dmlc::JSONWriter *writer) const { writer->BeginObject(); @@ -195,6 +193,16 @@ void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { for (const NodeEntry& e : src->outputs) { jgraph->heads.emplace_back(node2index.at(e.node.get()), e.index, e.version); } + // recursively construct subgraphs + for (JSONNode &jnode : jgraph->nodes) { + // construct jnode's subgraphs + const std::vector> &subgraphs = jnode.node->attrs.subgraphs; + std::vector &jsubgraphs = jnode.subgraphs; + jsubgraphs.resize(subgraphs.size()); + for (uint32_t i = 0; i < subgraphs.size(); ++i) { + Symbol2JSONGraph(subgraphs[i], &jsubgraphs[i]); + } + } } std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) { @@ -220,7 +228,7 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) n.node->attrs.subgraphs.push_back(JSONGraph2Symbol(subgraph, false)); } } - // consistent check + // consistency check for (uint32_t nid : jgraph.arg_nodes) { CHECK(jgraph.nodes[nid].node->is_variable()); } @@ -232,39 +240,6 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) return symbol; } -void DFSSubgraph(std::shared_ptr root, JSONGraph *jroot) { - // a standard DFS. - // the JGraphPtr in the stack forms a path from root to current node. - // the uint32_t indicates how many children of the corresponding JGraphPtr - // have been pushed to stack. - Symbol2JSONGraph(root, jroot); - std::vector > stack; - stack.emplace_back(jroot, 0U); - while (!stack.empty()) { - std::pair &top = stack.back(); - JSONGraph *jgraph = top.first; - uint32_t &next_subgraph = top.second; - if (next_subgraph == 0U) { - // this is the first time we visit this jgraph - // convert each jnode's symbolic subgraphs to JSONGraph - for (JSONNode &jnode : jgraph->nodes) { - const std::vector> &subgraphs = jnode.node->attrs.subgraphs; - jnode.subgraphs.resize(subgraphs.size()); - for (uint32_t i = 0; i < subgraphs.size(); ++i) { - Symbol2JSONGraph(subgraphs[i], &jnode.subgraphs[i]); - jgraph->subgraphs.push_back(&jnode.subgraphs[i]); - } - } - } - if (next_subgraph == jgraph->subgraphs.size()) { - stack.pop_back(); - } else { - JSONGraph *subgraph = jgraph->subgraphs[next_subgraph++]; - stack.emplace_back(subgraph, 0U); - } - } -} - // Load a graph from JSON file. Graph LoadJSON(Graph src) { CHECK_NE(src.attrs.count("json"), 0U) @@ -293,7 +268,7 @@ Graph SaveJSON(Graph src) { std::shared_ptr src_symbol = std::make_shared(); src_symbol->outputs = src.outputs; JSONGraph jgraph; - DFSSubgraph(src_symbol, &jgraph); + Symbol2JSONGraph(src_symbol, &jgraph); jgraph.attrs = src.attrs; std::ostringstream os; dmlc::JSONWriter writer(&os); From b27d6a08f0c6340162b536e4a0f1062412bab296 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 11 Jun 2018 16:00:51 -0700 Subject: [PATCH 11/12] Trigger CI From 17734a5ae0391b3c6859f22e2bd8af2c12ca1601 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Mon, 11 Jun 2018 17:54:14 -0700 Subject: [PATCH 12/12] trigger CI