Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for subgraphs. #1221

Merged
merged 12 commits into from
Jun 12, 2018
Merged
16 changes: 16 additions & 0 deletions nnvm/include/nnvm/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace nnvm {

// Forward declare node.
class Node;
class Symbol;

/*!
* \brief we always used NodePtr for a reference pointer
Expand Down Expand Up @@ -90,6 +91,21 @@ struct NodeAttrs {
* The object can be used to quickly access attributes.
*/
any parsed;
/*!
* \brief Some operators take graphs as input. These operators include
* control flow operators and high-order functions.
* These graphs don't change when the operators are invoked for different
* mini-batches. In this sense, the subgraphs are kind of similar to
* the parameters and show be kept as node attributes.
*
* Users need to make sure the subgraphs are disjoint with the main graph.
* If a graph shares nodes with subgraphs, loading the graph from LoadJSON
* may generate a graph that has a different structure from the original graph
* (some of the nodes are duplicated). If nodes are shared between two graphs,
* shared nodes might be executed multiple times, which can be a problem for
* stateful operators.
*/
std::vector<std::shared_ptr<Symbol> > subgraphs;
};

/*!
Expand Down
12 changes: 12 additions & 0 deletions nnvm/include/nnvm/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,18 @@ using FCorrectLayout = std::function<bool(
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts)>;

/*!
* \brief Get a list of inputs that represent graphs instead of data.
* Normally, input symbols are considered as data to the operator. However,
* control flow operators and high-order functions need to interpret symbols
* as graphs.
* \param attrs The attributes of this node.
* \return a list of input index that are interpreted as symbols by the operator.
*
* \note Register under "FInputGraph".
*/
using FInputGraph = std::function<std::vector<uint32_t>(const NodeAttrs& attrs)>;

} // namespace nnvm

#endif // NNVM_OP_ATTR_TYPES_H_
40 changes: 39 additions & 1 deletion nnvm/src/core/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,51 @@ const IndexedGraph& Graph::indexed_graph() const {
return *indexed_graph_;
}

// a subgraph should not refer to any nodes with higher level
// where "level" refers to the nested depth of the subgraph
// e.g. the main graph is level 0
// subgraphs of the main graph is level 1
// subgraphs of the subgraphs of the main graph is level 2
static void SubgraphSanityCheck(const std::vector<std::shared_ptr<Symbol>> &subgraphs) {
std::vector<const std::vector<nnvm::NodeEntry>*> curr_level;
std::vector<const std::vector<nnvm::NodeEntry>*> next_level;
std::unordered_map<nnvm::Node*, uint32_t> node2level;
for (auto &subgraph : subgraphs)
next_level.push_back(&subgraph->outputs);
for (uint32_t level = 0; !next_level.empty(); ++level) {
curr_level.swap(next_level);
next_level.clear();
for (const std::vector<NodeEntry> *graph_ptr : curr_level) {
const std::vector<NodeEntry> &graph = *graph_ptr;
DFSVisit(graph, [&next_level, &node2level, level](const NodePtr& n) {
nnvm::Node *node = n.get();
// if the node is visited, but on a different level, then check failed
// if check failed here or before, we stop doing anything, but raise an error
CHECK(!node2level.count(node) || node2level[node] == level)
<< "A subgraph should not depend on the outputs of nodes on higher levels";
// otherwise, this node belongs to the current level
node2level[node] = level;
// subgraphs of current node belongs to next level
for (const auto& subgraph : n->attrs.subgraphs) {
next_level.push_back(&subgraph->outputs);
}
});
}
}
}

// implement constructor from graph
IndexedGraph::IndexedGraph(const Graph &g) {
entry_rptr_.push_back(0);
std::vector<size_t> inputs_rptr{0}, control_rptr{0};
std::vector<std::shared_ptr<Symbol>> subgraphs;

DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr]
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs]
(const NodePtr& n) {
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size());
for (const auto &subgraph : n->attrs.subgraphs)
subgraphs.push_back(subgraph);
// nodes_
IndexedGraph::Node new_node;
new_node.source = n.get();
Expand Down Expand Up @@ -53,6 +89,8 @@ IndexedGraph::IndexedGraph(const Graph &g) {
}
control_rptr.push_back(control_deps_.size());
});
if (!subgraphs.empty())
SubgraphSanityCheck(subgraphs);

for (const auto& e : g.outputs) {
outputs_.emplace_back(NodeEntry{
Expand Down
81 changes: 62 additions & 19 deletions nnvm/src/core/symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,43 +267,86 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
const std::string& name) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose");
static auto& fgraph = Op::GetAttr<FInputGraph>("FInputGraph");

// The arguments that contain graphs.
Node* n = outputs[0].node.get();
FInputGraph fng = fgraph.get(n->op(), nullptr);
std::vector<uint32_t> garg_idx;
if (fng != nullptr)
garg_idx = fng(n->attrs);

// The names of the arguments that contain graphs.
FListInputNames name_fn = flist_inputs.get(n->op(), nullptr);
auto arg_names = (name_fn == nullptr) ? std::vector<std::string>{"data"} : name_fn(n->attrs);
std::vector<std::string> garg_names(garg_idx.size());
for (size_t i = 0; i < garg_idx.size(); i++) {
size_t idx = garg_idx[i];
if (idx < arg_names.size())
garg_names[i] = arg_names[idx];
}

// parameter check.
for (size_t i = 0; i < args.size(); ++i) {
CHECK_EQ(args[i]->outputs.size(), 1U)
// If the argument isn't a graph, it should have only one output.
if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end())
CHECK_EQ(args[i]->outputs.size(), 1U)
<< "Argument " << i << " is a tuple, single value is required";
}
for (const auto& kv : kwargs) {
CHECK_EQ(kv.second->outputs.size(), 1U)
if (garg_names.empty()
|| std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end())
CHECK_EQ(kv.second->outputs.size(), 1U)
<< "Keyword Argument " << kv.first << " is a tuple, single value is required";
}
// assign new name
if (!name.empty()) outputs[0].node->attrs.name = name;

// Atomic functor composition.
if (IsAtomic(outputs)) {
Node* n = outputs[0].node.get();
uint32_t n_req = n->num_inputs();
std::vector<const Symbol *> arg_vec(args.begin(), args.end());
std::unordered_map<std::string, const Symbol*> kwarg_map(kwargs.begin(), kwargs.end());
// If one of the input arguments is a graph, we need to remove it from the
// list.
if (fng != nullptr) {
std::vector<uint32_t> idxes = fng(n->attrs);
for (auto idx : idxes) {
const Symbol *sym;
if (idx < arg_vec.size()) {
sym = arg_vec[idx];
arg_vec.erase(arg_vec.begin() + idx);
} else {
auto it = kwarg_map.find(arg_names[idx]);
CHECK(it != kwarg_map.end());
sym = it->second;
kwarg_map.erase(it);
}

if (n_req != kVarg)
n_req--;
arg_names.erase(arg_names.begin() + idx);
n->attrs.subgraphs.push_back(std::make_shared<Symbol>(*sym));
}
}

if (n_req != kVarg) {
n->inputs.resize(n_req);
CHECK_LE(args.size(), n_req)
CHECK_LE(arg_vec.size(), n_req)
<< "Incorrect number of arguments, requires " << n_req
<< ", provided " << args.size();
for (size_t i = 0; i < args.size(); ++i) {
n->inputs[i] = args[i]->outputs[0];
<< ", provided " << arg_vec.size();
for (size_t i = 0; i < arg_vec.size(); ++i) {
n->inputs[i] = arg_vec[i]->outputs[0];
}
// switch to keyword argument matching
if (args.size() != n_req) {
FListInputNames fn = flist_inputs.get(n->op(), nullptr);
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs);
if (arg_vec.size() != n_req) {
if (arg_names.size() != n_req) {
LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name;
}
size_t nmatched = 0;
for (size_t i = args.size(); i < n_req; ++i) {
auto it = kwargs.find(arg_names[i]);
if (it != kwargs.end() && it->first == arg_names[i]) {
for (size_t i = arg_vec.size(); i < n_req; ++i) {
auto it = kwarg_map.find(arg_names[i]);
if (it != kwarg_map.end() && it->first == arg_names[i]) {
n->inputs[i] = it->second->outputs[0];
++nmatched;
} else {
Expand All @@ -314,18 +357,18 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
}
}

if (nmatched != kwargs.size()) {
if (nmatched != kwarg_map.size()) {
n->inputs.clear();
std::vector<std::string> keys = GetKeys(kwargs);
array_view<std::string> view(dmlc::BeginPtr(arg_names) + args.size(),
std::vector<std::string> keys = GetKeys(kwarg_map);
array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_vec.size(),
dmlc::BeginPtr(arg_names) + arg_names.size());
KeywordArgumentMismatch("Symbol.Compose", keys, view);
}
}
} else {
CHECK_EQ(kwargs.size(), 0U) << "Variable length function do not accept kwargs";
n->inputs.reserve(args.size());
for (const Symbol* s : args) {
CHECK_EQ(kwarg_map.size(), 0U) << "Variable length function do not accept kwargs";
n->inputs.reserve(arg_vec.size());
for (const Symbol* s : arg_vec) {
n->inputs.push_back(s->outputs[0]);
}
}
Expand Down
Loading