diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index 15db77ee6043..57afb0c5587a 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -18,6 +18,7 @@ namespace nnvm { // Forward declare node. class Node; +class Symbol; /*! * \brief we always used NodePtr for a reference pointer @@ -90,6 +91,21 @@ struct NodeAttrs { * The object can be used to quickly access attributes. */ any parsed; + /*! + * \brief Some operators take graphs as input. These operators include + * control flow operators and high-order functions. + * These graphs don't change when the operators are invoked for different + * mini-batches. In this sense, the subgraphs are kind of similar to + * the parameters and show be kept as node attributes. + * + * Users need to make sure the subgraphs are disjoint with the main graph. + * If a graph shares nodes with subgraphs, loading the graph from LoadJSON + * may generate a graph that has a different structure from the original graph + * (some of the nodes are duplicated). If nodes are shared between two graphs, + * shared nodes might be executed multiple times, which can be a problem for + * stateful operators. + */ + std::vector > subgraphs; }; /*! diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index e58e9ceb3581..b7f6be408a16 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -202,6 +202,18 @@ using FCorrectLayout = std::function *last_ilayouts, std::vector *olayouts)>; +/*! + * \brief Get a list of inputs that represent graphs instead of data. + * Normally, input symbols are considered as data to the operator. However, + * control flow operators and high-order functions need to interpret symbols + * as graphs. + * \param attrs The attributes of this node. + * \return a list of input index that are interpreted as symbols by the operator. + * + * \note Register under "FInputGraph". + */ +using FInputGraph = std::function(const NodeAttrs& attrs)>; + } // namespace nnvm #endif // NNVM_OP_ATTR_TYPES_H_ diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index 62c7085c1210..b8bcae70f2e0 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -16,15 +16,51 @@ const IndexedGraph& Graph::indexed_graph() const { return *indexed_graph_; } +// a subgraph should not refer to any nodes with higher level +// where "level" refers to the nested depth of the subgraph +// e.g. the main graph is level 0 +// subgraphs of the main graph is level 1 +// subgraphs of the subgraphs of the main graph is level 2 +static void SubgraphSanityCheck(const std::vector> &subgraphs) { + std::vector*> curr_level; + std::vector*> next_level; + std::unordered_map node2level; + for (auto &subgraph : subgraphs) + next_level.push_back(&subgraph->outputs); + for (uint32_t level = 0; !next_level.empty(); ++level) { + curr_level.swap(next_level); + next_level.clear(); + for (const std::vector *graph_ptr : curr_level) { + const std::vector &graph = *graph_ptr; + DFSVisit(graph, [&next_level, &node2level, level](const NodePtr& n) { + nnvm::Node *node = n.get(); + // if the node is visited, but on a different level, then check failed + // if check failed here or before, we stop doing anything, but raise an error + CHECK(!node2level.count(node) || node2level[node] == level) + << "A subgraph should not depend on the outputs of nodes on higher levels"; + // otherwise, this node belongs to the current level + node2level[node] = level; + // subgraphs of current node belongs to next level + for (const auto& subgraph : n->attrs.subgraphs) { + next_level.push_back(&subgraph->outputs); + } + }); + } + } +} + // implement constructor from graph IndexedGraph::IndexedGraph(const Graph &g) { entry_rptr_.push_back(0); std::vector inputs_rptr{0}, control_rptr{0}; + std::vector> subgraphs; - DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr] + DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs] (const NodePtr& n) { CHECK_LT(nodes_.size(), std::numeric_limits::max()); uint32_t nid = static_cast(nodes_.size()); + for (const auto &subgraph : n->attrs.subgraphs) + subgraphs.push_back(subgraph); // nodes_ IndexedGraph::Node new_node; new_node.source = n.get(); @@ -53,6 +89,8 @@ IndexedGraph::IndexedGraph(const Graph &g) { } control_rptr.push_back(control_deps_.size()); }); + if (!subgraphs.empty()) + SubgraphSanityCheck(subgraphs); for (const auto& e : g.outputs) { outputs_.emplace_back(NodeEntry{ diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc index 2a2f5be50bc0..927dd2b70e44 100644 --- a/nnvm/src/core/symbolic.cc +++ b/nnvm/src/core/symbolic.cc @@ -267,14 +267,36 @@ void Symbol::Compose(const array_view& args, const std::string& name) { static auto& flist_inputs = Op::GetAttr("FListInputNames"); static auto& fset_attrs = Op::GetAttr("FSetInputVarAttrOnCompose"); + static auto& fgraph = Op::GetAttr("FInputGraph"); + + // The arguments that contain graphs. + Node* n = outputs[0].node.get(); + FInputGraph fng = fgraph.get(n->op(), nullptr); + std::vector garg_idx; + if (fng != nullptr) + garg_idx = fng(n->attrs); + + // The names of the arguments that contain graphs. + FListInputNames name_fn = flist_inputs.get(n->op(), nullptr); + auto arg_names = (name_fn == nullptr) ? std::vector{"data"} : name_fn(n->attrs); + std::vector garg_names(garg_idx.size()); + for (size_t i = 0; i < garg_idx.size(); i++) { + size_t idx = garg_idx[i]; + if (idx < arg_names.size()) + garg_names[i] = arg_names[idx]; + } // parameter check. for (size_t i = 0; i < args.size(); ++i) { - CHECK_EQ(args[i]->outputs.size(), 1U) + // If the argument isn't a graph, it should have only one output. + if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end()) + CHECK_EQ(args[i]->outputs.size(), 1U) << "Argument " << i << " is a tuple, single value is required"; } for (const auto& kv : kwargs) { - CHECK_EQ(kv.second->outputs.size(), 1U) + if (garg_names.empty() + || std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end()) + CHECK_EQ(kv.second->outputs.size(), 1U) << "Keyword Argument " << kv.first << " is a tuple, single value is required"; } // assign new name @@ -282,28 +304,49 @@ void Symbol::Compose(const array_view& args, // Atomic functor composition. if (IsAtomic(outputs)) { - Node* n = outputs[0].node.get(); uint32_t n_req = n->num_inputs(); + std::vector arg_vec(args.begin(), args.end()); + std::unordered_map kwarg_map(kwargs.begin(), kwargs.end()); + // If one of the input arguments is a graph, we need to remove it from the + // list. + if (fng != nullptr) { + std::vector idxes = fng(n->attrs); + for (auto idx : idxes) { + const Symbol *sym; + if (idx < arg_vec.size()) { + sym = arg_vec[idx]; + arg_vec.erase(arg_vec.begin() + idx); + } else { + auto it = kwarg_map.find(arg_names[idx]); + CHECK(it != kwarg_map.end()); + sym = it->second; + kwarg_map.erase(it); + } + + if (n_req != kVarg) + n_req--; + arg_names.erase(arg_names.begin() + idx); + n->attrs.subgraphs.push_back(std::make_shared(*sym)); + } + } if (n_req != kVarg) { n->inputs.resize(n_req); - CHECK_LE(args.size(), n_req) + CHECK_LE(arg_vec.size(), n_req) << "Incorrect number of arguments, requires " << n_req - << ", provided " << args.size(); - for (size_t i = 0; i < args.size(); ++i) { - n->inputs[i] = args[i]->outputs[0]; + << ", provided " << arg_vec.size(); + for (size_t i = 0; i < arg_vec.size(); ++i) { + n->inputs[i] = arg_vec[i]->outputs[0]; } // switch to keyword argument matching - if (args.size() != n_req) { - FListInputNames fn = flist_inputs.get(n->op(), nullptr); - auto arg_names = (fn == nullptr) ? std::vector{"data"} : fn(n->attrs); + if (arg_vec.size() != n_req) { if (arg_names.size() != n_req) { LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name; } size_t nmatched = 0; - for (size_t i = args.size(); i < n_req; ++i) { - auto it = kwargs.find(arg_names[i]); - if (it != kwargs.end() && it->first == arg_names[i]) { + for (size_t i = arg_vec.size(); i < n_req; ++i) { + auto it = kwarg_map.find(arg_names[i]); + if (it != kwarg_map.end() && it->first == arg_names[i]) { n->inputs[i] = it->second->outputs[0]; ++nmatched; } else { @@ -314,18 +357,18 @@ void Symbol::Compose(const array_view& args, } } - if (nmatched != kwargs.size()) { + if (nmatched != kwarg_map.size()) { n->inputs.clear(); - std::vector keys = GetKeys(kwargs); - array_view view(dmlc::BeginPtr(arg_names) + args.size(), + std::vector keys = GetKeys(kwarg_map); + array_view view(dmlc::BeginPtr(arg_names) + arg_vec.size(), dmlc::BeginPtr(arg_names) + arg_names.size()); KeywordArgumentMismatch("Symbol.Compose", keys, view); } } } else { - CHECK_EQ(kwargs.size(), 0U) << "Variable length function do not accept kwargs"; - n->inputs.reserve(args.size()); - for (const Symbol* s : args) { + CHECK_EQ(kwarg_map.size(), 0U) << "Variable length function do not accept kwargs"; + n->inputs.reserve(arg_vec.size()); + for (const Symbol* s : arg_vec) { n->inputs.push_back(s->outputs[0]); } } diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 3170d245ca7a..195d49bfb9b4 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -29,6 +29,11 @@ namespace nnvm { namespace pass { namespace { +// JSONNode represents an nnvm::Node in JSON +struct JSONNode; +// JSONGraph represents an nnvm::Graph or nnvm::Symbol in JSON +struct JSONGraph; + // auxiliary node structure for serialization. struct JSONNode { // the node entry structure in serialized format @@ -36,6 +41,10 @@ struct JSONNode { uint32_t node_id; uint32_t index; uint32_t version; + Entry() = default; + Entry(uint32_t node_id, uint32_t index, uint32_t version): + node_id(node_id), index(index), version(version) { + } void Save(dmlc::JSONWriter *writer) const { writer->BeginArray(false); writer->WriteArrayItem(node_id); @@ -64,6 +73,8 @@ struct JSONNode { std::vector inputs; // control flow dependencies std::vector control_deps; + // subgraphs + std::vector subgraphs; // function to save JSON node. void Save(dmlc::JSONWriter *writer) const { @@ -85,6 +96,9 @@ struct JSONNode { if (control_deps.size() != 0) { writer->WriteObjectKeyValue("control_deps", control_deps); } + if (subgraphs.size() != 0) { + writer->WriteObjectKeyValue("subgraphs", subgraphs); + } writer->EndObject(); } @@ -99,6 +113,7 @@ struct JSONNode { helper.DeclareOptionalField("attrs", &(node->attrs.dict)); helper.DeclareOptionalField("attr", &(node->attrs.dict)); helper.DeclareOptionalField("control_deps", &control_deps); + helper.DeclareOptionalField("subgraphs", &subgraphs); // backward compatible code with mxnet graph. int backward_source_id; std::unordered_map param; @@ -154,86 +169,107 @@ struct JSONGraph { } }; -// Load a graph from JSON file. -Graph LoadJSON(Graph src) { - CHECK_NE(src.attrs.count("json"), 0U) - << "Load JSON require json to be presented."; - const std::string &json_str = - nnvm::get(*src.attrs.at("json")); - bool no_parse = false; - if (src.attrs.count("load_json_no_parse")) { - no_parse = nnvm::get(*src.attrs.at("load_json_no_parse")); +void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { + std::unordered_map node2index; + jgraph->node_row_ptr.push_back(0); + DFSVisit(src->outputs, [&node2index, jgraph](const NodePtr& n) { + uint32_t nid = static_cast(jgraph->nodes.size()); + node2index[n.get()] = nid; + if (n->is_variable()) { + jgraph->arg_nodes.push_back(nid); + } + JSONNode jnode; + jnode.node = n; + jnode.inputs.reserve(n->inputs.size()); + for (const NodeEntry& e : n->inputs) { + jnode.inputs.emplace_back(node2index.at(e.node.get()), e.index, e.version); + } + for (const NodePtr& c : n->control_deps) { + jnode.control_deps.push_back(node2index.at(c.get())); + } + jgraph->node_row_ptr.push_back(jgraph->node_row_ptr.back() + n->num_outputs()); + jgraph->nodes.emplace_back(std::move(jnode)); + }); + for (const NodeEntry& e : src->outputs) { + jgraph->heads.emplace_back(node2index.at(e.node.get()), e.index, e.version); } - std::istringstream is(json_str); - dmlc::JSONReader reader(&is); - JSONGraph jgraph; - // load in json graph. - jgraph.Load(&reader); - // connects the nodes - for (JSONNode &n : jgraph.nodes) { + // recursively construct subgraphs + for (JSONNode &jnode : jgraph->nodes) { + // construct jnode's subgraphs + const std::vector> &subgraphs = jnode.node->attrs.subgraphs; + std::vector &jsubgraphs = jnode.subgraphs; + jsubgraphs.resize(subgraphs.size()); + for (uint32_t i = 0; i < subgraphs.size(); ++i) { + Symbol2JSONGraph(subgraphs[i], &jsubgraphs[i]); + } + } +} + +std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) { + for (const JSONNode &n : jgraph.nodes) { n.node->inputs.reserve(n.inputs.size()); for (const JSONNode::Entry &e : n.inputs) { - n.node->inputs.emplace_back( - NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); + n.node->inputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); } n.node->control_deps.reserve(n.control_deps.size()); for (uint32_t nid : n.control_deps) { n.node->control_deps.push_back(jgraph.nodes[nid].node); } // rebuild attribute parser - if (!no_parse && n.node->op() != nullptr && - n.node->op()->attr_parser != nullptr) { + if (!no_parse && n.node->op() != nullptr && n.node->op()->attr_parser != nullptr) { n.node->op()->attr_parser(&(n.node->attrs)); } + for (const JSONGraph &subgraph : n.subgraphs) { + // The "no_parse" option here, is to be compatible with + // commit cfd3075e85807dcd8f9534c37e053583dee87524 + // (https://github.com/apache/incubator-mxnet/tree/cfd3075e85807dcd8f9534c37e053583dee87524), + // where the parsing of main graph is deferred until + // incubator-mxnet/src/nnvm/legacy_json_util.cc:UpgradeJSON_Parse + n.node->attrs.subgraphs.push_back(JSONGraph2Symbol(subgraph, false)); + } } - // consistent check + // consistency check for (uint32_t nid : jgraph.arg_nodes) { CHECK(jgraph.nodes[nid].node->is_variable()); } + std::shared_ptr symbol = std::make_shared(); + symbol->outputs.reserve(jgraph.heads.size()); + for (const JSONNode::Entry &e : jgraph.heads) { + symbol->outputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); + } + return symbol; +} +// Load a graph from JSON file. +Graph LoadJSON(Graph src) { + CHECK_NE(src.attrs.count("json"), 0U) + << "Load JSON require json to be presented."; + const std::string &json_str = + nnvm::get(*src.attrs.at("json")); + bool no_parse = false; + if (src.attrs.count("load_json_no_parse")) { + no_parse = nnvm::get(*src.attrs.at("load_json_no_parse")); + } + std::istringstream is(json_str); + dmlc::JSONReader reader(&is); + JSONGraph jgraph; + // load in json graph. + jgraph.Load(&reader); + std::shared_ptr symbol = JSONGraph2Symbol(jgraph, no_parse); // return the graph Graph ret; ret.attrs = std::move(jgraph.attrs); - ret.outputs.reserve(jgraph.heads.size()); - for (const JSONNode::Entry &e : jgraph.heads) { - ret.outputs.emplace_back( - NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); - } + ret.outputs = symbol->outputs; return ret; } // save a graph to json Graph SaveJSON(Graph src) { + std::shared_ptr src_symbol = std::make_shared(); + src_symbol->outputs = src.outputs; JSONGraph jgraph; + Symbol2JSONGraph(src_symbol, &jgraph); jgraph.attrs = src.attrs; - std::unordered_map node2index; - jgraph.node_row_ptr.push_back(0); - DFSVisit(src.outputs, [&node2index, &jgraph](const NodePtr& n) { - uint32_t nid = static_cast(jgraph.nodes.size()); - node2index[n.get()] = nid; - if (n->is_variable()) { - jgraph.arg_nodes.push_back(nid); - } - JSONNode jnode; - jnode.node = n; - jnode.inputs.reserve(n->inputs.size()); - for (const NodeEntry& e : n->inputs) { - jnode.inputs.emplace_back( - JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version}); - } - for (const NodePtr& c : n->control_deps) { - jnode.control_deps.push_back(node2index.at(c.get())); - } - jgraph.node_row_ptr.push_back( - jgraph.node_row_ptr.back() + n->num_outputs()); - jgraph.nodes.emplace_back(std::move(jnode)); - }); - - for (const NodeEntry& e : src.outputs) { - jgraph.heads.push_back( - JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version}); - } - std::ostringstream os; dmlc::JSONWriter writer(&os); jgraph.Save(&writer);