Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

dynamic custom operator support #15921

Merged
merged 139 commits into from
Dec 6, 2019
Merged
Show file tree
Hide file tree
Changes from 125 commits
Commits
Show all changes
139 commits
Select commit Hold shift + click to select a range
5030a65
fixed example to use absolute path
Aug 15, 2019
23a226a
added example for custom ops, added support for custom op registration
Aug 16, 2019
67c22c0
added fcompute registration for loaded operators
Aug 17, 2019
915c1d5
changed dynamic ops to be contrib
Aug 17, 2019
f568e3d
added num in/out
Aug 18, 2019
8e12588
removed contrib op registration
Aug 20, 2019
1e27a47
added support for infer shape, updated example to call operator
Aug 20, 2019
9aecf86
fixed whitespace
Aug 20, 2019
02deacf
fixed whitespace
Aug 20, 2019
cf9350d
fixed whitespace
Aug 20, 2019
ada3895
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Aug 23, 2019
38e77a5
added temporary support for operator multi-registration
Aug 23, 2019
7b8f6a2
insanity checked
Aug 23, 2019
5b817bd
update docblocks
rondogency Aug 23, 2019
3bccfbe
small format fix
rondogency Aug 23, 2019
a8c19c8
fix unittest with correct library
rondogency Aug 23, 2019
2f34471
implement InferType
rondogency Aug 27, 2019
3502aa9
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Aug 27, 2019
52e687b
Merge branch 'dynamic_op' of https://github.com/samskalicky/incubator…
Aug 27, 2019
5438a35
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Aug 27, 2019
592249a
initial support for resource manager, temp space
Aug 27, 2019
3186d60
fixed formatting
Aug 27, 2019
e8b413b
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Aug 29, 2019
bf549b4
changed decltype to typedef
Aug 29, 2019
439ee20
fixed whitespace
Aug 29, 2019
bba25db
Added windows declaration types, change APIs to return MXReturnValue …
Aug 29, 2019
a681f61
added library version number, API to get, and check to validate
Aug 29, 2019
711f9a3
Changed CMakeLists to build lib_ops instead of lib_api, updated lib_a…
Aug 29, 2019
172129f
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Aug 29, 2019
5af1736
add prototype of subgraph op
rondogency Aug 29, 2019
33d9cd7
implement FMutateInput as optional attribute
rondogency Aug 30, 2019
4576570
fix sanity check
rondogency Aug 30, 2019
6f3e3d9
replace fcompute to fcomputeEx and implement simple finferstoragetype
rondogency Sep 3, 2019
9587483
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Sep 4, 2019
34efb2b
Merge branch 'dynamic_op' of https://github.com/samskalicky/incubator…
Sep 4, 2019
ff9a868
changed fcompute to forward
Sep 4, 2019
0be218b
initial commit with fgradient support
Sep 4, 2019
570a059
enabled gradient registration
Sep 4, 2019
4b01932
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Sep 4, 2019
e4be175
fixed whitespace
Sep 4, 2019
8cfcc85
fixed example to use absolute path
Aug 15, 2019
9884ec6
added example for custom ops, added support for custom op registration
Aug 16, 2019
8e21600
added fcompute registration for loaded operators
Aug 17, 2019
794e30b
changed dynamic ops to be contrib
Aug 17, 2019
8fbf664
added num in/out
Aug 18, 2019
e7c6e8f
removed contrib op registration
Aug 20, 2019
6047378
added support for infer shape, updated example to call operator
Aug 20, 2019
d1587ab
fixed whitespace
Aug 20, 2019
0ee56c9
fixed whitespace
Aug 20, 2019
adc9770
fixed whitespace
Aug 20, 2019
5c06d47
added temporary support for operator multi-registration
Aug 23, 2019
9136839
insanity checked
Aug 23, 2019
ffe7623
update docblocks
rondogency Aug 23, 2019
435e01e
small format fix
rondogency Aug 23, 2019
0de79a9
fix unittest with correct library
rondogency Aug 23, 2019
0d6f7b0
implement InferType
rondogency Aug 27, 2019
18b028e
initial support for resource manager, temp space
Aug 27, 2019
a4690b4
fixed formatting
Aug 27, 2019
c901828
changed decltype to typedef
Aug 29, 2019
5ddb919
fixed whitespace
Aug 29, 2019
7b4c4e6
Added windows declaration types, change APIs to return MXReturnValue …
Aug 29, 2019
18117ec
added library version number, API to get, and check to validate
Aug 29, 2019
ee65419
Changed CMakeLists to build lib_ops instead of lib_api, updated lib_a…
Aug 29, 2019
c66438c
add prototype of subgraph op
rondogency Aug 29, 2019
698a0b6
implement FMutateInput as optional attribute
rondogency Aug 30, 2019
bd55612
fix sanity check
rondogency Aug 30, 2019
35ff973
replace fcompute to fcomputeEx and implement simple finferstoragetype
rondogency Sep 3, 2019
f243e2f
changed fcompute to forward
Sep 4, 2019
efbb858
initial commit with fgradient support
Sep 4, 2019
0032143
enabled gradient registration
Sep 4, 2019
14ef3a7
fixed whitespace
Sep 4, 2019
eec71d6
prototype of createopstate and fstatefulcompute
rondogency Sep 6, 2019
abcb8cb
make custom state op interface work
rondogency Sep 6, 2019
9cf0455
subgraph forward
rondogency Sep 9, 2019
82f1bff
refactor stateful forward and add op resource
rondogency Sep 10, 2019
f7ff481
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Sep 10, 2019
ba563d2
wip gemm backward
Sep 10, 2019
a9b7215
Merge branch 'dynamic_op' of https://github.com/samskalicky/incubator…
Sep 10, 2019
7bf4f7a
stateful backward and subgraph test
rondogency Sep 11, 2019
8aec7ac
implement gemm and state gemm, refactor test files
rondogency Sep 12, 2019
39e3d6b
add body to pure virtual destructor
rondogency Sep 12, 2019
b3ba028
subgraph passing from python to custom lib
rondogency Sep 23, 2019
c9d8498
Merge branch 'master' into dynamic_op
rondogency Sep 24, 2019
1686273
rm lib_api c++11 dep, rm warpctc, add rm flag
rondogency Sep 26, 2019
7009ad4
fix conflict
rondogency Sep 26, 2019
4b73179
subgraph json parsing utility
rondogency Sep 28, 2019
dca521e
add data size and fix unsigned warnings
rondogency Sep 29, 2019
baed04e
use c++ struct and fix cpplint
rondogency Sep 30, 2019
aedcf91
refactor op registry
rondogency Sep 30, 2019
75102a3
fix line length and win array of ci; condense lines
rondogency Sep 30, 2019
9c29deb
Merge remote-tracking branch 'upstream/master' into dynamic_op
rondogency Sep 30, 2019
c5a3ed6
add mxnet_extension dir
rondogency Oct 1, 2019
44683f1
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Oct 1, 2019
d1b6c8e
Merge branch 'dynamic_op' of https://github.com/samskalicky/incubator…
Oct 1, 2019
44affc7
fixed extension to be dll for windows
Oct 1, 2019
ef1d4cf
updated examples to use the same format as the example in the top-lev…
Oct 1, 2019
24d8cc3
removed destructor for CustomStatefulOp
Oct 1, 2019
279a989
fix error in gemm test and clear up subgraph test
rondogency Oct 2, 2019
5db9e97
merge with dynamic_op
rondogency Oct 2, 2019
75b1169
lib path fix
rondogency Oct 2, 2019
79c0e3a
add unittest for custom op
rondogency Oct 2, 2019
11d3344
update Makefile revolve merge
rondogency Oct 2, 2019
de157a8
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Oct 2, 2019
28450b5
Merge branch 'dynamic_op' of https://github.com/samskalicky/incubator…
Oct 2, 2019
9504b33
fix test and rename folder
rondogency Oct 2, 2019
cf27d57
fix makefile rename
rondogency Oct 2, 2019
7f456d4
fix cmake rename
rondogency Oct 2, 2019
e50819b
add explicit cpu context
rondogency Oct 2, 2019
5984f3a
Merge remote-tracking branch 'upstream/master' into dynamic_op
rondogency Oct 2, 2019
bd2c3a0
wkcn feedback: change mxtensor func name. use c++11 flag
rondogency Oct 3, 2019
b07e46b
add operator keyward test and refine info print
rondogency Oct 3, 2019
2466d67
using typedef in forward
rondogency Oct 3, 2019
e041400
small refine of docblock
rondogency Oct 4, 2019
f16942c
change names
rondogency Oct 8, 2019
50a6b64
add separate stateful compute and pass state_op ptr
rondogency Oct 8, 2019
adb0415
user example using opresource alloc
rondogency Oct 14, 2019
6148ef8
Merge remote-tracking branch 'upstream/master' into dynamic_op
rondogency Oct 14, 2019
6d9ac54
Merge remote-tracking branch 'upstream/master' into dynamic_op
rondogency Oct 15, 2019
7c256cd
Merge remote-tracking branch 'upstream/master' into dynamic_op
rondogency Oct 15, 2019
6e824fb
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Oct 18, 2019
40c471b
Merge branch 'dynamic_op' of https://github.com/samskalicky/incubator…
Oct 18, 2019
dfb5946
added DLTensor into MXTensor
Oct 18, 2019
5146fd5
fixed whitespace
Oct 18, 2019
1b9fee2
added error check when DLTensor does not support MXNet data type
Oct 18, 2019
5761891
changed to throw runtime exception
Oct 18, 2019
ef840b4
changed include to stdexcept
Oct 18, 2019
bba61b3
retrigger CI
wkcn Oct 18, 2019
53d18ec
empty commit
Oct 18, 2019
e0c778c
Merge branch 'dynamic_op' of https://github.com/samskalicky/incubator…
Oct 18, 2019
141328f
empty commit
Oct 18, 2019
deacae2
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Oct 18, 2019
2b2c6a4
Merge branch 'master' into dynamic_op
szha Oct 22, 2019
ed8ac16
remove merge conflict
rondogency Oct 23, 2019
56b0e28
add setdltensor for easy use and add docs
rondogency Oct 23, 2019
1bd166e
Merge branch 'master' into dynamic_op
wkcn Nov 26, 2019
50c8aea
CI
wkcn Nov 26, 2019
34a9ee9
re-trigger CI
wkcn Nov 28, 2019
9910c39
ci
wkcn Dec 5, 2019
5fd4314
ci
wkcn Dec 6, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ else()

endif()

add_library(sample_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/lib_api/mylib.cc)
add_library(sample_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc)
target_include_directories(sample_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
set(MXNET_INSTALL_TARGETS mxnet)
if(UNIX)
Expand Down
10 changes: 6 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,10 @@ cpplint:
pylint:
python3 -m pylint --rcfile=$(ROOTDIR)/ci/other/pylintrc --ignore-patterns=".*\.so$$,.*\.dll$$,.*\.dylib$$" python/mxnet tools/caffe_converter/*.py

# sample lib for MXNet extension dynamically loading custom operator
sample_lib:
$(CXX) -shared -fPIC -std=c++11 example/extensions/lib_custom_op/gemm_lib.cc -o libsample_lib.so -I include/mxnet

samskalicky marked this conversation as resolved.
Show resolved Hide resolved
# Cython build
cython:
cd python; $(PYTHON) setup.py build_ext --inplace --with-cython
Expand Down Expand Up @@ -721,10 +725,6 @@ rpkgtest:
Rscript -e 'require(testthat);res<-test_dir("R-package/tests/testthat");if(!testthat:::all_passed(res)){stop("Test failures", call. = FALSE)}'
Rscript -e 'res<-covr:::package_coverage("R-package");fileConn<-file(paste("r-package_coverage_",toString(runif(1)),".json"));writeLines(covr:::to_codecov(res), fileConn);close(fileConn)'


sample_lib:
$(CXX) -shared -fPIC example/lib_api/mylib.cc -o libsample_lib.so -I include/mxnet

scalaclean:
(cd $(ROOTDIR)/scala-package && mvn clean)

Expand Down Expand Up @@ -776,6 +776,7 @@ clean: rclean cyclean $(EXTRA_PACKAGES_CLEAN)
cd $(NNVM_PATH); $(MAKE) clean; cd -
cd $(TVM_PATH); $(MAKE) clean; cd -
cd $(AMALGAMATION_PATH); $(MAKE) clean; cd -
$(RM) libsample_lib.so
$(RM) -r $(patsubst %, %/*.d, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.d, $(EXTRA_OPERATORS))
$(RM) -r $(patsubst %, %/*.o, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.o, $(EXTRA_OPERATORS))
else
Expand All @@ -786,6 +787,7 @@ clean: rclean mkldnn_clean cyclean testclean $(EXTRA_PACKAGES_CLEAN)
cd $(PS_PATH); $(MAKE) clean; cd -
cd $(NNVM_PATH); $(MAKE) clean; cd -
cd $(AMALGAMATION_PATH); $(MAKE) clean; cd -
$(RM) libsample_lib.so
endif

clean_all: clean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
# under the License.

all:
g++ -shared -fPIC mylib.cc -o mylib.so -I ../../include/mxnet
g++ -std=c++11 -shared -fPIC init_lib.cc -o libinit_lib.so -I ../../../include/mxnet

test:
g++ -std=c++11 -O3 -o libtest libtest.cc -ldl -I ../../include/mxnet
g++ -std=c++11 -O3 -o libtest libtest.cc -ldl -I ../../../include/mxnet

windows:
cl /LD mylib.cc
cl /LD init_lib.cc

win_test:
cl libtest.cc

clean:
rm -rf mylib.so libtest
rm -rf *.so libtest
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@

/*!
* Copyright (c) 2015 by Contributors
* \file mylib.cc
* \file init_lib.cc
* \brief Sample library file
*/

#include <iostream>
#include "lib_api.h"

int initialize(int version) {
MXReturnValue initialize(int version) {
if (version >= 10400) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return 1;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
return 0;
return MX_FAIL;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ int main(void) {
// Get a handle to the library.
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
HINSTANCE handle;
handle = LoadLibrary(TEXT("mylib.dll"));
handle = LoadLibrary(TEXT("libinit_lib.dll"));
#else
void *handle;
handle = dlopen("mylib.so", RTLD_LAZY);
handle = dlopen("libinit_lib.so", RTLD_LAZY);
#endif

if (!handle) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import os

if (os.name=='posix'):
mx.library.load('mylib.so')
path = os.path.abspath('libinit_lib.so')
mx.library.load(path)
elif (os.name=='nt'):
mx.library.load('mylib.dll')
path = os.path.abspath('libinit_lib.dll')
mx.library.load(path)
27 changes: 27 additions & 0 deletions example/extensions/lib_custom_op/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

all: subgraph_lib gemm_lib

gemm_lib:
g++ -shared -fPIC -std=c++11 gemm_lib.cc -o libgemm_lib.so -I ../../../include/mxnet

subgraph_lib:
g++ -shared -fPIC -std=c++11 subgraph_lib.cc -o libsubgraph_lib.so -I ../../../include/mxnet

clean:
rm -rf libsubgraph_lib.so libgemm_lib.so
233 changes: 233 additions & 0 deletions example/extensions/lib_custom_op/gemm_lib.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file gemm_lib.cc
* \brief Sample 2D gemm custom operator implementation library file
*/

#include <iostream>
#include "lib_api.h"

// main matrix multiplication routine
void gemm(const float* A, const float* B, float* C,
const unsigned n, const unsigned k, const unsigned m) {
unsigned i, j, kk;
for (i = 0; i < n; i++) {
for (j = 0; j < m; j++) {
C[i*m+j] = 0;
for (kk = 0; kk < k; kk++) {
C[i*m+j] += A[i*k+kk] * B[kk*m+j];
}
}
}
}

void transpose(const float* A, float* At, const unsigned n, const unsigned m) {
unsigned i, j;
for (i = 0; i < n; i++) {
for (j = 0; j < m; j++) {
At[i*m+j] = A[j*n+i];
}
}
}

/*
* Executes C = A * B
* inputs[0] = A; inputs[1] = B; outputs[0] = C
*/
MXReturnValue forward(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
// simple example of using runtime data type
if (inputs[0].dtype == kFloat32) {
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
typedef float DType;
// extract data pointers from tensors
DType* A = inputs[0].data<DType>();
DType* B = inputs[1].data<DType>();
DType* C = outputs[0].data<DType>();
// set tensor shapes
unsigned n = inputs[0].shape[0];
unsigned k = inputs[0].shape[1];
unsigned m = inputs[1].shape[1];

gemm(A, B, C, n, k, m);
}
return MX_SUCCESS;
}

/*
* Executes dA = dC * B.T; Executes dB = A.T * dC
***** gradient inputs
* inputs[0] = dC
***** original inputs
* inputs[1] = A; inputs[2] = B
***** original outputs
* inputs[3] = C
***** gradient outputs
* outputs[0] = dA; outputs[1] = dB
*/
MXReturnValue backward(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
// extract data pointers from tensors
float* dC = inputs[0].data<float>();
float* A = inputs[1].data<float>();
float* B = inputs[2].data<float>();
float* dA = outputs[0].data<float>();
float* dB = outputs[1].data<float>();
// set tensor shapes
unsigned n = inputs[1].shape[0];
unsigned k = inputs[1].shape[1];
unsigned m = inputs[2].shape[1];
// allocate temporary workspace memory through resource manager
// for multiple arrays better to request a big memory pool
void *workspace = res.alloc((k*n + m*k) * sizeof(float));
float *At = static_cast<float*>(workspace);
float *Bt = static_cast<float*>(workspace) + (k*n);

transpose(A, At, k, n);
transpose(B, Bt, m, k);
gemm(dC, Bt, dA, n, m, k);
gemm(At, dC, dB, k, n, m);

return MX_SUCCESS;
}

MXReturnValue parseAttrs(std::map<std::string, std::string> attrs, int* num_in, int* num_out) {
*num_in = 2;
*num_out = 1;
return MX_SUCCESS;
}

MXReturnValue inferType(std::map<std::string, std::string> attrs,
std::vector<int> &intypes,
std::vector<int> &outtypes) {
// validate inputs
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
if (intypes.size() != 2) {
std::cout << "Expected 2 inputs to inferType" << std::endl;
return MX_FAIL;
}
for (unsigned i = 0; i < intypes.size(); i++) {
if (intypes[i] != kFloat32) {
std::cout << "Expected input " << i << " to have float32 type" << std::endl;
return MX_FAIL;
}
}

outtypes[0] = intypes[0];
return MX_SUCCESS;
}

MXReturnValue inferShape(std::map<std::string, std::string> attrs,
std::vector<std::vector<unsigned int>> &inshapes,
std::vector<std::vector<unsigned int>> &outshapes) {
// validate inputs
if (inshapes.size() != 2) {
std::cout << "Expected 2 inputs to inferShape" << std::endl;
return MX_FAIL;
}
if (inshapes[0].size() != 2 || inshapes[1].size() != 2) {
std::cout << "Expected 2D matrices for both inputs to inferShape" << std::endl;
return MX_FAIL;
}

unsigned n = inshapes[0][0];
unsigned k = inshapes[0][1];
unsigned kk = inshapes[1][0];
unsigned m = inshapes[1][1];
if (k != kk) {
std::cout << "Exected first input axis 1 equals to second input axis 0" << std::endl;
return MX_FAIL;
}

outshapes[0] = {n, m};
return MX_SUCCESS;
}

REGISTER_OP(my_gemm)
.setForward(forward)
.setBackward(backward)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape);

/* ------------------------------------------------------------------------- */

class MyStatefulGemm : public CustomStatefulOp {
public:
explicit MyStatefulGemm(int count) : count(count) {}

MXReturnValue Forward(std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource op_res) {
++count;
std::cout << "Info: keyword + number of forward: " << count << std::endl;
std::map<std::string, std::string> attrs;
return forward(attrs, inputs, outputs, op_res);
}

MXReturnValue Backward(std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource op_res) {
std::map<std::string, std::string> attrs;
return backward(attrs, inputs, outputs, op_res);
}

~MyStatefulGemm() {}

private:
int count;
};

MXReturnValue createOpState(std::map<std::string, std::string> attrs,
CustomStatefulOp** op_inst) {
int count = 0;
if (attrs.count("test_kw") > 0)
count = std::stoi(attrs["test_kw"]);
*op_inst = new MyStatefulGemm(count);
std::cout << "Info: stateful operator created" << std::endl;
return MX_SUCCESS;
}

MXReturnValue mutateInputs(std::map<std::string, std::string> attrs,
std::vector<int> &input_indices) {
// input_indices.push_back(1); // mark mutate input
return MX_SUCCESS;
}

REGISTER_OP(state_gemm)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape)
.setMutateInputs(mutateInputs)
.setCreateOpState(createOpState);

MXReturnValue initialize(int version) {
if (version >= 10400) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
return MX_FAIL;
}
}