Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Support extra inputs for subgraph ops #18779

Merged
merged 27 commits into from
Aug 14, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions example/extensions/lib_api/init_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <iostream>
#include "lib_api.h"

using namespace mxnet::ext;

MXReturnValue initialize(int version) {
if (version >= 10700) {
std::cout << "MXNet version " << version << " supported" << std::endl;
Expand Down
2 changes: 2 additions & 0 deletions example/extensions/lib_custom_op/gemm_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <iostream>
#include "lib_api.h"

using namespace mxnet::ext;

// main matrix multiplication routine
void gemm(const float* A, const float* B, float* C,
const unsigned n, const unsigned k, const unsigned m) {
Expand Down
2 changes: 2 additions & 0 deletions example/extensions/lib_custom_op/relu_lib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <iostream>
#include "lib_api.h"

using namespace mxnet::ext;

#define NumThreadPerBlock 256 // mxnet recommended cuda thread number per block

__global__ void relu_gpu_forward(float *out, float *in, int64_t N) {
Expand Down
2 changes: 2 additions & 0 deletions example/extensions/lib_custom_op/transposecsr_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <iostream>
#include "lib_api.h"

using namespace mxnet::ext;

void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) {
MXSparse* A = src.data<MXSparse>();
MXSparse* B = dst.data<MXSparse>();
Expand Down
2 changes: 2 additions & 0 deletions example/extensions/lib_custom_op/transposerowsp_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <iostream>
#include "lib_api.h"

using namespace mxnet::ext;

void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) {
MXSparse* A = src.data<MXSparse>();
MXSparse* B = dst.data<MXSparse>();
Expand Down
67 changes: 50 additions & 17 deletions example/extensions/lib_pass/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,36 +80,36 @@ sym_block.optimize_for(x, backend='myPass')

APIs in MXNet are available in both Symbol and Gluon APIs. For the Symbol API, the `optimize_for` API can be called on Symbol objects to return a new Symbol post graph pass.

```
```python
optimize_for(backend, args=None, aux=None, ctx=None, **kwargs)
```

The `optimize_for` API takes at least 1 argument, `backend` which is a string that identifies which backend to use to optimize the model. The `args` and `aux` arguments are optional and take a list of NDArray or dict of str to NDArray. They are used to infer shapes and types and before executing the graph pass. The `ctx` argument is optional and takes a device context to infer storage types. It also takes any other user-specified options that will be passed to the backend APIs.

For the Gluon API, the `hybridize` API can be called on HybridBlocks to execute a graph pass on the internal CachedOp Symbol.

```
```python
hybridize(backend=None, backend_opts=None, **kwargs)
```

The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which pass that will be executed on the model. The `backend_opts` takes other user-specified options that will be passed to the backend APIs. The actual pass runs once just before the first the forward pass.

If you just want to run a graph pass on the HybridBlock but not run a complete forward pass, you can use the `optimize_for` API that combines the work done in the `hybridize` API with part of the work done in the forward pass.

```
```python
optimize_for(x, backend=None, backend_opts=None, **kwargs)
```

When the `optimize_for` API is called on a HybridBlock it runs the graph pass immediately. This lets users export the modified model without running a complete forward pass.

```
```python
block.optimize_for(x, backend='myPass')
block.export('optimized')
```

But you can also use `optimize_for` in place of `hybridize` and run inference immediately after too.

```
```python
block.optimize_for(x, backend='myPass')
block(x)
```
Expand All @@ -120,50 +120,83 @@ There are several essential building blocks for making a custom pass:

* [initialize](./pass_lib.cc#44):
* This function is the library initialization function necessary for any dynamic libraries. It lets you check if the user is using a compatible version of MXNet. Note that this `version` parameter is passed from MXNet when library is loaded.

```c++
MXReturnValue initialize(int version)

```
* [graphPass](./pass_lib.cc#31):
* This function provides a copy of the model graph as a JSON string, and provides an interface for returning a modified model JSON string. Also this is where a custom pass can validate the options specified by the user.

```c++
MXReturnValue graphPass(
const std::string& in_graph,
const std::string** out_graph,
const std::unordered_map<std::string, std::string>& options,
const std::unordered_map<std::string, MXTensor>& args,
const std::unordered_map<std::string, MXTensor>& aux,
const PassResource& res)

```
* [REGISTER_PASS(my_pass_name)](./pass_lib.cc#L41):
* This macro registers the custom pass and its properties to MXNet by its name. The argument to `setBody` is the `graphPass` function.

```c++
REGISTER_PASS(my_pass_name)
.setBody(graphPass);

```
Let’s take a closer look at those registry functions:

* **graphPass**: This function takes six arguments. The 1st argument is a JSON string of the model architecture graph, where nodes are inputs/params/weights and edges are data dependencies. The graph is pre-sorted in topological order. The 2nd argument is a pointer to a pointer of a JSON model string. It is expected users will dereference and assign the address of their output string allocated with `new` and `delete` will be called on it automatically. The third argument is the map of options specified by the user. Users can pass custom options to the pass and they are passed to this function in the `options` map. The fourth and fifth arguments are the named tensor mapping for the args and aux params for the model. They will contain the model params if the user provides them to the `optimize_for` API. The last argument is the `PassResource` object for memory allocation and other utilities. The details of `PassResource` are covered in the section below

### Pass Resource
### Graph representation

Some graph passes require allocating new NDArrays to add/replace model params. The `alloc_arg` and `alloc_aux` APIs enabling allocating new NDArrays and integrating them with the user-provide args and aux params. Both APIs have the following signature:
The `Graph` class represents the model's architecture. Each `Node` in the graph represents an operator or weight (ie. args/aux param). Since an operator in MXNet can take multiple inputs and produce multiple outputs, each input/output is represented by a `NodeEntry`. A `Node` contains the following:
- `op` - [string] operator name
- `name` - [string] unique node name
- `inputs` - [vector of NodeEntry] set of inputs to the node
- `outputs` - [vector of NodeEntry] set of outputs from the node
- `subgraph` - [vector of Graph] set of subgraphs in the node
- `attrs` - [map of string to string] set of attributes for the node

The `inputs` are a set of `NodeEntry` where each contains a pointer to a `Node` that produces the data, and an `entry` that is the index of the output on the other `Node`. Conversely, the `outputs` are a set of `NodeEntry` where each contains a pointer to a`Node` that consumes the data, and and `entry` that is the index of the input on the other `Node`. This bidirectional dependency will enable you to easily traverse the graph.

A `Graph` contains the following:
- `nodes` - [vector of Node] set of nodes in the graph
- `inputs` - [vector of Node] set of inputs to the graph
- `outputs` - [vector of NodeEntry] set of outputs from the graph
- `attrs` - [map of string to JSON object] set of attributes for the graph

The `nodes` are all the nodes in the graph (superset). The `inputs` are only those nodes that are model inputs (ie. input image) or weights (ie. arg/aux params). The `outputs` are the outputs from the operators in the model that are true outputs of the model (ie. prediction results).

Heres an example creating a new node and adding it to the graph:
```c++
Node* n = new Node();
g->nodes.push_back(n);
```
Heres an example creating an edge between two nodes:
```c++
n1->outputs.push_back({n2,1});
n2->inputs.push_back({n1,0});
```
Here node `n1` produces an output at index 0 that is consumed by node `n2` on the input at index 1.

### Pass Resource

Some graph passes require allocating new NDArrays to add/replace model params. The `alloc_arg` and `alloc_aux` APIs enable allocating new NDArrays and integrate them with the model args and aux params. Both APIs have the following signature:

```c++
MXTensor* alloc_xxx(const std::string& name,
const std::vector<int64_t>& shapes,
const MXContext &ctx,
MXDType dtype)
```

If the `name` provided matches the name of an existing param it replaces the previous one. Otherwise it adds a new param to the appropriate arg/aux set.
If the `name` provided matches the name of an existing param it replaces the previous one. Otherwise it adds a new param to the appropriate arg/aux set. Be sure that you add a new node in the graph that corresponds to this new param, otherwise it will be useless.

If you wish to remove an existing param, just remove the node in the graph corresponding to that param. It will be deleted after the pass completes and removed from the dictionary of args or aux (whichever it is a member of).

### Parsing a JSON string

To simplify custom libraries, basic JSON parsing utility functions have been implemented in the `lib_api.h` header file. You create a `JsonParser` object and parse the string by calling the `parse_to_json` API like:

```c++
JsonParser parser;
JsonVal json_val = parser.parse_to_json(json_string);
JsonVal json_val = JsonVal::parse(json);
```

A `JsonVal` is a class that represents the nodes in a JSON structure. You can check the type of a node (num, str, list, or map) by comparing the `JsonVal.type` to `STR`, `NUM`, `LIST`, or `MAP`. Then you can get that value from the node like:
Expand All @@ -187,4 +220,4 @@ switch(json_val.type) {
}
```

There are also convenience constructors for creating `JsonVal` objects for strings and numbers like `JsonVal("myKey")` or `JsonVal(42)`. This makes it easy to get specific keys from a map like `json_val.map[JsonVal("nodes")]`.
You call the `dump` function on a `JsonVal` object like `json_val.dump()` to get a JSON-compatible string. There are also convenience constructors for creating `JsonVal` objects for strings and numbers like `JsonVal("myKey")` or `JsonVal(42)`. This makes it easy to get specific keys from a map like `json_val.map[JsonVal("nodes")]`.
7 changes: 4 additions & 3 deletions example/extensions/lib_pass/pass_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <algorithm>
#include "lib_api.h"

using namespace mxnet::ext;

/* \brief a basic pass that copies the input to the output */
MXReturnValue myPass(const std::string& in_graph, const std::string** out_graph,
const std::unordered_map<std::string, std::string>& options,
Expand Down Expand Up @@ -60,8 +62,7 @@ MXReturnValue jsonPass(const std::string& in_graph, const std::string** out_grap
MXTensor* aux_ = res.alloc_aux("test_aux",{1},MXContext::CPU(0),kFloat32);

// convert json string to json object
JsonParser parser;
JsonVal json_val = parser.parse_to_json(in_graph);
JsonVal json_val = JsonVal::parse(in_graph);

// get nodes list
JsonVal nodes = json_val.map[JsonVal("nodes")];
Expand All @@ -86,7 +87,7 @@ MXReturnValue jsonPass(const std::string& in_graph, const std::string** out_grap
}
}

*out_graph = new std::string(parser.dump(json_val));
*out_graph = new std::string(json_val.dump());
return MX_SUCCESS;
}

Expand Down
Loading