From cdf1c5dcf752ca23263eab9132ade8b0edd1d0e5 Mon Sep 17 00:00:00 2001 From: Eric Junyuan Xie Date: Wed, 12 Jul 2017 10:04:40 -0700 Subject: [PATCH] Refactor Stateful operator and custom op (#6928) * refactor create layer * fix * refactor custom op * fix * fix * fix * fix * fix OpState * remove superfluous infershape * fix * fix * fix lint * fix * fix * fix * Update CMakeLists.txt * delete * fix * fix scala --- Jenkinsfile | 12 +- cpp-package/include/mxnet-cpp/MxNetCpp.h | 6 +- cpp-package/include/mxnet-cpp/base.h | 6 +- cpp-package/include/mxnet-cpp/executor.h | 6 +- cpp-package/include/mxnet-cpp/executor.hpp | 6 +- cpp-package/include/mxnet-cpp/initializer.h | 6 +- cpp-package/include/mxnet-cpp/io.h | 6 +- cpp-package/include/mxnet-cpp/io.hpp | 6 +- cpp-package/include/mxnet-cpp/kvstore.h | 6 +- cpp-package/include/mxnet-cpp/kvstore.hpp | 6 +- cpp-package/include/mxnet-cpp/metric.h | 6 +- cpp-package/include/mxnet-cpp/model.h | 6 +- cpp-package/include/mxnet-cpp/monitor.h | 6 +- cpp-package/include/mxnet-cpp/monitor.hpp | 6 +- cpp-package/include/mxnet-cpp/ndarray.h | 6 +- cpp-package/include/mxnet-cpp/ndarray.hpp | 6 +- cpp-package/include/mxnet-cpp/op_map.h | 6 +- cpp-package/include/mxnet-cpp/op_suppl.h | 6 +- cpp-package/include/mxnet-cpp/op_util.h | 6 +- cpp-package/include/mxnet-cpp/operator.h | 6 +- cpp-package/include/mxnet-cpp/operator.hpp | 6 +- cpp-package/include/mxnet-cpp/optimizer.h | 6 +- cpp-package/include/mxnet-cpp/optimizer.hpp | 6 +- cpp-package/include/mxnet-cpp/shape.h | 6 +- cpp-package/include/mxnet-cpp/symbol.h | 6 +- cpp-package/include/mxnet-cpp/symbol.hpp | 6 +- cpp-package/scripts/OpWrapperGenerator.py | 6 +- dmlc-core | 2 +- include/mxnet/base.h | 6 + include/mxnet/c_api.h | 8 +- include/mxnet/op_attr_types.h | 160 +++++- include/mxnet/operator.h | 70 +-- python/mxnet/ndarray.py | 4 +- python/mxnet/operator.py | 2 +- .../scala/ml/dmlc/mxnet/OperatorSuite.scala | 2 + src/c_api/c_api.cc | 2 +- src/c_api/c_api_ndarray.cc | 106 ++-- src/common/utils.h | 17 + src/engine/naive_engine.cc | 8 +- src/engine/stream_manager.h | 14 +- src/engine/threaded_engine_perdevice.cc | 38 +- src/executor/attach_op_execs_pass.cc | 228 +++++---- src/executor/exec_pass.h | 6 +- src/executor/graph_executor.cc | 30 +- src/executor/graph_executor.h | 7 +- src/ndarray/autograd.cc | 18 +- src/ndarray/autograd.h | 8 +- src/ndarray/ndarray.cc | 6 +- src/ndarray/ndarray_function.cu | 8 +- src/nnvm/legacy_op_util.cc | 116 ++++- src/operator/activation.cc | 4 - src/operator/batch_norm.cc | 16 - src/operator/bilinear_sampler.cc | 4 - src/operator/convolution.cc | 3 - src/operator/cross_device_copy.cc | 12 +- src/operator/custom/custom-inl.h | 270 +--------- src/operator/custom/custom.cc | 463 +++++++++++++----- src/operator/custom/ndarray_op-inl.h | 8 +- src/operator/deconvolution.cc | 2 - src/operator/dropout-inl.h | 4 +- src/operator/dropout.cc | 4 - src/operator/fully_connected.cc | 2 - src/operator/grid_generator.cc | 4 - src/operator/instance_norm.cc | 4 - src/operator/lrn.cc | 14 +- src/operator/pad.cc | 4 - src/operator/pooling.cc | 4 - src/operator/rnn.cc | 4 - src/operator/roi_pooling.cc | 4 - src/operator/sequence_last.cc | 4 - src/operator/sequence_mask.cc | 4 - src/operator/sequence_reverse.cc | 4 - src/operator/softmax_output.cc | 4 - src/operator/spatial_transformer.cc | 4 - src/operator/svm_output.cc | 5 - src/operator/swapaxis.cc | 4 - src/operator/upsampling.cc | 4 - tests/cpp/include/test_op.h | 6 +- tests/cpp/include/test_perf.h | 6 +- tests/cpp/include/test_util.h | 6 +- tests/python/unittest/test_operator.py | 16 +- 81 files changed, 1012 insertions(+), 919 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 48f6251a0be2..26a96a4843bc 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -215,9 +215,11 @@ del /Q *.7z // Python unittest for CPU def python_ut(docker_type) { timeout(time: max_time, unit: 'MINUTES') { + sh "${docker_run} ${docker_type} find . -name '*.pyc' -type f -delete" sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/unittest" - sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/unittest" sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/train" + sh "${docker_run} ${docker_type} find . -name '*.pyc' -type f -delete" + sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/unittest" } } @@ -225,7 +227,9 @@ def python_ut(docker_type) { // both CPU and GPU def python_gpu_ut(docker_type) { timeout(time: max_time, unit: 'MINUTES') { + sh "${docker_run} ${docker_type} find . -name '*.pyc' -type f -delete" sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/gpu" + sh "${docker_run} ${docker_type} find . -name '*.pyc' -type f -delete" sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/gpu" } } @@ -312,11 +316,13 @@ stage('Unit Test') { xcopy C:\\mxnet\\model model /E /I /Y call activate py3 set PYTHONPATH=${env.WORKSPACE}\\pkg_vc14_cpu\\python +del /S /Q ${env.WORKSPACE}\\pkg_vc14_cpu\\python\\*.pyc C:\\mxnet\\test_cpu.bat""" bat """xcopy C:\\mxnet\\data data /E /I /Y xcopy C:\\mxnet\\model model /E /I /Y call activate py2 set PYTHONPATH=${env.WORKSPACE}\\pkg_vc14_cpu\\python +del /S /Q ${env.WORKSPACE}\\pkg_vc14_cpu\\python\\*.pyc C:\\mxnet\\test_cpu.bat""" } } @@ -332,11 +338,13 @@ C:\\mxnet\\test_cpu.bat""" xcopy C:\\mxnet\\model model /E /I /Y call activate py3 set PYTHONPATH=${env.WORKSPACE}\\pkg_vc14_gpu\\python +del /S /Q ${env.WORKSPACE}\\pkg_vc14_gpu\\python\\*.pyc C:\\mxnet\\test_gpu.bat""" bat """xcopy C:\\mxnet\\data data /E /I /Y xcopy C:\\mxnet\\model model /E /I /Y call activate py2 set PYTHONPATH=${env.WORKSPACE}\\pkg_vc14_gpu\\python +del /S /Q ${env.WORKSPACE}\\pkg_vc14_gpu\\python\\*.pyc C:\\mxnet\\test_gpu.bat""" } } @@ -390,4 +398,4 @@ stage('Deploy') { } } } -} \ No newline at end of file +} diff --git a/cpp-package/include/mxnet-cpp/MxNetCpp.h b/cpp-package/include/mxnet-cpp/MxNetCpp.h index 8ed90e3c751a..5d61b823baa2 100644 --- a/cpp-package/include/mxnet-cpp/MxNetCpp.h +++ b/cpp-package/include/mxnet-cpp/MxNetCpp.h @@ -5,8 +5,8 @@ * \author Chuntao Hong, Zhang Chen */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_MXNETCPP_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_MXNETCPP_H_ +#ifndef MXNET_CPP_MXNETCPP_H_ +#define MXNET_CPP_MXNETCPP_H_ #include "mxnet-cpp/executor.hpp" #include "mxnet-cpp/symbol.hpp" @@ -21,4 +21,4 @@ #include "mxnet-cpp/metric.h" #include "mxnet-cpp/initializer.h" -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_MXNETCPP_H_ +#endif // MXNET_CPP_MXNETCPP_H_ diff --git a/cpp-package/include/mxnet-cpp/base.h b/cpp-package/include/mxnet-cpp/base.h index 18f268a8a85a..b684986a6f54 100644 --- a/cpp-package/include/mxnet-cpp/base.h +++ b/cpp-package/include/mxnet-cpp/base.h @@ -5,8 +5,8 @@ * \author Chuntao Hong, Zhang Chen */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_BASE_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_BASE_H_ +#ifndef MXNET_CPP_BASE_H_ +#define MXNET_CPP_BASE_H_ #include #include "mxnet/c_api.h" @@ -35,4 +35,4 @@ enum OpReqType { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_BASE_H_ +#endif // MXNET_CPP_BASE_H_ diff --git a/cpp-package/include/mxnet-cpp/executor.h b/cpp-package/include/mxnet-cpp/executor.h index e4343a19a50d..822344b7efee 100644 --- a/cpp-package/include/mxnet-cpp/executor.h +++ b/cpp-package/include/mxnet-cpp/executor.h @@ -5,8 +5,8 @@ * \author Chuntao Hong, Zhang Chen */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_EXECUTOR_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_EXECUTOR_H_ +#ifndef MXNET_CPP_EXECUTOR_H_ +#define MXNET_CPP_EXECUTOR_H_ #include #include @@ -135,4 +135,4 @@ class Executor { }; } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_EXECUTOR_H_ +#endif // MXNET_CPP_EXECUTOR_H_ diff --git a/cpp-package/include/mxnet-cpp/executor.hpp b/cpp-package/include/mxnet-cpp/executor.hpp index 4cae684f8881..1a452a1610db 100644 --- a/cpp-package/include/mxnet-cpp/executor.hpp +++ b/cpp-package/include/mxnet-cpp/executor.hpp @@ -5,8 +5,8 @@ * \author Zhang Chen, Chuntao Hong */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_EXECUTOR_HPP_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_EXECUTOR_HPP_ +#ifndef MXNET_CPP_EXECUTOR_HPP_ +#define MXNET_CPP_EXECUTOR_HPP_ #include #include @@ -89,4 +89,4 @@ inline void Executor::UpdateAll(Optimizer *opt, float lr, float wd, } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_EXECUTOR_HPP_ +#endif // MXNET_CPP_EXECUTOR_HPP_ diff --git a/cpp-package/include/mxnet-cpp/initializer.h b/cpp-package/include/mxnet-cpp/initializer.h index 843965256df1..f28656577482 100644 --- a/cpp-package/include/mxnet-cpp/initializer.h +++ b/cpp-package/include/mxnet-cpp/initializer.h @@ -5,8 +5,8 @@ * \author Zhang Chen */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_INITIALIZER_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_INITIALIZER_H_ +#ifndef MXNET_CPP_INITIALIZER_H_ +#define MXNET_CPP_INITIALIZER_H_ #include #include @@ -179,4 +179,4 @@ class Xavier : public Initializer { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_INITIALIZER_H_ +#endif // MXNET_CPP_INITIALIZER_H_ diff --git a/cpp-package/include/mxnet-cpp/io.h b/cpp-package/include/mxnet-cpp/io.h index 171803831109..727a96467c63 100644 --- a/cpp-package/include/mxnet-cpp/io.h +++ b/cpp-package/include/mxnet-cpp/io.h @@ -4,8 +4,8 @@ * \brief definition of io, such as DataIter * \author Zhang Chen */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_IO_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_IO_H_ +#ifndef MXNET_CPP_IO_H_ +#define MXNET_CPP_IO_H_ #include #include @@ -124,5 +124,5 @@ class MXDataIter : public DataIter { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_IO_H_ +#endif // MXNET_CPP_IO_H_ diff --git a/cpp-package/include/mxnet-cpp/io.hpp b/cpp-package/include/mxnet-cpp/io.hpp index 61e575e949a9..1be7993fbe4f 100644 --- a/cpp-package/include/mxnet-cpp/io.hpp +++ b/cpp-package/include/mxnet-cpp/io.hpp @@ -4,8 +4,8 @@ * \brief implementation of data iter * \author Zhang Chen */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_IO_HPP_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_IO_HPP_ +#ifndef MXNET_CPP_IO_HPP_ +#define MXNET_CPP_IO_HPP_ #include #include @@ -86,5 +86,5 @@ inline MXDataIter MXDataIter::CreateDataIter() { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_IO_HPP_ +#endif // MXNET_CPP_IO_HPP_ diff --git a/cpp-package/include/mxnet-cpp/kvstore.h b/cpp-package/include/mxnet-cpp/kvstore.h index 6d3987ecf030..9bb33a4733dd 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.h +++ b/cpp-package/include/mxnet-cpp/kvstore.h @@ -5,8 +5,8 @@ * \author Chuntao Hong */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_KVSTORE_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_KVSTORE_H_ +#ifndef MXNET_CPP_KVSTORE_H_ +#define MXNET_CPP_KVSTORE_H_ #include #include @@ -46,4 +46,4 @@ class KVStore { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_KVSTORE_H_ +#endif // MXNET_CPP_KVSTORE_H_ diff --git a/cpp-package/include/mxnet-cpp/kvstore.hpp b/cpp-package/include/mxnet-cpp/kvstore.hpp index d9effcf82f3c..4f66c1d637a5 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.hpp +++ b/cpp-package/include/mxnet-cpp/kvstore.hpp @@ -14,8 +14,8 @@ #include "mxnet-cpp/kvstore.h" #include "mxnet-cpp/optimizer.h" -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_KVSTORE_HPP_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_KVSTORE_HPP_ +#ifndef MXNET_CPP_KVSTORE_HPP_ +#define MXNET_CPP_KVSTORE_HPP_ namespace mxnet { namespace cpp { @@ -175,4 +175,4 @@ inline std::string KVStore::GetRole() { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_KVSTORE_HPP_ +#endif // MXNET_CPP_KVSTORE_HPP_ diff --git a/cpp-package/include/mxnet-cpp/metric.h b/cpp-package/include/mxnet-cpp/metric.h index 24b3d73bae00..eda927199ca8 100644 --- a/cpp-package/include/mxnet-cpp/metric.h +++ b/cpp-package/include/mxnet-cpp/metric.h @@ -5,8 +5,8 @@ * \author Zhang Chen */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_METRIC_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_METRIC_H_ +#ifndef MXNET_CPP_METRIC_H_ +#define MXNET_CPP_METRIC_H_ #include #include @@ -187,5 +187,5 @@ class PSNR : public EvalMetric { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_METRIC_H_ +#endif // MXNET_CPP_METRIC_H_ diff --git a/cpp-package/include/mxnet-cpp/model.h b/cpp-package/include/mxnet-cpp/model.h index 7bfe1980f095..e4cb1a9aee95 100644 --- a/cpp-package/include/mxnet-cpp/model.h +++ b/cpp-package/include/mxnet-cpp/model.h @@ -5,8 +5,8 @@ * \author Zhang Chen */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_MODEL_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_MODEL_H_ +#ifndef MXNET_CPP_MODEL_H_ +#define MXNET_CPP_MODEL_H_ #include #include @@ -54,5 +54,5 @@ class FeedForward { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_MODEL_H_ +#endif // MXNET_CPP_MODEL_H_ diff --git a/cpp-package/include/mxnet-cpp/monitor.h b/cpp-package/include/mxnet-cpp/monitor.h index 2ce4e9590794..afe030cbd5d8 100644 --- a/cpp-package/include/mxnet-cpp/monitor.h +++ b/cpp-package/include/mxnet-cpp/monitor.h @@ -5,8 +5,8 @@ * \author Xin Li */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_MONITOR_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_MONITOR_H_ +#ifndef MXNET_CPP_MONITOR_H_ +#define MXNET_CPP_MONITOR_H_ #include #include @@ -85,4 +85,4 @@ class Monitor { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_MONITOR_H_ +#endif // MXNET_CPP_MONITOR_H_ diff --git a/cpp-package/include/mxnet-cpp/monitor.hpp b/cpp-package/include/mxnet-cpp/monitor.hpp index d37652dd2c05..eef218bff41d 100644 --- a/cpp-package/include/mxnet-cpp/monitor.hpp +++ b/cpp-package/include/mxnet-cpp/monitor.hpp @@ -5,8 +5,8 @@ * \author Xin Li */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_MONITOR_HPP_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_MONITOR_HPP_ +#ifndef MXNET_CPP_MONITOR_HPP_ +#define MXNET_CPP_MONITOR_HPP_ #include #include @@ -103,4 +103,4 @@ inline void Monitor::executor_callback(const char *name, NDArrayHandle handle, } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_MONITOR_HPP_ +#endif // MXNET_CPP_MONITOR_HPP_ diff --git a/cpp-package/include/mxnet-cpp/ndarray.h b/cpp-package/include/mxnet-cpp/ndarray.h index f908b4ff38eb..52451faa94cc 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.h +++ b/cpp-package/include/mxnet-cpp/ndarray.h @@ -5,8 +5,8 @@ * \author Chuntao Hong, Zhang Chen */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_NDARRAY_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_NDARRAY_H_ +#ifndef MXNET_CPP_NDARRAY_H_ +#define MXNET_CPP_NDARRAY_H_ #include #include @@ -428,4 +428,4 @@ std::ostream& operator<<(std::ostream& out, const NDArray &ndarray); } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_NDARRAY_H_ +#endif // MXNET_CPP_NDARRAY_H_ diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp b/cpp-package/include/mxnet-cpp/ndarray.hpp index 6157a6600cb4..ba0954b3f815 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.hpp +++ b/cpp-package/include/mxnet-cpp/ndarray.hpp @@ -5,8 +5,8 @@ * \author Zhang Chen, Chuntao Hong */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_NDARRAY_HPP_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_NDARRAY_HPP_ +#ifndef MXNET_CPP_NDARRAY_HPP_ +#define MXNET_CPP_NDARRAY_HPP_ #include #include @@ -378,4 +378,4 @@ inline std::ostream & operator<<(std::ostream &out, const NDArray &ndarray) { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_NDARRAY_HPP_ +#endif // MXNET_CPP_NDARRAY_HPP_ diff --git a/cpp-package/include/mxnet-cpp/op_map.h b/cpp-package/include/mxnet-cpp/op_map.h index 2a2ae50a4e84..ea75a8ca7b4c 100644 --- a/cpp-package/include/mxnet-cpp/op_map.h +++ b/cpp-package/include/mxnet-cpp/op_map.h @@ -5,8 +5,8 @@ * \author Chuntao Hong */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_OP_MAP_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_OP_MAP_H_ +#ifndef MXNET_CPP_OP_MAP_H_ +#define MXNET_CPP_OP_MAP_H_ #include #include @@ -89,4 +89,4 @@ class OpMap { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_OP_MAP_H_ +#endif // MXNET_CPP_OP_MAP_H_ diff --git a/cpp-package/include/mxnet-cpp/op_suppl.h b/cpp-package/include/mxnet-cpp/op_suppl.h index c40449cc9f89..b66521bc0654 100644 --- a/cpp-package/include/mxnet-cpp/op_suppl.h +++ b/cpp-package/include/mxnet-cpp/op_suppl.h @@ -5,8 +5,8 @@ * \author Zhang Chen, zhubuntu, Xin Li */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_OP_SUPPL_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_OP_SUPPL_H_ +#ifndef MXNET_CPP_OP_SUPPL_H_ +#define MXNET_CPP_OP_SUPPL_H_ #include #include @@ -157,5 +157,5 @@ inline Symbol Activation(const std::string& symbol_name, } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_OP_SUPPL_H_ +#endif // MXNET_CPP_OP_SUPPL_H_ diff --git a/cpp-package/include/mxnet-cpp/op_util.h b/cpp-package/include/mxnet-cpp/op_util.h index bf67eab4c1ae..5a737480d469 100644 --- a/cpp-package/include/mxnet-cpp/op_util.h +++ b/cpp-package/include/mxnet-cpp/op_util.h @@ -5,8 +5,8 @@ * \author Chris Olivier */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_OP_UTIL_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_OP_UTIL_H_ +#ifndef MXNET_CPP_OP_UTIL_H_ +#define MXNET_CPP_OP_UTIL_H_ #include @@ -43,4 +43,4 @@ inline StreamType& operator << (StreamType& os, const ::caffe::LayerParameter& o } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_OP_UTIL_H_ +#endif // MXNET_CPP_OP_UTIL_H_ diff --git a/cpp-package/include/mxnet-cpp/operator.h b/cpp-package/include/mxnet-cpp/operator.h index 9a492576d104..4fc45bbc9f04 100644 --- a/cpp-package/include/mxnet-cpp/operator.h +++ b/cpp-package/include/mxnet-cpp/operator.h @@ -5,8 +5,8 @@ * \author Chuntao Hong, Zhang Chen */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_OPERATOR_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_OPERATOR_H_ +#ifndef MXNET_CPP_OPERATOR_H_ +#define MXNET_CPP_OPERATOR_H_ #include #include @@ -188,4 +188,4 @@ class Operator { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_OPERATOR_H_ +#endif // MXNET_CPP_OPERATOR_H_ diff --git a/cpp-package/include/mxnet-cpp/operator.hpp b/cpp-package/include/mxnet-cpp/operator.hpp index 8a421d7b6b4f..17f4885133fc 100644 --- a/cpp-package/include/mxnet-cpp/operator.hpp +++ b/cpp-package/include/mxnet-cpp/operator.hpp @@ -5,8 +5,8 @@ * \author Chuntao Hong, Zhang Chen */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_OPERATOR_HPP_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_OPERATOR_HPP_ +#ifndef MXNET_CPP_OPERATOR_HPP_ +#define MXNET_CPP_OPERATOR_HPP_ #include #include @@ -155,4 +155,4 @@ inline Operator &Operator::SetInput(const std::string &name, NDArray ndarray) { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_OPERATOR_HPP_ +#endif // MXNET_CPP_OPERATOR_HPP_ diff --git a/cpp-package/include/mxnet-cpp/optimizer.h b/cpp-package/include/mxnet-cpp/optimizer.h index 8dbbbf7f39ea..76f8a3564fbe 100644 --- a/cpp-package/include/mxnet-cpp/optimizer.h +++ b/cpp-package/include/mxnet-cpp/optimizer.h @@ -5,8 +5,8 @@ * \author Chuntao Hong, Zhang Chen */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_OPTIMIZER_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_OPTIMIZER_H_ +#ifndef MXNET_CPP_OPTIMIZER_H_ +#define MXNET_CPP_OPTIMIZER_H_ #include #include @@ -176,4 +176,4 @@ class AdaDeltaOptimizer : public Optimizer { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_OPTIMIZER_H_ +#endif // MXNET_CPP_OPTIMIZER_H_ diff --git a/cpp-package/include/mxnet-cpp/optimizer.hpp b/cpp-package/include/mxnet-cpp/optimizer.hpp index c86476f65417..9dcb158b9e14 100644 --- a/cpp-package/include/mxnet-cpp/optimizer.hpp +++ b/cpp-package/include/mxnet-cpp/optimizer.hpp @@ -5,8 +5,8 @@ * \author Chuntao Hong, Zhang Chen */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_OPTIMIZER_HPP_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_OPTIMIZER_HPP_ +#ifndef MXNET_CPP_OPTIMIZER_HPP_ +#define MXNET_CPP_OPTIMIZER_HPP_ #include #include @@ -387,4 +387,4 @@ inline void AdaDeltaOptimizer::CreateState_(int index, NDArray weight) { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_OPTIMIZER_HPP_ +#endif // MXNET_CPP_OPTIMIZER_HPP_ diff --git a/cpp-package/include/mxnet-cpp/shape.h b/cpp-package/include/mxnet-cpp/shape.h index d8e3f2c95282..d30ea9df2531 100644 --- a/cpp-package/include/mxnet-cpp/shape.h +++ b/cpp-package/include/mxnet-cpp/shape.h @@ -5,8 +5,8 @@ * \author Chuntao Hong, Zhang Chen */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_SHAPE_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_SHAPE_H_ +#ifndef MXNET_CPP_SHAPE_H_ +#define MXNET_CPP_SHAPE_H_ #include #include @@ -386,4 +386,4 @@ inline std::istream &operator>>(std::istream &is, Shape &shape) { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_SHAPE_H_ +#endif // MXNET_CPP_SHAPE_H_ diff --git a/cpp-package/include/mxnet-cpp/symbol.h b/cpp-package/include/mxnet-cpp/symbol.h index e853c2617ea4..c04ae2a03d29 100644 --- a/cpp-package/include/mxnet-cpp/symbol.h +++ b/cpp-package/include/mxnet-cpp/symbol.h @@ -5,8 +5,8 @@ * \author Chuntao Hong, Zhang Chen */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_SYMBOL_H_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_SYMBOL_H_ +#ifndef MXNET_CPP_SYMBOL_H_ +#define MXNET_CPP_SYMBOL_H_ #include #include @@ -257,4 +257,4 @@ Symbol operator/(mx_float lhs, const Symbol &rhs); Symbol operator%(mx_float lhs, const Symbol &rhs); } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_SYMBOL_H_ +#endif // MXNET_CPP_SYMBOL_H_ diff --git a/cpp-package/include/mxnet-cpp/symbol.hpp b/cpp-package/include/mxnet-cpp/symbol.hpp index 26962ba5c99b..7f88e485830f 100644 --- a/cpp-package/include/mxnet-cpp/symbol.hpp +++ b/cpp-package/include/mxnet-cpp/symbol.hpp @@ -5,8 +5,8 @@ * \author Zhang Chen, Chuntao Hong */ -#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_SYMBOL_HPP_ -#define CPP_PACKAGE_INCLUDE_MXNET_CPP_SYMBOL_HPP_ +#ifndef MXNET_CPP_SYMBOL_HPP_ +#define MXNET_CPP_SYMBOL_HPP_ #include #include @@ -347,4 +347,4 @@ inline Symbol operator%(mx_float lhs, const Symbol &rhs) { } // namespace cpp } // namespace mxnet -#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_SYMBOL_HPP_ +#endif // MXNET_CPP_SYMBOL_HPP_ diff --git a/cpp-package/scripts/OpWrapperGenerator.py b/cpp-package/scripts/OpWrapperGenerator.py index 392e07f9caa4..8f762368d0a4 100644 --- a/cpp-package/scripts/OpWrapperGenerator.py +++ b/cpp-package/scripts/OpWrapperGenerator.py @@ -372,8 +372,8 @@ def ParseAllOps(): "* \\author Chuntao Hong, Xin Li\n" "*/\n" "\n" - "#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_OP_H_\n" - "#define CPP_PACKAGE_INCLUDE_MXNET_CPP_OP_H_\n" + "#ifndef MXNET_CPP_OP_H_\n" + "#define MXNET_CPP_OP_H_\n" "\n" "#include \n" "#include \n" @@ -389,7 +389,7 @@ def ParseAllOps(): "%s" "} //namespace cpp\n" "} //namespace mxnet\n" - "#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_OP_H_\n") + "#endif // MXNET_CPP_OP_H_\n") # Generate a temporary file name tf = tempfile.NamedTemporaryFile() diff --git a/dmlc-core b/dmlc-core index a6c5701219e6..b647be2dee98 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit a6c5701219e635fea808d264aefc5b03c3aec314 +Subproject commit b647be2dee985d77a12e8e41bc27382221938290 diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 0c4c9d3daa77..739105b388bc 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -211,6 +211,8 @@ struct Context { * The information needed in runtime for actual execution. */ struct RunContext { + /*! \brief base Context */ + Context ctx; /*! * \brief the stream of the device, can be NULL or Stream* in GPU mode */ @@ -224,6 +226,10 @@ struct RunContext { inline mshadow::Stream* get_stream() const { return static_cast*>(stream); } + /*! \brief get the base Context from RunContext */ + inline const Context& get_ctx() const { + return ctx; + } }; } // namespace mxnet diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index b8f8411353bf..8bc1451ba90d 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -139,12 +139,12 @@ typedef int (*CustomOpBwdDepFunc)(const int* /*out_grad*/, const int* /*in_data* const int* /*out_data*/, int* /*num_deps*/, int** /*rdeps*/, void* /*state*/); typedef int (*CustomOpCreateFunc)(const char* /*ctx*/, int /*num_inputs*/, - unsigned** /*shapes*/, int* /*ndims*/, - int* /*dtypes*/, struct MXCallbackList* /*ret*/, + unsigned** /*shapes*/, const int* /*ndims*/, + const int* /*dtypes*/, struct MXCallbackList* /*ret*/, void* /*state*/); typedef int (*CustomOpPropCreator)(const char* /*op_type*/, const int /*num_kwargs*/, - const char** /*keys*/, const char** /*values*/, - struct MXCallbackList* /*ret*/); + const char** /*keys*/, const char** /*values*/, + struct MXCallbackList* /*ret*/); /*! * \brief return str message of the last error diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 316a90fe0841..dbf9a07e0bcb 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -15,27 +15,173 @@ #include #include "./base.h" -#include "./operator.h" #include "./ndarray.h" +#include "./engine.h" namespace mxnet { using nnvm::NodeAttrs; + +/*! \brief operation request type to Forward and Backward */ +enum OpReqType { + /*! \brief no operation, do not write anything */ + kNullOp, + /*! \brief write gradient to provided space */ + kWriteTo, + /*! + * \brief perform an inplace write, + * Target shares memory with one of input arguments. + * This option only happen when + */ + kWriteInplace, + /*! \brief add to the provided space */ + kAddTo +}; + +/*! + * \brief All the possible information needed by Operator.Forward and Backward + * This is the superset of RunContext. + * We use this data structure to bookkeep everything needed by Forward and Backward. + * \sa Resource + */ +struct OpContext { + /*! \brief whether it is training phase */ + int is_train; + /*! \brief RunContext related resources */ + RunContext run_ctx; + /*! \brief the callback when operation completes, used by asynchronize ops */ + engine::CallbackOnComplete async_on_complete; + /*! \brief Resources requested by the operator */ + std::vector requested; + /*! + * \brief get mshadow stream from Context + * \return the mshadow stream + * \tparam xpu the device type of the stream + */ + template + inline mshadow::Stream* get_stream() const { + return run_ctx.get_stream(); + } +}; + +/*! \brief the execution type of the operator */ +enum class ExecType { + /*! \brief Forward/Backward are synchronize calls */ + kSync, + /*! + * \brief Forward/Backward are asynchronize, + * will call OpContext.async_on_complete when operation finishes. + */ + kAsync, + /*! \brief Run this operator on the scheduling thread without pushing to engine. */ + kLocal, + /*! + * \brief Cross device copy operation, this is a special operator + * That indicates copy across devices, the input and output can sit on different device. + * In current implementation, copy operator is specially handled by executor. + * This flag is used for special case treatment and future extension of different copy ops. + */ + kCrossDeviceCopy +}; + +/*! + * \brief Operator state. This is a pointer type, its content is mutable + * even if OpStatePtr is const. + */ +class OpStatePtr { + public: + /* \brief Create a OpStatePtr with state of type T. + * \param args Arguments passed to T's constructor. + */ + template + static OpStatePtr Create(Args&&... args) { + OpStatePtr ret; + ret.ptr_ = std::make_shared(); + ret.ptr_->var_ = Engine::Get()->NewVariable(); + ret.ptr_->state_.construct(std::forward(args)...); + + return ret; + } + /* \brief Get engine variable associated with this state */ + engine::VarHandle get_var() const { + return ptr_->var_; + } + /* \brief Get state of type T */ + template + T& get_state() const { + return dmlc::get(ptr_->state_); + } + /* \brief clear state */ + void reset() { + ptr_.reset(); + } + /* \brief Whether state is empty */ + explicit operator bool() const { + return ptr_ ? true : false; + } + + private: + /* \brief state structure */ + struct OpState { + OpState() {} + OpState(const OpState& other) = delete; + OpState& operator=(const OpState& other) = delete; + + ~OpState() { + Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), var_); + } + + engine::VarHandle var_; + dmlc::any state_; + }; + /* \brief shared pointer to state */ + std::shared_ptr ptr_; +}; + /*! * \brief Create a Layer style, forward/backward operator. * This is easy to write code that contains state. + * OpStatePtr is a pointer type, it's content is mutable even if + * OpStatePtr is constant. + * * * This is not the only way to register an op execution function. * More simpler or specialized operator form can be registered * * \note Register under "FCreateLayerOp" */ -using FCreateLayerOp = std::function< - Operator* (const NodeAttrs& n, - Context ctx, - const std::vector& in_shape, - const std::vector& in_type)>; - +using FCreateOpState = std::function& in_shape, + const std::vector& in_type)>; +/*! + * \brief Execution mode of this operator. + */ +using FExecType = std::function; +/*! + * \brief Resiger a compute function for stateful operator. + * OpStatePtr is a pointer type, it's content is mutable even if + * OpStatePtr is constant. + * + * \note Register under "FStatefulCompute" and "FStatefulCompute" + */ +using FStatefulCompute = std::function& inputs, + const std::vector& req, + const std::vector& outputs)>; +/*! + * \brief Resiger a compute function for stateful operator using NDArray interface. + * OpStatePtr is a pointer type, it's content is mutable even if + * OpStatePtr is constant. + * + * \note Register under "FStatefulComputeEx" and "FStatefulComputeEx" + */ +using FStatefulComputeEx = std::function& inputs, + const std::vector& req, + const std::vector& outputs)>; /*! * \brief The resource request from the operator * diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index fe5c3de0279f..09a643390342 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -18,50 +18,9 @@ #include #include "./base.h" #include "./resource.h" +#include "./op_attr_types.h" namespace mxnet { -/*! \brief operation request type to Forward and Backward */ -enum OpReqType { - /*! \brief no operation, do not write anything */ - kNullOp, - /*! \brief write gradient to provided space */ - kWriteTo, - /*! - * \brief perform an inplace write, - * Target shares memory with one of input arguments. - * This option only happen when - */ - kWriteInplace, - /*! \brief add to the provided space */ - kAddTo -}; - -/*! - * \brief All the possible information needed by Operator.Forward and Backward - * This is the superset of RunContext. - * We use this data structure to bookkeep everything needed by Forward and Backward. - * \sa Resource - */ -struct OpContext { - /*! \brief whether it is training phase */ - int is_train; - /*! \brief RunContext related resources */ - RunContext run_ctx; - /*! \brief the callback when operation completes, used by asynchronize ops */ - engine::CallbackOnComplete async_on_complete; - /*! \brief Resources requested by the operator */ - std::vector requested; - /*! - * \brief get mshadow stream from Context - * \return the mshadow stream - * \tparam xpu the device type of the stream - */ - template - inline mshadow::Stream* get_stream() const { - return run_ctx.get_stream(); - } -}; - /*! * \brief Operator interface. * Operator defines basic operation unit of optimized computation graph in mxnet. @@ -76,23 +35,6 @@ struct OpContext { */ class Operator { public: - /*! \brief the execution type of the operator */ - enum ExecType { - /*! \brief Forward/Backward are synchronize calls */ - kSync, - /*! - * \brief Forward/Backward are asynchronize, - * will call OpContext.async_on_complete when operation finishes. - */ - kAsync, - /*! - * \brief Cross device copy operation, this is a special operator - * That indicates copy across devices, the input and output can sit on different device. - * In current implementation, copy operator is specially handled by executor. - * This flag is used for special case treatment and future extension of different copy ops. - */ - kCrossDeviceCopy - }; /*! \brief destructor */ virtual ~Operator() {} /*! @@ -148,9 +90,9 @@ class Operator { const std::vector &aux_states) { LOG(FATAL) << "Backward is not implemented"; } - /*! \return execution type of the operator */ - virtual ExecType exec_type() const { - return kSync; + /*! \return [Deprecated] execution type of the operator */ + virtual ExecType exec_type() const final { // NOLINT(*) exec_type has been moved to OperatorProperty + return ExecType::kSync; } }; @@ -478,6 +420,10 @@ class OperatorProperty { * \return a new constructed OperatorProperty */ static OperatorProperty *Create(const char* type_name); + /*! \return execution type of the operator */ + virtual ExecType exec_type() const { + return ExecType::kSync; + } }; /*! \brief typedef the factory function of operator property */ diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 31b7d7cfb944..001400db95b8 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -258,7 +258,9 @@ def __le__(self, other): return lesser_equal(self, other) def __bool__(self): - raise ValueError("The truth value of an NDArray with more than one element is ambiguous.") + raise ValueError("The truth value of an NDArray is ambiguous. " \ + "Please convert to number with asscalar() first.") + __nonzero__ = __bool__ def __getstate__(self): diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py index a08e764088a5..d57ee717fcf6 100644 --- a/python/mxnet/operator.py +++ b/python/mxnet/operator.py @@ -471,7 +471,7 @@ def infer_shape(self, in_shape): List of aux shapes calculated from in_shape, in the same order as declared in list_auxiliary_states. """ - return in_shape, [in_shape[0]], [] + return in_shape, (in_shape[0],)*len(self.list_outputs()), () def infer_type(self, in_type): """infer_type interface. override to create new operators diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala index a870cabb568b..dfbc864785f1 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala @@ -881,6 +881,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll NDArray.zeros(Shape(numFilter, inputShape(1), kernel._1, kernel._2))) val exeConv = conv.bind(Context.cpu(), args = convArgs, argsGrad = convArgsGrad) val convOutGrad = Random.normal(0, 2, exeConv.outputs.head.shape) + exeConv.forward() exeConv.backward(convOutGrad) val deconvData = convOutGrad @@ -889,6 +890,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll NDArray.zeros(Shape(numFilter, inputShape(1), kernel._1, kernel._2))) val exeDeconv = deconv.bind(Context.cpu(), args = deconvArgs, argsGrad = deconvArgsGrad) val deconvOutGrad = convData + exeDeconv.forward() exeDeconv.backward(deconvOutGrad) assert(reldiff(convArgsGrad(1), deconvArgsGrad(1)) < 1e-5) } diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index bea6437b4c64..a376b3b6802c 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -974,6 +974,6 @@ int MXRtcFree(RtcHandle handle) { int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creator) { API_BEGIN(); - mxnet::op::CustomOpProp::Register(op_type, creator); + mxnet::op::custom::Registry::Get()->Register(op_type, creator); API_END(); } diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index dfdd46b6aa90..98fbe760854e 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -279,59 +279,70 @@ void PushFCompute(const FCompute& fn, 0, PROFILER_MESSAGE(op->name.c_str())); } -void PushOperator(std::shared_ptr opr, +void PushOperator(const OpStatePtr& state, const nnvm::Op* op, const nnvm::NodeAttrs& attrs, const Context& ctx, const std::vector& read_vars, const std::vector& write_vars, const std::vector& requested, - const std::vector& auxidx, const std::vector& ndinputs, const std::vector& ndoutputs) { - struct Capture { - engine::CallbackOnComplete on_complete; - std::shared_ptr opr; - }; + static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); bool is_train = AutogradRuntime::Get()->IsTraining(); - Engine::Get()->PushAsync( - [ctx, opr, auxidx, ndinputs, ndoutputs, requested, is_train]( - RunContext rctx, - engine::CallbackOnComplete on_complete) { - std::vector input_blobs, aux_blobs, output_blobs; - auto atop = auxidx.begin(); - for (size_t i = 0; i < ndinputs.size(); ++i) { - if (atop != auxidx.end() && i == *atop) { - aux_blobs.push_back(ndinputs[i].data()); - ++atop; - } else { - input_blobs.push_back(ndinputs[i].data()); + ExecType exec_type = ExecType::kSync; + if (fexec_type.count(op)) { + exec_type = fexec_type[op](attrs); + } + + auto fcompute = common::GetFCompute(op, "FStatefulCompute", ctx); + if (fcompute != nullptr) { + CHECK(exec_type == ExecType::kSync || exec_type == ExecType::kAsync); + Engine::Get()->PushAsync( + [state, fcompute, ndinputs, ndoutputs, requested, is_train, exec_type]( + RunContext rctx, + engine::CallbackOnComplete on_complete) { + OpContext opctx{is_train, rctx, on_complete, requested}; + std::vector input_blobs, output_blobs; + for (const auto& i : ndinputs) input_blobs.push_back(i.data()); + for (const auto& i : ndoutputs) output_blobs.push_back(i.data()); + std::vector req(output_blobs.size(), kWriteTo); + fcompute(state, opctx, input_blobs, req, output_blobs); + if (exec_type == ExecType::kSync) { + if (rctx.get_ctx().dev_mask() == gpu::kDevMask) { + rctx.get_stream()->Wait(); + } + on_complete(); } - } - for (auto& i : ndoutputs) { - output_blobs.push_back(i.data()); - } - Capture* capture = new Capture({on_complete, opr}); - OpContext opctx{is_train, rctx, - Engine::Get()->CreateCallback( - [](Engine* engine, void *cpt_handle) { - Capture* cpt = static_cast(cpt_handle); - cpt->on_complete(); - delete cpt; - }, static_cast(capture)), - requested}; - std::vector req(output_blobs.size(), kWriteTo); - opr->Forward(opctx, input_blobs, req, output_blobs, aux_blobs); - if (opr->exec_type() != Operator::kAsync) { - if (ctx.dev_mask() == gpu::kDevMask) { - rctx.get_stream()->Wait(); + }, ctx, read_vars, write_vars, FnProperty::kNormal, + 0, PROFILER_MESSAGE(op->name.c_str())); + } else { + auto fcompute_ex = common::GetFCompute( + op, "FStatefulComputeEx", ctx); + CHECK(fcompute_ex != nullptr) + << "One of FStatefulCompute and FStatefulComputeEx must be registered " + << "for stateful operator " << op->name; + const auto& run = [state, fcompute_ex, ndinputs, ndoutputs, requested, is_train, exec_type]( + RunContext rctx, + engine::CallbackOnComplete on_complete) { + OpContext opctx{is_train, rctx, on_complete, requested}; + std::vector req(ndoutputs.size(), kWriteTo); + fcompute_ex(state, opctx, ndinputs, req, ndoutputs); + if (exec_type == ExecType::kSync) { + if (rctx.get_ctx().dev_mask() == gpu::kDevMask) { + rctx.get_stream()->Wait(); + } + on_complete(); } - delete capture; - on_complete(); - } - }, ctx, read_vars, write_vars, FnProperty::kNormal, - 0, PROFILER_MESSAGE(op->name.c_str())); + }; + if (exec_type == ExecType::kLocal) { + run(RunContext{ctx, nullptr}, engine::CallbackOnComplete()); + } else { + Engine::Get()->PushAsync(run, ctx, read_vars, write_vars, FnProperty::kNormal, + 0, PROFILER_MESSAGE(op->name.c_str())); + } + } } void ImperativeInvokeImpl(const Context& default_ctx, @@ -341,7 +352,7 @@ void ImperativeInvokeImpl(const Context& default_ctx, static auto& fcpu = nnvm::Op::GetAttr("FCompute"); static auto& fgpu = nnvm::Op::GetAttr("FCompute"); static auto& ndfunc = nnvm::Op::GetAttr("FNDArrayFunction"); - static auto& createop = nnvm::Op::GetAttr("FCreateLayerOp"); + static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); const nnvm::Op *op = attrs.op; @@ -378,14 +389,15 @@ void ImperativeInvokeImpl(const Context& default_ctx, PushFCompute(fn, op, attrs, ctx, read_vars, write_vars, requested, ndinputs, ndoutputs); } else if (createop.count(op)) { - std::shared_ptr opr( - createop[op](attrs, ctx, ret->arg_shapes, ret->arg_types)); + auto state = + createop[op](attrs, ctx, ret->arg_shapes, ret->arg_types); if (AutogradRuntime::Get()->IsTraining()) { - AutogradRuntime::Get()->RecordImperativeOperator(opr, op, + AutogradRuntime::Get()->RecordImperativeOperator(state, op, attrs, &ndinputs, &ndoutputs); } - PushOperator(opr, op, attrs, ctx, read_vars, write_vars, - requested, auxidx, ndinputs, ndoutputs); + write_vars.push_back(state.get_var()); + PushOperator(state, op, attrs, ctx, read_vars, write_vars, + requested, ndinputs, ndoutputs); } else { LOG(FATAL) << "Operator " << op->name << " is not implemented for " diff --git a/src/common/utils.h b/src/common/utils.h index 789b4d14b9f2..5f50aab4781f 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #endif // DMLC_USE_CXX11 @@ -124,6 +125,22 @@ typename helper::UniqueIf::UnknownBound MakeUnique(size_t n) { template typename helper::UniqueIf::KnownBound MakeUnique(Args&&... args) = delete; +template +FCompType GetFCompute(const nnvm::Op* op, const std::string& name, + const Context& ctx) { + static auto& fcompute_cpu = nnvm::Op::GetAttr(name + ""); + static auto& fcompute_gpu = nnvm::Op::GetAttr(name + ""); + + if (ctx.dev_mask() == cpu::kDevMask) { + return fcompute_cpu.get(op, nullptr); + } else if (ctx.dev_mask() == gpu::kDevMask) { + return fcompute_gpu.get(op, nullptr); + } else { + LOG(FATAL) << "Unknown device mask"; + return nullptr; + } +} + #endif // DMLC_USE_CXX11 } // namespace common diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index efb7bd44981b..11ff7c8138bf 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -138,14 +138,12 @@ class NaiveEngine final : public Engine { if (streams_[dev_id] == nullptr) { streams_[dev_id] = mshadow::NewStream(true, MXNET_USE_CUDNN != 0); } - ctx_.stream = streams_[dev_id]; - exec_fun(ctx_, callback); + exec_fun(RunContext{exec_ctx, streams_[dev_id]}, callback); #else LOG(FATAL) << "GPU is not enabled"; #endif } else { - ctx_.stream = &cpu_stream_; - exec_fun(ctx_, callback); + exec_fun(RunContext{exec_ctx, &cpu_stream_}, callback); } CHECK(this->req_completed_) << "NaiveEngine only support synchronize Push so far"; @@ -176,8 +174,6 @@ class NaiveEngine final : public Engine { static void OnComplete(Engine *engine, void *param) { static_cast(engine)->req_completed_ = true; } - // runtime contetxt - RunContext ctx_; // whether action is completed bool req_completed_; // counter diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index 313db6d2010b..2d684bbb7b9a 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -46,9 +46,10 @@ template RunContext StreamManager::GetRunContext( Context const& ctx) { RunContext ret; - ret.stream = nullptr; switch (ctx.dev_mask()) { - case cpu::kDevMask: break; + case cpu::kDevMask: + ret = RunContext{ctx, nullptr}; + break; case gpu::kDevMask: { #if MXNET_USE_CUDA std::size_t use_counter; @@ -65,7 +66,7 @@ RunContext StreamManager::GetRunContext( use_counter = counter; counter = (counter + 1) % kStreams; } - ret.stream = gpu_streams_.at(ctx.dev_id).at(use_counter); + ret = RunContext{ctx, gpu_streams_.at(ctx.dev_id).at(use_counter)}; break; #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; @@ -79,9 +80,10 @@ template RunContext StreamManager::GetIORunContext( Context const& ctx) { RunContext ret; - ret.stream = nullptr; switch (ctx.dev_mask()) { - case cpu::kDevMask: break; + case cpu::kDevMask: + ret = RunContext{ctx, nullptr}; + break; case gpu::kDevMask: { #if MXNET_USE_CUDA CUDA_CALL(cudaSetDevice(ctx.dev_id)); @@ -91,7 +93,7 @@ RunContext StreamManager::GetIORunContext( gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream(false, false); } } - ret.stream = gpu_io_streams_.at(ctx.dev_id); + ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id)}; break; #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index 2b333d60647a..97356ae91e0d 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -39,7 +39,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { cpu_priority_worker_.reset(new ThreadWorkerBlock()); cpu_priority_worker_->pool.reset(new ThreadPool( cpu_priority_nthreads, [this]() { - this->CPUWorker(cpu_priority_worker_.get()); + this->CPUWorker(Context(), cpu_priority_worker_.get()); })); // GPU tasks will be created lazily } @@ -60,9 +60,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { MSHADOW_CATCH_ERROR(mshadow::SetDevice(ctx.dev_id)); #endif } - RunContext run_ctx; - run_ctx.stream = nullptr; - this->ExecuteOprBlock(run_ctx, opr_block); + this->ExecuteOprBlock(RunContext{ctx, nullptr}, opr_block); } else { if (ctx.dev_mask() == cpu::kDevMask) { if (opr_block->opr->prop == FnProperty::kCPUPrioritized) { @@ -71,10 +69,10 @@ class ThreadedEnginePerDevice : public ThreadedEngine { int dev_id = ctx.dev_id; int nthread = cpu_worker_nthreads_; auto ptr = - cpu_normal_workers_.Get(dev_id, [this, dev_id, nthread]() { + cpu_normal_workers_.Get(dev_id, [this, ctx, nthread]() { auto blk = new ThreadWorkerBlock(); - blk->pool.reset(new ThreadPool(nthread, [this, blk] () { - this->CPUWorker(blk); + blk->pool.reset(new ThreadPool(nthread, [this, ctx, blk] () { + this->CPUWorker(ctx, blk); })); return blk; }); @@ -89,16 +87,15 @@ class ThreadedEnginePerDevice : public ThreadedEngine { bool is_copy = (prop == FnProperty::kCopyFromGPU || prop == FnProperty::kCopyToGPU); int nthread = gpu_worker_nthreads_; - int dev_id = ctx.dev_id; if (is_copy) { auto ptr = - gpu_copy_workers_.Get(dev_id, [this, dev_id, is_copy, nthread]() { + gpu_copy_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() { auto blk = new ThreadWorkerBlock(); blk->pool.reset(new ThreadPool( nthread, - [this, dev_id, is_copy, blk] + [this, ctx, is_copy, blk] (std::shared_ptr ready_event) { - this->GPUWorker(dev_id, is_copy, blk, ready_event); + this->GPUWorker(ctx, is_copy, blk, ready_event); }, true)); return blk; }); @@ -106,13 +103,13 @@ class ThreadedEnginePerDevice : public ThreadedEngine { ptr->task_queue.Push(opr_block, opr_block->priority); } } else { - auto ptr = gpu_normal_workers_.Get(dev_id, [this, dev_id, is_copy, nthread]() { + auto ptr = gpu_normal_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() { auto blk = new ThreadWorkerBlock(); blk->pool.reset(new ThreadPool( nthread, - [this, dev_id, is_copy, blk] + [this, ctx, is_copy, blk] (std::shared_ptr ready_event) { - this->GPUWorker(dev_id, is_copy, blk, ready_event); + this->GPUWorker(ctx, is_copy, blk, ready_event); }, true)); return blk; }); @@ -157,26 +154,25 @@ class ThreadedEnginePerDevice : public ThreadedEngine { * \param block The task block of the worker. */ template - inline void GPUWorker(int dev_id, + inline void GPUWorker(Context ctx, bool is_copy_worker, ThreadWorkerBlock *block, std::shared_ptr ready_event) { #if MXNET_USE_CUDA mshadow::Stream *stream; - RunContext run_ctx; do { ThreadPool::SimpleEvent::SetReadyOnDestroy setReady(ready_event); // allocate stream - mshadow::SetDevice(dev_id); + mshadow::SetDevice(ctx.dev_id); if (is_copy_worker) { stream = mshadow::NewStream(false, false); } else { stream = mshadow::NewStream(true, MXNET_USE_CUDNN != 0); } - run_ctx.stream = stream; } while (false); // execute task OprBlock* opr_block; + RunContext run_ctx{ctx, stream}; auto* task_queue = &(block->task_queue); while (task_queue->Pop(&opr_block)) { this->ExecuteOprBlock(run_ctx, opr_block); @@ -192,10 +188,10 @@ class ThreadedEnginePerDevice : public ThreadedEngine { * \param block The task block of the worker. */ template - inline void CPUWorker(ThreadWorkerBlock *block) { + inline void CPUWorker(Context ctx, + ThreadWorkerBlock *block) { auto* task_queue = &(block->task_queue); - RunContext run_ctx; - run_ctx.stream = nullptr; + RunContext run_ctx{ctx, nullptr}; // execute task OprBlock* opr_block; while (task_queue->Pop(&opr_block)) { diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index 16b55adc15e8..6a0c489a1ec5 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -7,6 +7,7 @@ #include #include #include +#include "../common/utils.h" #include "./exec_pass.h" #if MXNET_USE_MKL2017 == 1 #include @@ -22,116 +23,81 @@ const OperatorProperty* OpPropGetOpProperty(const NodeAttrs& attrs); namespace exec { // forward executor -class ForwardOpExecutor : public OpExecutor { +class StatefulComputeExecutor : public OpExecutor { public: void Run(RunContext rctx) override { op_ctx.run_ctx = rctx; - op_->Forward(op_ctx, in_data_, req, out_data_, aux_data_); + fcompute_(state_, op_ctx, in_data_, req, out_data_); #if MKL_EXPERIMENTAL == 1 mkl_tblobs_prv_to_cpu(in_data_); mkl_tblobs_prv_to_cpu(out_data_); - mkl_tblobs_prv_to_cpu(aux_data_); #endif } void Setup() override { - in_data_.clear(); aux_data_.clear(); + in_data_.clear(); for (size_t i = 0; i < in_array.size(); ++i) { - if (!std::binary_search(aux_index_.begin(), aux_index_.end(), i)) { - in_data_.push_back(in_array[i].data()); - } else { - aux_data_.push_back(in_array[i].data()); - } + in_data_.push_back(in_array[i].data()); + } + out_data_.clear(); + for (size_t i = 0; i < out_array.size(); ++i) { + out_data_.push_back(out_array[i].data()); } - out_data_.resize(out_array.size()); - std::transform(out_array.begin(), out_array.end(), out_data_.begin(), [](const NDArray& nd) { - return nd.data(); - }); } - Operator::ExecType exec_type() const override { - return op_->exec_type(); + + ExecType exec_type() const override { + return exec_type_; } - explicit ForwardOpExecutor(std::shared_ptr op, - std::vector aux_index) - : op_(op), aux_index_(aux_index) { - std::sort(aux_index_.begin(), aux_index_.end()); + + virtual engine::VarHandle var() const { + return state_.get_var(); } + explicit StatefulComputeExecutor(const OpStatePtr& state, + const FStatefulCompute& fcompute, + ExecType exec_type) + : state_(state), fcompute_(fcompute), exec_type_(exec_type) {} + private: friend Graph AttachOpExecs(Graph g); - std::shared_ptr op_; - std::vector aux_index_; - std::vector in_data_, out_data_, aux_data_; + OpStatePtr state_; + FStatefulCompute fcompute_; + ExecType exec_type_; + std::vector in_data_, out_data_; }; -// backward executor -class BackwardOpExecutor : public OpExecutor { + +// forward executor +class StatefulComputeExExecutor : public OpExecutor { public: void Run(RunContext rctx) override { op_ctx.run_ctx = rctx; - op_->Backward(op_ctx, out_grad_, in_data_, out_data_, - req, in_grad_, aux_data_); -#if MKL_EXPERIMENTAL == 1 - mkl_tblobs_prv_to_cpu(out_grad_); - mkl_tblobs_prv_to_cpu(in_data_); - mkl_tblobs_prv_to_cpu(out_data_); - mkl_tblobs_prv_to_cpu(in_grad_); - mkl_tblobs_prv_to_cpu(aux_data_); -#endif + fcompute_(state_, op_ctx, in_array, req, out_array); } - void Setup() override { - size_t arg_top = 0, aux_top = 0; - aux_data_.resize(aux_index_.size()); - for (size_t i = 0; i < in_array.size(); ++i) { - if (!std::binary_search(aux_index_.begin(), aux_index_.end(), i)) { - CHECK_GT(arg_data_ptr_.size(), arg_top); - *arg_data_ptr_[arg_top++] = in_array[i].data(); - } else { - aux_data_.at(aux_top++) = in_array[i].data(); - } - } - CHECK_EQ(out_array.size(), in_grad_.size()); - std::transform(out_array.begin(), out_array.end(), - in_grad_.begin(), [](const NDArray& nd) { - return nd.data(); - }); - } - Operator::ExecType exec_type() const override { - return op_->exec_type(); + + void Setup() override {} + + ExecType exec_type() const override { + return exec_type_; } - explicit BackwardOpExecutor(std::shared_ptr op, - const OperatorProperty* prop, - std::vector aux_index) - : op_(op), aux_index_(aux_index) { - std::sort(aux_index_.begin(), aux_index_.end()); - out_grad_.resize(prop->NumVisibleOutputs()); - in_data_.resize(prop->ListArguments().size()); - in_grad_.resize(in_data_.size()); - out_data_.resize(prop->NumOutputs()); - - std::vector out_grad_ptr(out_grad_.size()); - for (size_t i = 0; i < out_grad_.size(); ++i) { - out_grad_ptr[i] = &out_grad_[i]; - } - std::vector in_data_ptr(in_data_.size()); - for (size_t i = 0; i < in_data_.size(); ++i) { - in_data_ptr[i] = &in_data_[i]; - } - std::vector out_data_ptr(out_data_.size()); - for (size_t i = 0; i < out_data_.size(); ++i) { - out_data_ptr[i] = &out_data_[i]; - } - arg_data_ptr_ = prop->BackwardInputs( - out_grad_ptr, in_data_ptr, out_data_ptr); + + virtual engine::VarHandle var() const { + return state_.get_var(); } + explicit StatefulComputeExExecutor(const OpStatePtr& state, + const FStatefulComputeEx& fcompute, + ExecType exec_type) + : state_(state), fcompute_(fcompute), exec_type_(exec_type) {} + private: - std::shared_ptr op_; - std::vector aux_index_; - std::vector out_grad_, in_grad_, in_data_, out_data_, aux_data_; - std::vector arg_data_ptr_; + friend Graph AttachOpExecs(Graph g); + OpStatePtr state_; + FStatefulComputeEx fcompute_; + ExecType exec_type_; }; + // fcompute executor executor class FComputeExecutor : public OpExecutor { public: @@ -143,6 +109,7 @@ class FComputeExecutor : public OpExecutor { mkl_tblobs_prv_to_cpu(out_data_); #endif } + void Setup() override { in_data_.resize(in_array.size()); out_data_.resize(out_array.size()); @@ -152,29 +119,20 @@ class FComputeExecutor : public OpExecutor { std::transform(in_array.begin(), in_array.end(), in_data_.begin(), get_blob); std::transform(out_array.begin(), out_array.end(), out_data_.begin(), get_blob); } - Operator::ExecType exec_type() const override { - return Operator::kSync; - } - explicit FComputeExecutor(FCompute fcompute, const NodeAttrs& attrs) - : fcompute_(fcompute), attrs_(attrs) { + + ExecType exec_type() const override { + return exec_type_; } - static FCompute GetFCompute(const Op* op, Context ctx) { - static auto& fcompute_cpu = nnvm::Op::GetAttr("FCompute"); - static auto& fcompute_gpu = nnvm::Op::GetAttr("FCompute"); - if (ctx.dev_mask() == cpu::kDevMask) { - return fcompute_cpu.get(op, nullptr); - } else if (ctx.dev_mask() == gpu::kDevMask) { - return fcompute_gpu.get(op, nullptr); - } else { - LOG(FATAL) << "Unknown device mask"; - return nullptr; - } + explicit FComputeExecutor(const NodeAttrs& attrs, FCompute fcompute, + ExecType exec_type) + : attrs_(attrs), fcompute_(fcompute), exec_type_(exec_type) { } private: - FCompute fcompute_; NodeAttrs attrs_; + FCompute fcompute_; + ExecType exec_type_; std::vector in_data_, out_data_; }; @@ -184,15 +142,16 @@ Graph AttachOpExecs(Graph g) { using nnvm::ShapeVector; using nnvm::FMutateInputs; - auto& fcreate_layer_op = nnvm::Op::GetAttr("FCreateLayerOp"); + auto& fcreate_op_state = nnvm::Op::GetAttr("FCreateOpState"); auto& fmutate_inputs = nnvm::Op::GetAttr("FMutateInputs"); + auto& fexec_type = nnvm::Op::GetAttr("FExecType"); auto& is_layer_backward = nnvm::Op::GetAttr("TIsLayerOpBackward"); const auto& vdtype = g.GetAttr("dtype"); const auto& vshape = g.GetAttr("shape"); const auto& vctx = g.GetAttr("context"); - const auto& saved_opr = g.GetAttr< - std::unordered_map>>("saved_opr"); + const auto& saved_states = g.GetAttr< + std::unordered_map >("saved_states"); // get the graph const auto& idx = g.indexed_graph(); @@ -202,39 +161,72 @@ Graph AttachOpExecs(Graph g) { for (size_t i = 0; i < idx.num_nodes(); ++i) { const auto& inode = idx[i]; if (inode.source->is_variable()) continue; + const nnvm::Op *op = inode.source->op(); + ExecType exec_type = ExecType::kSync; std::vector mutate_index; - if (fmutate_inputs.count(inode.source->op())) { - mutate_index = fmutate_inputs[inode.source->op()](inode.source->attrs); + if (fmutate_inputs.count(op)) { + mutate_index = fmutate_inputs[op](inode.source->attrs); + } + if (fexec_type.count(op)) { + exec_type = fexec_type[op](inode.source->attrs); } - FCompute fcompute = FComputeExecutor::GetFCompute(inode.source->op(), vctx[i]); - if (fcreate_layer_op.count(inode.source->op())) { + + if (fcreate_op_state.count(op)) { std::vector ishape; std::vector itype; for (const auto& e : inode.inputs) { ishape.emplace_back(vshape[idx.entry_id(e)]); itype.emplace_back(vdtype[idx.entry_id(e)]); } - std::shared_ptr opr; - if (saved_opr.count(inode.source)) { - opr = saved_opr.at(inode.source); + + OpStatePtr state; + if (saved_states.count(inode.source)) { + state = saved_states.at(inode.source); } else { - opr.reset(fcreate_layer_op[inode.source->op()]( - inode.source->attrs, vctx[i], ishape, itype)); + state = fcreate_op_state[op]( + inode.source->attrs, vctx[i], ishape, itype); } - ret[i] = std::make_shared(opr, mutate_index); - } else if (is_layer_backward.get(inode.source->op(), false)) { + FStatefulCompute fcompute = common::GetFCompute( + op, "FStatefulCompute", vctx[i]); + if (fcompute != nullptr) { + ret[i] = std::make_shared(state, fcompute, exec_type); + } else { + FStatefulComputeEx fcompute_ex = common::GetFCompute( + op, "FStatefulComputeEx", vctx[i]); + CHECK(fcompute_ex != nullptr) + << "One of FStatefulCompute and FStatefulComputeEx must be registered " + << "for stateful operator " << op->name; + ret[i] = std::make_shared(state, fcompute_ex, exec_type); + } + } else if (is_layer_backward.get(op, false)) { CHECK_GE(inode.control_deps.size(), 1); uint32_t fwd_id = inode.control_deps[0]; CHECK(vctx[fwd_id] == vctx[i]); CHECK(ret[fwd_id] != nullptr); - ret[i] = std::make_shared( - dynamic_cast(ret[fwd_id].get())->op_, - mxnet::op::OpPropGetOpProperty(inode.source->attrs), - mutate_index); - } else if (fcompute != nullptr) { - ret[i] = std::make_shared(fcompute, inode.source->attrs); + FStatefulCompute fcompute = common::GetFCompute( + op, "FStatefulCompute", vctx[i]); + if (fcompute != nullptr) { + ret[i] = std::make_shared( + dynamic_cast(ret[fwd_id].get())->state_, + fcompute, exec_type); + } else { + FStatefulComputeEx fcompute_ex = common::GetFCompute( + op, "FStatefulComputeEx", vctx[i]); + CHECK(fcompute_ex != nullptr) + << "One of FStatefulCompute and FStatefulComputeEx must be registered " + << "for stateful operator " << op->name; + ret[i] = std::make_shared( + dynamic_cast(ret[fwd_id].get())->state_, + fcompute_ex, exec_type); + } } else { - LOG(INFO) << "FCompute not registered " << inode.source->op()->name; + FCompute fcompute = common::GetFCompute(op, "FCompute", vctx[i]); + if (fcompute != nullptr) { + ret[i] = std::make_shared( + inode.source->attrs, fcompute, exec_type); + } else { + LOG(INFO) << "FCompute not registered " << op->name; + } } } g.attrs["op_execs"] = std::make_shared(ret); diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 8df6a3c5d3bb..76b02de736e9 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -49,7 +49,11 @@ class OpExecutor { */ virtual void Run(RunContext rctx) = 0; /*! \return the execution type */ - virtual Operator::ExecType exec_type() const = 0; + virtual ExecType exec_type() const = 0; + /*! \return return engine variable for operator states */ + virtual engine::VarHandle var() const { + return nullptr; + } }; /*! diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 2be000112711..add1d36434a8 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -707,7 +707,7 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, } g = DetectInplaceAddTo(g); - g.attrs["saved_opr"] = std::make_shared(std::move(saved_opr_)); + g.attrs["saved_states"] = std::make_shared(std::move(saved_states_)); g = AttachOpExecs(g); g = AttachOpResources(g); graph_ = std::move(g); @@ -1037,7 +1037,7 @@ void GraphExecutor::InitCachedOps() { if (inode.source->is_variable()) continue; if (op_nodes_[nid].skip_exec_node) continue; auto& exec = op_nodes_[nid].exec; - bool is_async = op_nodes_[nid].exec->exec_type() == Operator::kAsync; + bool is_async = op_nodes_[nid].exec->exec_type() == ExecType::kAsync; bool is_gpu = op_nodes_[nid].ctx.dev_mask() == gpu::kDevMask; // the variables @@ -1052,6 +1052,9 @@ void GraphExecutor::InitCachedOps() { for (auto& nd : exec->out_array) { mutate_vars.push_back(nd.var()); } + if (exec->var() != nullptr) { + mutate_vars.push_back(exec->var()); + } // dedup vars Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars); // all vars include both mutate vars and use vars @@ -1100,16 +1103,15 @@ void GraphExecutor::InitOpSegs() { // Generate segments based on the graph structure bool prefer_bulk_exec_inference = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true); - if (prefer_bulk_exec_inference && num_forward_nodes_ == total_num_nodes) { - // bulk the whole graph for inference - cached_seg_opr_[0] = this->CreateCachedSegOpr(0, num_forward_nodes_); - return; - } - // Whether to perform bulk exec for training bool prefer_bulk_exec = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1); // The maximum number of node in a segment executed in bulk size_t num_nodes_threshold = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15); + if (prefer_bulk_exec_inference && num_forward_nodes_ == total_num_nodes) { + // bulk the whole graph for inference + num_nodes_threshold = std::numeric_limits::max(); + } + // create forward segments for training if (prefer_bulk_exec > 0) { size_t topo_start = 0; @@ -1119,7 +1121,7 @@ void GraphExecutor::InitOpSegs() { // check if the segment relies on external input, or exceeds maxinum number of node, // or requires async ops if (node->is_variable() || nid - topo_start > num_nodes_threshold || - op_node.exec->exec_type() != Operator::kSync) { + op_node.exec->exec_type() != ExecType::kSync) { // create a new segment for the previous nodes if the current one cannot be bulked cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); topo_start = nid + 1; @@ -1146,7 +1148,7 @@ void GraphExecutor::InitOpSegs() { continue; } if (idx[nid].source->is_variable() || nid - topo_start > num_nodes_threshold || - op_node.exec->exec_type() != Operator::kSync) { + op_node.exec->exec_type() != ExecType::kSync) { cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); topo_start = nid + 1; } else { @@ -1224,11 +1226,13 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { OpNode& opnode = op_nodes_[nid]; if (op_nodes_[nid].skip_exec_node) continue; opnode.exec->op_ctx.is_train = is_train; - if (opnode.exec->exec_type() == Operator::kCrossDeviceCopy) { + if (opnode.exec->exec_type() == ExecType::kCrossDeviceCopy) { CHECK_EQ(inode.inputs.size(), 1U); CHECK_EQ(opnode.exec->in_array.size(), 1U); CHECK_EQ(opnode.exec->out_array.size(), 1U); CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0])); + } else if (opnode.exec->exec_type() == ExecType::kLocal) { + opnode.exec->Run(RunContext{opnode.ctx, nullptr}); } else if (opnode.cached_opr != nullptr) { #if MXNET_USE_PROFILER bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning; @@ -1271,7 +1275,7 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, OpNode& op_node = op_nodes_[nid]; if (op_node.skip_exec_node) continue; if (inode.source->is_variable()) continue; - if (op_node.exec->exec_type() != Operator::kSync) { + if (op_node.exec->exec_type() != ExecType::kSync) { return ret; } if (pctx == nullptr) pctx = &(op_node.ctx); @@ -1283,7 +1287,7 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, std::inserter(mutate_vars, mutate_vars.end())); std::copy(op_node.use_vars.begin(), op_node.use_vars.end(), std::inserter(use_vars, use_vars.end())); - ret.exec_list.push_back(exec.get()); + ret.exec_list.push_back(exec); #if MXNET_USE_PROFILER opr_names += inode.source->op()->name + ","; #endif diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index d5a4e8c3aa6c..5b6fa395b242 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -21,9 +21,6 @@ namespace mxnet { -using NodeOperatorMap = std::unordered_map>; - // forward declaration namespace exec { class GraphExecutor; @@ -120,7 +117,7 @@ class GraphExecutor : public Executor { // the cached operator Engine::OprHandle opr = nullptr; // list of op executors - std::vector exec_list; + std::vector > exec_list; }; // Initialize in_args, arg_grads, and aux_states void InitArguments(const nnvm::IndexedGraph& idx, @@ -211,7 +208,7 @@ class GraphExecutor : public Executor { // number of forward nodes size_t num_forward_nodes_{0}; // saved operator for autograd - NodeOperatorMap saved_opr_; + std::unordered_map saved_states_; // monitor call back std::function monitor_callback_{nullptr}; // whether to enable bulk execution diff --git a/src/ndarray/autograd.cc b/src/ndarray/autograd.cc index ce1b98f095d8..b35364d0c70f 100644 --- a/src/ndarray/autograd.cc +++ b/src/ndarray/autograd.cc @@ -83,15 +83,15 @@ void AutogradRuntime::RecordImperativeFCompute(const nnvm::Op* op, const nnvm::NodeAttrs& attrs, std::vector *p_inputs, std::vector *p_outputs) { - RecordOp(op, attrs, p_inputs, p_outputs, nullptr); + RecordOp(op, attrs, p_inputs, p_outputs, OpStatePtr()); } -void AutogradRuntime::RecordImperativeOperator(const std::shared_ptr& opr, +void AutogradRuntime::RecordImperativeOperator(const OpStatePtr& state, const nnvm::Op* op, const nnvm::NodeAttrs& attrs, std::vector *p_inputs, std::vector *p_outputs) { - RecordOp(op, attrs, p_inputs, p_outputs, opr); + RecordOp(op, attrs, p_inputs, p_outputs, state); } std::shared_ptr AutogradRuntime::_GetSharedRef() { @@ -108,7 +108,7 @@ AGNodePtr AutogradRuntime::RecordOp(const nnvm::Op* op, const nnvm::NodeAttrs& attrs, std::vector *p_inputs, std::vector *p_outputs, - const std::shared_ptr& opr) { + const OpStatePtr& state) { std::vector& inputs = *p_inputs; std::vector& outputs = *p_outputs; @@ -117,7 +117,7 @@ AGNodePtr AutogradRuntime::RecordOp(const nnvm::Op* op, nn_node->attrs.name = "node_" + std::to_string(node_count_++); AGNodePtr ag_node = AGNode::Create(nn_node); - ag_node->opr = opr; + ag_node->state = state; for (uint32_t i = 0; i < outputs.size(); ++i) { CHECK(outputs[i].entry_.is_none()) @@ -167,13 +167,13 @@ void AutogradRuntime::ComputeGradient(const std::vector& outputs, std::vector args, args_grad; std::vector aux_states; std::vector grad_reqs; - std::unordered_map> saved_opr; + std::unordered_map saved_states; AGDFSVisit(heads, [&](const AGNodePtr& n) { if (n->nn_node->is_variable()) { vlist.push_back(n); } else { - if (n->opr != nullptr) { - saved_opr.insert({n->nn_node.get(), n->opr}); + if (n->state) { + saved_states.insert({n->nn_node.get(), n->state}); } if (fmutate_inputs.count(n->nn_node->op())) { for (uint32_t i : fmutate_inputs[n->nn_node->op()](n->nn_node->attrs)) { @@ -203,7 +203,7 @@ void AutogradRuntime::ComputeGradient(const std::vector& outputs, std::map ctx_map; auto exec = new exec::GraphExecutor(); // (TODO) too hack here - exec->saved_opr_ = saved_opr; + exec->saved_states_ = saved_states; exec->Init(sym, args[0].ctx(), ctx_map, args, args_grad, grad_reqs, aux_states, nullptr, feed_dict); diff --git a/src/ndarray/autograd.h b/src/ndarray/autograd.h index e6868064ca0d..baf843dbd4e1 100644 --- a/src/ndarray/autograd.h +++ b/src/ndarray/autograd.h @@ -25,7 +25,7 @@ class AGNode { public: OpReqType grad_req; nnvm::NodePtr nn_node; - std::shared_ptr opr; + OpStatePtr state; std::vector inputs; std::vector outputs; std::vector out_grads; @@ -40,7 +40,7 @@ class AGNode { void clear_history() { if (out_grads.size()) return; - opr.reset(); + state.reset(); outputs.clear(); nn_node.reset(); for (auto& i : inputs) i.ag_node->clear_history(); @@ -73,7 +73,7 @@ class AutogradRuntime { std::vector* p_inputs, std::vector* p_outputs); /*! \brief record imperative operator which is executed by operator. */ - void RecordImperativeOperator(const std::shared_ptr& opr, + void RecordImperativeOperator(const OpStatePtr& state, const nnvm::Op* op, const nnvm::NodeAttrs& attrs, std::vector* p_inputs, @@ -103,7 +103,7 @@ class AutogradRuntime { const nnvm::NodeAttrs& attrs, std::vector* p_inputs, std::vector* p_outputs, - const std::shared_ptr& opr); + const OpStatePtr& state); /*! \brief AutogradRuntime singleton. */ static AutogradRuntime* instance_; /*! \brief indicate whether is training. */ diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 6f1795d6f368..9999f9c8307b 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -757,8 +757,7 @@ void NDArray::SyncCopyFromCPU(const void *data, size_t size) const { if (this->ctx().dev_mask() == cpu::kDevMask) { this->WaitToWrite(); - RunContext rctx; - rctx.stream = nullptr; + RunContext rctx{this->ctx(), nullptr}; TBlob dst = this->data(); ndarray::Copy(src, &dst, Context::CPU(), Context::CPU(), rctx); } else { @@ -786,8 +785,7 @@ void NDArray::SyncCopyToCPU(void *data, size_t size) const { if (this->ctx().dev_mask() == cpu::kDevMask) { this->WaitToRead(); - RunContext rctx; - rctx.stream = nullptr; + RunContext rctx{this->ctx(), nullptr}; ndarray::Copy(this->data(), &dst, Context::CPU(), Context::CPU(), rctx); } else { diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu index ff6702f2f41b..13d36a2c4293 100644 --- a/src/ndarray/ndarray_function.cu +++ b/src/ndarray/ndarray_function.cu @@ -20,7 +20,7 @@ void Copy(const TBlob &from, TBlob *to, MSHADOW_TYPE_SWITCH(to->type_flag_, DType, { mshadow::Copy(to->FlatTo1D(), from.FlatTo1D(), - static_cast*>(ctx.stream)); + ctx.get_stream()); }); } @@ -33,7 +33,7 @@ void Copy(const TBlob &from, TBlob *to, MSHADOW_TYPE_SWITCH(to->type_flag_, DType, { mshadow::Copy(to->FlatTo1D(), from.FlatTo1D(), - static_cast*>(ctx.stream)); + ctx.get_stream()); }); } @@ -42,7 +42,7 @@ void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, RunContext ctx) { if (from_ctx.dev_id == to_ctx.dev_id) { - mshadow::Stream* s = static_cast*>(ctx.stream); + mshadow::Stream* s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(to->type_flag_, DType, { if (to->type_flag_ == from.type_flag_) { mshadow::Copy(to->FlatTo1D(s), @@ -60,7 +60,7 @@ void Copy(const TBlob &from, TBlob *to, << "copy across only support continugous memory"; CHECK_EQ(to->type_flag_, from.type_flag_) << "Source and target must have the same data type when copying across devices."; - mshadow::Stream *s = static_cast*>(ctx.stream); + mshadow::Stream *s = ctx.get_stream(); CHECK(s != NULL) << "need stream in GPU context"; cudaMemcpyPeerAsync(to->dptr_, to_ctx.dev_id, diff --git a/src/nnvm/legacy_op_util.cc b/src/nnvm/legacy_op_util.cc index 9b39794b4782..2bba5f1c3655 100644 --- a/src/nnvm/legacy_op_util.cc +++ b/src/nnvm/legacy_op_util.cc @@ -54,6 +54,97 @@ class ParsedOpProp { } }; +class OperatorState { + public: + OperatorState(Operator *opr, const OperatorProperty *prop) { + opr_ = opr; + fwd_init_ = bwd_init_ = false; + + in_data_.resize(prop->ListArguments().size()); + out_data_.resize(prop->NumOutputs()); + aux_data_.resize(prop->ListAuxiliaryStates().size()); + in_grad_.resize(in_data_.size()); + out_grad_.resize(prop->NumVisibleOutputs()); + + std::vector out_grad_ptr(out_grad_.size()); + for (size_t i = 0; i < out_grad_.size(); ++i) { + out_grad_ptr[i] = &out_grad_[i]; + } + std::vector in_data_ptr(in_data_.size()); + for (size_t i = 0; i < in_data_.size(); ++i) { + in_data_ptr[i] = &in_data_[i]; + } + std::vector out_data_ptr(out_data_.size()); + for (size_t i = 0; i < out_data_.size(); ++i) { + out_data_ptr[i] = &out_data_[i]; + } + arg_data_ptr_ = prop->BackwardInputs( + out_grad_ptr, in_data_ptr, out_data_ptr); + } + + ~OperatorState() { delete opr_; } + + void Forward(const OpContext &ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (!fwd_init_) { + CHECK_EQ(inputs.size(), in_data_.size() + aux_data_.size()); + CHECK_EQ(outputs.size(), out_data_.size()); + for (size_t i = 0; i < in_data_.size(); ++i) in_data_[i] = inputs[i]; + for (size_t i = 0; i < aux_data_.size(); ++i) { + aux_data_[i] = inputs[i + in_data_.size()]; + } + for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i]; + fwd_init_ = true; + } + opr_->Forward(ctx, in_data_, req, out_data_, aux_data_); + } + + void Backward(const OpContext &ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (!bwd_init_) { + CHECK(fwd_init_); + CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size()); + for (size_t i = 0; i < arg_data_ptr_.size(); ++i) { + *arg_data_ptr_[i] = inputs[i]; + } + for (size_t i = 0; i < aux_data_.size(); ++i) { + aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i]; + } + CHECK_EQ(outputs.size(), in_grad_.size()); + for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i]; + bwd_init_ = true; + } + opr_->Backward(ctx, out_grad_, in_data_, out_data_, req, in_grad_, aux_data_); + } + + private: + Operator *opr_; + bool fwd_init_, bwd_init_; + std::vector in_data_, aux_data_, out_data_, in_grad_, out_grad_; + std::vector arg_data_ptr_; +}; + +void LegacyOpForward(const OpStatePtr& state, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + auto& op = state.get_state(); + op.Forward(ctx, inputs, req, outputs); +} + +void LegacyOpBackward(const OpStatePtr& state, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + auto& op = state.get_state(); + op.Backward(ctx, inputs, req, outputs); +} // function to use operator property to infer attr // get op property from the attribute @@ -182,14 +273,15 @@ std::vector OpBackResourceRequest(const NodeAttrs& attrs) { return prop.ptr->BackwardResource(ishape); } -Operator* OpPropCreateLayerOp(const NodeAttrs& attrs, - Context ctx, - const std::vector& ishape, - const std::vector& itype) { +OpStatePtr OpPropCreateLayerOp(const NodeAttrs& attrs, + Context ctx, + const std::vector& ishape, + const std::vector& itype) { auto& prop = nnvm::get(attrs.parsed); std::vector is(ishape.begin(), ishape.begin() + prop.arguments.size()); std::vector it(itype.begin(), itype.begin() + prop.arguments.size()); - return prop.ptr->CreateOperatorEx(ctx, &is, &it); + return OpStatePtr::Create(prop.ptr->CreateOperatorEx(ctx, &is, &it), + prop.ptr.get()); } inline std::vector OpPropGradient( @@ -300,6 +392,11 @@ std::vector > OpBackInplaceOption(const NodeAttrs& attrs) { return remap; } +inline ExecType OpExecType(const NodeAttrs& attrs) { + auto& prop = nnvm::get(attrs.parsed); + return prop.ptr->exec_type(); +} + // register the legacy operator properties under NNVM registry. void RegisterLegacyOpProp() { for (auto reg : dmlc::Registry::List()) { @@ -328,10 +425,14 @@ void RegisterLegacyOpProp() { op.set_attr("FMutateInputs", OpPropMutateInputs); op.set_attr("FInplaceOption", OpPropInplaceOption); op.set_attr("FResourceRequest", OpPropResourceRequest); - op.set_attr("FCreateLayerOp", OpPropCreateLayerOp); + op.set_attr("FExecType", OpExecType); + op.set_attr("FCreateOpState", OpPropCreateLayerOp); + op.set_attr("FStatefulCompute", LegacyOpForward); + op.set_attr("FStatefulCompute", LegacyOpForward); if (reg->key_var_num_args.length() != 0) { op.set_attr("key_var_num_args", reg->key_var_num_args); } + // register BackwardOps std::string back_op_name = "_backward_" + reg->name; Op& back_op = ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER__(back_op_name); @@ -348,6 +449,9 @@ void RegisterLegacyOpProp() { "FResourceRequest", OpBackResourceRequest); back_op.set_attr("TIsLayerOpBackward", true); back_op.set_attr("TIsBackward", true); + back_op.set_attr("FExecType", OpExecType); + back_op.set_attr("FStatefulCompute", LegacyOpBackward); + back_op.set_attr("FStatefulCompute", LegacyOpBackward); } } diff --git a/src/operator/activation.cc b/src/operator/activation.cc index 0b1562925398..c8b8c3b5acb2 100644 --- a/src/operator/activation.cc +++ b/src/operator/activation.cc @@ -55,10 +55,6 @@ Operator *CreateOp(ActivationParam param, int dtype) { // DO_BIND_DISPATCH comes from operator_common.h Operator *ActivationProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } diff --git a/src/operator/batch_norm.cc b/src/operator/batch_norm.cc index 0ef5733f9f8c..1bc6fd08e2ea 100644 --- a/src/operator/batch_norm.cc +++ b/src/operator/batch_norm.cc @@ -315,21 +315,11 @@ Operator *CreateOp(BatchNormParam param, const int dtype, const TShape& sha break; } } -#define BATCHNORM_LOG_MKL_INFO() \ - do { \ - if (!mxnet::op::batchnorm::disable_mkl) { \ - LOG(INFO) << MKLBatchNormOp::getName() \ - << " Skipping MKL optimization (unsupported dimension, axis or type)"; \ - } \ - } while (0) -#else -#define BATCHNORM_LOG_MKL_INFO() ((void)0) #endif if (!op) { MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, { - BATCHNORM_LOG_MKL_INFO(); op = new BatchNormOp(param); }); } return op; @@ -338,11 +328,6 @@ Operator *CreateOp(BatchNormParam param, const int dtype, const TShape& sha // DO_BIND_DISPATCH comes from operator_common.h Operator *BatchNormProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - CHECK_GE(in_shape->size(), 1U); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_shape)[0]); } @@ -415,4 +400,3 @@ NNVM_REGISTER_OP(BatchNorm) } // namespace op } // namespace mxnet - diff --git a/src/operator/bilinear_sampler.cc b/src/operator/bilinear_sampler.cc index 7cc94c50982f..fd2bff824fd7 100644 --- a/src/operator/bilinear_sampler.cc +++ b/src/operator/bilinear_sampler.cc @@ -142,10 +142,6 @@ Operator* CreateOp(BilinearSamplerParam param, int dtype) { Operator *BilinearSamplerProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } diff --git a/src/operator/convolution.cc b/src/operator/convolution.cc index b8fc49021d77..fd604d90c546 100644 --- a/src/operator/convolution.cc +++ b/src/operator/convolution.cc @@ -44,7 +44,6 @@ Operator* CreateOp(ConvolutionParam param, int dtype, break; } } - LOG(INFO) << MKLConvolutionOp::getName() << " Skip MKL optimization"; #endif #if MXNET_USE_NNPACK == 1 const size_t batch_size = (*in_shape)[0][0]; @@ -72,8 +71,6 @@ Operator *ConvolutionProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], in_shape, &out_shape, ctx); } diff --git a/src/operator/cross_device_copy.cc b/src/operator/cross_device_copy.cc index ce618c97fa05..a9a5f475f0bc 100644 --- a/src/operator/cross_device_copy.cc +++ b/src/operator/cross_device_copy.cc @@ -20,12 +20,6 @@ class CrossDeviceCopyOp : public Operator { // We still re-use things such as InferShape in OperatorProperty LOG(FATAL) << "Not Reached"; } - - ExecType exec_type() const override { - // TODO(tianqi) Think of other way to blend cross device op into operator interface. - // declare the op as cross device, - return kCrossDeviceCopy; - } }; class CrossDeviceCopyProp : public OperatorProperty { @@ -58,6 +52,12 @@ class CrossDeviceCopyProp : public OperatorProperty { Operator* CreateOperator(Context ctx) const override { return new CrossDeviceCopyOp(); } + + ExecType exec_type() const override { + // TODO(tianqi) Think of other way to blend cross device op into operator interface. + // declare the op as cross device, + return ExecType::kCrossDeviceCopy; + } }; diff --git a/src/operator/custom/custom-inl.h b/src/operator/custom/custom-inl.h index f640c3abd7a6..3c688feb05a1 100644 --- a/src/operator/custom/custom-inl.h +++ b/src/operator/custom/custom-inl.h @@ -25,273 +25,33 @@ namespace mxnet { namespace op { +namespace custom { -struct CustomOpParam { - std::string op_type; - std::vector > kwargs; -}; - -template -class CustomOp : public Operator { - public: - explicit CustomOp(MXCallbackList* op_info) { - op_info_.reset(op_info, [](MXCallbackList *ptr){ - reinterpret_cast(ptr->callbacks[kCustomOpDelete])( - ptr->contexts[kCustomOpDelete]); - delete ptr; - }); - if (std::string("NaiveEngine") == dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) { - sync_mode_ = true; - } else { - sync_mode_ = false; - destructing_ = false; - worker_ = std::thread([&]() { - std::unique_lock lock(mtx_); - while (!q_.empty() || !destructing_) { - cv_.wait(lock, [&] {return !q_.empty() || destructing_;}); - while (!q_.empty()) { - q_.front()(); - q_.pop(); - } - } - }); - } - } - - ~CustomOp() { - if (!sync_mode_) { - { - std::unique_lock lock(mtx_); - destructing_ = true; - cv_.notify_all(); - } - worker_.join(); - } - } - - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args); - - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args); - - virtual ExecType exec_type() const { - return kAsync; - } - - private: - Context get_ctx(); - std::shared_ptr op_info_; - std::mutex mtx_; - std::condition_variable cv_; - std::thread worker_; - std::queue > q_; - bool destructing_; - bool sync_mode_; -}; // CustomOp - -template -Operator* CreateOp(MXCallbackList *op_info); - -class CustomOpProp : public OperatorProperty { +class Registry { public: - static void Register(const std::string &op_type, CustomOpPropCreator creator) { + void Register(const std::string &op_type, CustomOpPropCreator creator) { + std::lock_guard lock(mutex_); if (registry_.find(op_type) != registry_.end()) { LOG(WARNING) << "New registration is overriding existing custom operator " << op_type; } registry_[op_type] = creator; } - void Init(const std::vector >& kwargs) override { - kwargs_ = kwargs; - param_.op_type = ""; - param_.kwargs.clear(); - std::vector keys, vals; - for (auto &p : kwargs) { - if (p.first == "op_type") { - param_.op_type = p.second; - } else { - param_.kwargs.push_back(p); - keys.push_back(p.first.c_str()); - vals.push_back(p.second.c_str()); - } - } - CHECK_NE(param_.op_type, "") << "Custom operator type missing"; - CHECK(registry_.find(param_.op_type) != registry_.end()) - << "Cannot find custom operator type " << param_.op_type; - CustomOpPropCreator creator = registry_[param_.op_type]; - info_.reset(new MXCallbackList, [](MXCallbackList* ptr){ - reinterpret_cast(ptr->callbacks[kCustomOpPropDelete])( - ptr->contexts[kCustomOpPropDelete]); - delete ptr; - }); - CHECK(creator(param_.op_type.c_str(), keys.size(), keys.data(), vals.data(), info_.get())); - num_inputs_ = ListArguments().size(); - num_outputs_ = ListOutputs().size(); - num_auxs_ = ListAuxiliaryStates().size(); - } - - std::vector ListArguments() const override { - char ** args = NULL; - CHECK(reinterpret_cast(info_->callbacks[kCustomOpPropListArguments])( - &args, info_->contexts[kCustomOpPropListArguments])); - std::vector ret; - for (int i = 0; args[i] != NULL; ++i) { - ret.push_back(args[i]); - } - return ret; - } - - std::vector ListOutputs() const override { - char ** args = NULL; - CHECK(reinterpret_cast(info_->callbacks[kCustomOpPropListOutputs])( - &args, info_->contexts[kCustomOpPropListOutputs])); - std::vector ret; - for (int i = 0; args[i] != NULL; ++i) { - ret.push_back(args[i]); - } - return ret; - } - - std::vector ListAuxiliaryStates() const override { - char ** args = NULL; - CHECK(reinterpret_cast(info_->callbacks[kCustomOpPropListAuxiliaryStates])( - &args, info_->contexts[kCustomOpPropListAuxiliaryStates])); - std::vector ret; - for (int i = 0; args[i] != NULL; ++i) { - ret.push_back(args[i]); - } - return ret; - } - - int NumOutputs() const override { - return ListOutputs().size(); - } - - std::map GetParams() const override { - return std::map(kwargs_.begin(), kwargs_.end()); + CustomOpPropCreator Find(const std::string &op_type) { + std::lock_guard lock(mutex_); + auto it = registry_.find(op_type); + if (it != registry_.end()) return it->second; + return nullptr; } - - bool InferShape(std::vector *in_shape, - std::vector *out_shape, - std::vector *aux_shape) const override { - std::vector shapes; - std::vector ndims; - size_t size = 0; - for (const auto& s : *in_shape) size += s.ndim(); - std::vector shapes_buffer(size); - shapes_buffer.resize(size); - uint32_t *ptr = shapes_buffer.data(); - for (auto iter = in_shape->begin(); iter != in_shape->end(); ++iter) { - shapes.push_back(ptr); - ndims.push_back(iter->ndim()); - ptr = nnvm::ShapeTypeCast(iter->begin(), iter->end(), ptr); - } - shapes.resize(num_inputs_+num_outputs_+num_auxs_); - ndims.resize(num_inputs_+num_outputs_+num_auxs_); - - CHECK(reinterpret_cast(info_->callbacks[kCustomOpPropInferShape])( - shapes.size(), ndims.data(), shapes.data(), info_->contexts[kCustomOpPropInferShape])); - for (unsigned i = 0; i < in_shape->size(); ++i) { - SHAPE_ASSIGN_CHECK(*in_shape, i, TShape(shapes[i], shapes[i]+ndims[i])); - } - out_shape->clear(); - for (unsigned i = num_inputs_; i < num_inputs_+num_outputs_; ++i) { - out_shape->push_back(TShape(shapes[i], shapes[i]+ndims[i])); - } - aux_shape->clear(); - for (unsigned i = num_inputs_+num_outputs_; i < shapes.size(); ++i) { - aux_shape->push_back(TShape(shapes[i], shapes[i]+ndims[i])); - } - return true; - } - - bool InferType(std::vector *in_type, - std::vector *out_type, - std::vector *aux_type) const override { - if (info_->num_callbacks <= kCustomOpPropInferType) { - return OperatorProperty::InferType(in_type, out_type, aux_type); - } - - std::vector types; - for (const auto &i : *in_type) types.push_back(i); - for (const auto &i : *out_type) types.push_back(i); - for (const auto &i : *aux_type) types.push_back(i); - - CHECK(reinterpret_cast(info_->callbacks[kCustomOpPropInferType])( - types.size(), types.data(), info_->contexts[kCustomOpPropInferType])); - for (unsigned i = 0; i < num_inputs_; ++i) { - TYPE_ASSIGN_CHECK(*in_type, i, types[i]); - } - for (unsigned i = 0; i < num_outputs_; ++i) { - TYPE_ASSIGN_CHECK(*out_type, i, types[i+num_inputs_]); - } - for (unsigned i = 0; i < num_auxs_; ++i) { - TYPE_ASSIGN_CHECK(*aux_type, i, types[i+num_inputs_+num_outputs_]); - } - return true; - } - - - OperatorProperty* Copy() const override { - CustomOpProp *prop_sym = new CustomOpProp(); - prop_sym->Init(kwargs_); - return prop_sym; - } - - std::string TypeString() const override { - return "Custom"; - } - - std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { - int num_dep; - int *rdeps; - CHECK(reinterpret_cast( - info_->callbacks[kCustomOpPropDeclareBackwardDependency])( - out_grad.data(), in_data.data(), out_data.data(), &num_dep, - &rdeps, info_->contexts[kCustomOpPropDeclareBackwardDependency])); - std::vector deps; - deps.insert(deps.end(), rdeps, rdeps+num_dep); - return deps; - } - - std::vector > BackwardInplaceOption( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &in_grad) const override { - return {}; - } - - Operator* CreateOperator(Context ctx) const override { - LOG(FATAL) << "Not Implemented."; - return NULL; - } - - Operator* CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const override; - + static Registry* Get(); private: - static std::map registry_; + Registry() {} + std::mutex mutex_; + std::map registry_; +}; - CustomOpParam param_; - std::shared_ptr info_; - std::vector > kwargs_; - unsigned num_inputs_, num_outputs_, num_auxs_; - mutable std::vector shapes_buffer_; -}; // class CustomOpProp +} // namespace custom } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_CUSTOM_CUSTOM_INL_H_ diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc index 8fb324c1f5c2..1854bb7f05d0 100644 --- a/src/operator/custom/custom.cc +++ b/src/operator/custom/custom.cc @@ -8,196 +8,387 @@ #include #include +#include "../../ndarray/autograd.h" +#include "../elemwise_op_common.h" + namespace mxnet { namespace op { -std::map CustomOpProp::registry_; +namespace custom { + +Registry* Registry::Get() { + static Registry inst; + return &inst; +} + +struct CustomParam { + std::string op_type; + size_t num_args, num_outs, num_auxs; + std::vector bwd_idx; + std::shared_ptr info; +}; + + +template +std::vector List(const NodeAttrs& attrs) { + const CustomParam& params = nnvm::get(attrs.parsed); + char ** args = NULL; + CHECK(reinterpret_cast( + params.info->callbacks[Type])( + &args, params.info->contexts[Type])); + std::vector ret; + for (int i = 0; args[i] != NULL; ++i) { + ret.push_back(args[i]); + } + return ret; +} + +void AttrParser(NodeAttrs* attrs) { + attrs->parsed = CustomParam(); + CustomParam& params = nnvm::get(attrs->parsed); + + std::vector keys, vals; + for (auto &p : attrs->dict) { + if (p.first == "op_type") { + params.op_type = p.second; + } else { + keys.push_back(p.first.c_str()); + vals.push_back(p.second.c_str()); + } + } + CHECK(!params.op_type.empty()) << "Required argument `op_type` is missing."; + CustomOpPropCreator creator = Registry::Get()->Find(params.op_type); + CHECK(Registry::Get()->Find(params.op_type) != nullptr) + << "Cannot find custom operator " << params.op_type; + params.info.reset(new MXCallbackList, [](MXCallbackList* ptr){ + reinterpret_cast(ptr->callbacks[kCustomOpPropDelete])( + ptr->contexts[kCustomOpPropDelete]); + delete ptr; + }); + CHECK(creator(params.op_type.c_str(), keys.size(), keys.data(), + vals.data(), params.info.get())); + + params.num_args = List(*attrs).size(); + params.num_outs = List(*attrs).size(); + params.num_auxs = List(*attrs).size(); + + int num_dep, *rdeps, counter = 0; + std::vector out_grad, in_data, out_data; + for (size_t i = 0; i < params.num_outs; ++i) out_grad.push_back(counter++); + for (size_t i = 0; i < params.num_args; ++i) in_data.push_back(counter++); + for (size_t i = 0; i < params.num_outs; ++i) out_data.push_back(counter++); + CHECK(reinterpret_cast( + params.info->callbacks[kCustomOpPropDeclareBackwardDependency])( + out_grad.data(), in_data.data(), out_data.data(), &num_dep, + &rdeps, params.info->contexts[kCustomOpPropDeclareBackwardDependency])); + params.bwd_idx.insert(params.bwd_idx.end(), rdeps, rdeps+num_dep); +} + +bool InferShape(const NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + const CustomParam& params = nnvm::get(attrs.parsed); + + size_t total = params.num_args + params.num_outs + params.num_auxs; + std::vector shapes(total); + std::vector ndims(total); + size_t buff_size = 0; + for (const auto& i : *in_shape) buff_size += i.ndim(); + std::vector buff(buff_size); + uint32_t *ptr = buff.data(); + for (size_t i = 0; i < in_shape->size(); ++i) { + shapes[i] = ptr; + ndims[i] = (*in_shape)[i].ndim(); + for (size_t j = 0; j < (*in_shape)[i].ndim(); ++j, ++ptr) { + *ptr = static_cast((*in_shape)[i][j]); + } + } + + CHECK(reinterpret_cast( + params.info->callbacks[kCustomOpPropInferShape])( + shapes.size(), ndims.data(), shapes.data(), + params.info->contexts[kCustomOpPropInferShape])); + + for (size_t i = 0; i < params.num_args; ++i) { + SHAPE_ASSIGN_CHECK(*in_shape, i, TShape(shapes[i], shapes[i]+ndims[i])); + } -template<> -Context CustomOp::get_ctx() { - return Context::CPU(); + size_t base = params.num_args; + for (size_t i = 0; i < params.num_outs; ++i) { + SHAPE_ASSIGN_CHECK(*out_shape, i, + TShape(shapes[base+i], shapes[base+i]+ndims[base+i])); + } + + base = params.num_args + params.num_outs; + for (size_t i = 0; i < params.num_auxs; ++i) { + SHAPE_ASSIGN_CHECK(*in_shape, params.num_args+i, + TShape(shapes[base+i], shapes[base+i]+ndims[base+i])); + } + return true; } -template<> -Operator *CreateOp(MXCallbackList *op_info) { - return new CustomOp(op_info); +bool InferType(const NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + const CustomParam& params = nnvm::get(attrs.parsed); + + if (params.info->num_callbacks <= kCustomOpPropInferType) { + return ElemwiseAttr( + attrs, in_type, out_type, -1); + } + + std::vector types; + types.reserve(params.num_args + params.num_outs + params.num_auxs); + for (size_t i = 0; i < params.num_args; ++i) { + types.push_back((*in_type)[i]); + } + for (const auto &i : *out_type) { + types.push_back(i); + } + for (size_t i = 0; i < params.num_auxs; ++i) { + types.push_back((*in_type)[params.num_args+i]); + } + + CHECK(reinterpret_cast( + params.info->callbacks[kCustomOpPropInferType])( + types.size(), types.data(), params.info->contexts[kCustomOpPropInferType])); + + for (size_t i = 0; i < params.num_args; ++i) { + TYPE_ASSIGN_CHECK(*in_type, i, types[i]); + } + for (size_t i = 0; i < params.num_outs; ++i) { + TYPE_ASSIGN_CHECK(*out_type, i, types[params.num_args+i]); + } + for (size_t i = 0; i < params.num_auxs; ++i) { + TYPE_ASSIGN_CHECK(*in_type, params.num_args+i, + types[params.num_args+params.num_outs+i]); + } + return true; } -#if MXNET_USE_CUDA -template<> -Context CustomOp::get_ctx() { - int dev_id; - CHECK_EQ(cudaGetDevice(&dev_id), cudaSuccess); - return Context::GPU(dev_id); +std::vector Gradient( + const nnvm::NodePtr& n, + const std::vector& out_grads) { + const CustomParam& params = nnvm::get(n->attrs.parsed); + + nnvm::NodePtr g = nnvm::Node::Create(); + g->attrs.op = nnvm::Op::Get("_backward_Custom"); + g->attrs.name = n->attrs.name; + g->attrs.parsed = params; + g->control_deps.emplace_back(n); + + g->inputs.reserve(params.bwd_idx.size()); + for (const int& t : params.bwd_idx) { + size_t i = static_cast(t); + if (i >= params.num_outs + params.num_args) { + uint32_t idx = static_cast(i-params.num_outs-params.num_args); + g->inputs.push_back(nnvm::NodeEntry{n, idx, 0}); + } else if (i >= params.num_outs) { + g->inputs.push_back(n->inputs[i-params.num_outs]); + } else { + g->inputs.push_back(out_grads[i]); + } + } + + for (size_t i = 0; i < params.num_auxs; ++i) { + g->inputs.push_back(n->inputs[i+params.num_args]); + } + + std::vector ret; + for (index_t i = 0; i < g->num_outputs(); ++i) { + ret.emplace_back(nnvm::NodeEntry{g, i, 0}); + } + + return ret; } -template<> -Operator* CreateOp(MXCallbackList *op_info) { - return new CustomOp(op_info); + +OpStatePtr CreateState(const NodeAttrs& attrs, Context ctx, + const std::vector& in_shape, + const std::vector& in_type) { + const CustomParam& params = nnvm::get(attrs.parsed); + + size_t total = params.num_args + params.num_outs + params.num_auxs; + std::vector shapes(total); + std::vector ndims(total); + size_t buff_size = 0; + for (const auto& i : in_shape) buff_size += i.ndim(); + std::vector buff(buff_size); + uint32_t *ptr = buff.data(); + for (size_t i = 0; i < in_shape.size(); ++i) { + shapes[i] = ptr; + ndims[i] = in_shape[i].ndim(); + for (size_t j = 0; j < in_shape[i].ndim(); ++j, ++ptr) { + *ptr = static_cast(in_shape[i][j]); + } + } + + std::string str_ctx; + if (ctx.dev_mask() == cpu::kDevMask) { + str_ctx = "cpu"; + } else { + str_ctx = "gpu"; + } + + MXCallbackList *op_info = new MXCallbackList; + CHECK(reinterpret_cast( + params.info->callbacks[kCustomOpPropCreateOperator])( + str_ctx.c_str(), shapes.size(), shapes.data(), ndims.data(), in_type.data(), + op_info, params.info->contexts[kCustomOpPropCreateOperator])); + + CustomParam state = params; + state.info.reset(op_info, [](MXCallbackList *ptr){ + reinterpret_cast(ptr->callbacks[kCustomOpDelete])( + ptr->contexts[kCustomOpDelete]); + delete ptr; + }); + + return OpStatePtr::Create(state); } -#endif // MXNET_USE_CUDA - -template -void CustomOp::Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - Context ndctx = get_ctx(); + +void Forward(const OpStatePtr& state, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const CustomParam& params = state.get_state(); std::vector ptrs; - std::vector ndcpy; - std::vector ndvar; std::vector tags; - std::vector reqs(req.begin(), req.end()); - for (auto& blob : in_data) { - ptrs.push_back(reinterpret_cast(new NDArray(blob, ndctx.dev_id))); + for (size_t i = 0; i < params.num_args; ++i) { + NDArray *nd = new NDArray(inputs[i].Detach()); + ptrs.push_back(reinterpret_cast(nd)); tags.push_back(0); } - for (auto& blob : out_data) { - NDArray* nd = new NDArray(blob, ndctx.dev_id); + + for (size_t i = 0; i < params.num_outs; ++i) { + NDArray *nd = new NDArray(outputs[i].Detach()); ptrs.push_back(reinterpret_cast(nd)); - ndcpy.push_back(*nd); - ndvar.push_back(nd->var()); tags.push_back(1); } - for (auto& blob : aux_args) { - NDArray* nd = new NDArray(blob, ndctx.dev_id); + + for (size_t i = 0; i < params.num_auxs; ++i) { + NDArray *nd = new NDArray(inputs[i+params.num_args].Detach()); ptrs.push_back(reinterpret_cast(nd)); - ndcpy.push_back(*nd); - ndvar.push_back(nd->var()); tags.push_back(4); } - std::sort(ndvar.begin(), ndvar.end()); - ndvar.resize(std::unique(ndvar.begin(), ndvar.end()) - ndvar.begin()); - auto compute = [=]() mutable { - CHECK(reinterpret_cast(op_info_->callbacks[kCustomOpForward])( - ptrs.size(), ptrs.data(), tags.data(), reqs.data(), - static_cast(ctx.is_train), op_info_->contexts[kCustomOpForward])); + bool old = autograd::AutogradRuntime::Get()->SetIsTraining(false); - // NDArray* in ptrs is freed by frontend side. We keep a copy in ndcpy to keep ndvar alive - Engine::Get()->PushSync([ndcpy, ctx](RunContext rctx) { - ctx.async_on_complete(); - }, ndctx, ndvar, {}, - FnProperty::kNormal, 0, PROFILER_MESSAGE("CustomOpForward")); - }; + CHECK(reinterpret_cast(params.info->callbacks[kCustomOpForward])( + ptrs.size(), ptrs.data(), tags.data(), reinterpret_cast(req.data()), + static_cast(ctx.is_train), params.info->contexts[kCustomOpForward])); - if (sync_mode_) { - compute(); - } else { - std::unique_lock lock(mtx_); - q_.push(compute); - cv_.notify_all(); - } + autograd::AutogradRuntime::Get()->SetIsTraining(old); } -template -void CustomOp::Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args) { - using namespace mshadow; - Context ndctx = get_ctx(); - std::vector ptrs; - std::vector ndcpy; - std::vector ndvar; + +void Backward(const OpStatePtr& state, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const CustomParam& params = state.get_state(); + + size_t total = 2*params.num_args + 2*params.num_outs + params.num_auxs; + std::vector ptrs(params.num_args + 2*params.num_outs, nullptr); std::vector tags; - std::vector reqs(req.begin(), req.end()); + ptrs.reserve(total); + tags.reserve(total); + for (size_t i = 0; i < params.num_outs; ++i) tags.push_back(3); + for (size_t i = 0; i < params.num_args; ++i) tags.push_back(0); + for (size_t i = 0; i < params.num_outs; ++i) tags.push_back(1); - for (auto& blob : in_data) { - ptrs.push_back(reinterpret_cast(new NDArray(blob, ndctx.dev_id))); - tags.push_back(0); + for (size_t i = 0; i < params.bwd_idx.size(); ++i) { + NDArray *nd = new NDArray(inputs[i].Detach()); + ptrs[params.bwd_idx[i]] = reinterpret_cast(nd); } - for (auto& blob : out_data) { - ptrs.push_back(reinterpret_cast(new NDArray(blob, ndctx.dev_id))); - tags.push_back(1); + for (size_t i = 0; i < ptrs.size(); ++i) { + if (ptrs[i] == nullptr) ptrs[i] = reinterpret_cast(new NDArray()); } - for (auto& blob : in_grad) { - NDArray* nd = new NDArray(blob, ndctx.dev_id); + for (const auto& i : outputs) { + NDArray* nd = new NDArray(i.Detach()); ptrs.push_back(reinterpret_cast(nd)); - ndcpy.push_back(*nd); - ndvar.push_back(nd->var()); tags.push_back(2); } - for (auto& blob : aux_args) { - NDArray* nd = new NDArray(blob, ndctx.dev_id); + for (size_t i = 0; i < params.num_auxs; ++i) { + NDArray* nd = new NDArray(inputs[inputs.size()-params.num_auxs+i].Detach()); ptrs.push_back(reinterpret_cast(nd)); - ndcpy.push_back(*nd); - ndvar.push_back(nd->var()); tags.push_back(4); } - std::sort(ndvar.begin(), ndvar.end()); - ndvar.resize(std::unique(ndvar.begin(), ndvar.end()) - ndvar.begin()); - for (auto& blob : out_grad) { - ptrs.push_back(reinterpret_cast(new NDArray(blob, ndctx.dev_id))); - tags.push_back(3); - } - auto compute = [=]() mutable { - CHECK(reinterpret_cast(op_info_->callbacks[kCustomOpBackward])( - ptrs.size(), ptrs.data(), tags.data(), reqs.data(), 1, - op_info_->contexts[kCustomOpBackward])); + bool old = autograd::AutogradRuntime::Get()->SetIsTraining(false); - // NDArray* in ptrs is freed by frontend side. We keep a copy in ndcpy to keep ndvar alive - Engine::Get()->PushSync([ndcpy, ctx](RunContext rctx){ - ctx.async_on_complete(); - }, ndctx, ndvar, {}, - FnProperty::kNormal, 0, PROFILER_MESSAGE("CustomOpBackward")); - }; + CHECK(reinterpret_cast(params.info->callbacks[kCustomOpBackward])( + ptrs.size(), ptrs.data(), tags.data(), reinterpret_cast(req.data()), 1, + params.info->contexts[kCustomOpBackward])); - if (sync_mode_) { - compute(); - } else { - std::unique_lock lock(mtx_); - q_.push(compute); - cv_.notify_all(); - } + autograd::AutogradRuntime::Get()->SetIsTraining(old); } -Operator* CustomOpProp::CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const { - std::vector shapes; - std::vector ndims; - size_t size = 0; - for (const auto& s : *in_shape) size += s.ndim(); - shapes_buffer_.resize(size); - uint32_t *ptr = shapes_buffer_.data(); - for (auto iter = in_shape->begin(); iter != in_shape->end(); ++iter) { - shapes.push_back(ptr); - ndims.push_back(iter->ndim()); - ptr = nnvm::ShapeTypeCast(iter->begin(), iter->end(), ptr); - } - std::string str_ctx; - if (ctx.dev_mask() == cpu::kDevMask) { - str_ctx = "cpu"; - } else { - str_ctx = "gpu"; - } - MXCallbackList *op_info = new MXCallbackList; - CHECK(reinterpret_cast(info_->callbacks[kCustomOpPropCreateOperator])( - str_ctx.c_str(), shapes.size(), shapes.data(), ndims.data(), in_type->data(), op_info, - info_->contexts[kCustomOpPropCreateOperator])); - DO_BIND_DISPATCH(CreateOp, op_info); -} -MXNET_REGISTER_OP_PROPERTY(Custom, CustomOpProp) +NNVM_REGISTER_OP(Custom) .describe(R"code(Apply a custom operator implemented in a frontend language (like Python). Custom operators should override required methods like `forward` and `backward`. The custom operator must be registered before it can be used. Please check the tutorial here: http://mxnet.io/how_to/new_op.html. -)code") +)code" ADD_FILELINE) +.set_num_inputs([](const NodeAttrs& attrs){ + const CustomParam& params = nnvm::get(attrs.parsed); + return params.num_args + params.num_auxs; + }) +.set_num_outputs([](const NodeAttrs& attrs){ + const CustomParam& params = nnvm::get(attrs.parsed); + return params.num_outs; + }) +.set_attr_parser(AttrParser) +.set_attr("FInferShape", InferShape) +.set_attr("FInferType", InferType) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + std::vector args = List(attrs); + std::vector auxs = List(attrs); + args.insert(args.end(), auxs.begin(), auxs.end()); + return args; + }) +.set_attr("FListOutputNames", List) +.set_attr("FMutateInputs", [](const NodeAttrs& attrs) { + const CustomParam& params = nnvm::get(attrs.parsed); + std::vector ret; + for (size_t i = 0; i < params.num_auxs; ++i) ret.push_back(i+params.num_args); + return ret; + }) +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kLocal; + }) +.set_attr("FGradient", Gradient) +.set_attr("FCreateOpState", CreateState) +.set_attr("FStatefulComputeEx", Forward) +.set_attr("FStatefulComputeEx", Forward) .add_argument("data", "NDArray-or-Symbol[]", "Input data for the custom operator.") .add_argument("op_type", "string", "Name of the custom operator. " "This is the name that is passed to `mx.operator.register` " "to register the operator."); +NNVM_REGISTER_OP(_backward_Custom) +.set_num_inputs([](const NodeAttrs& attrs){ + const CustomParam& params = nnvm::get(attrs.parsed); + return params.bwd_idx.size(); + }) +.set_num_outputs([](const NodeAttrs& attrs){ + const CustomParam& params = nnvm::get(attrs.parsed); + return params.num_args; + }) +.set_attr("TIsLayerOpBackward", true) +.set_attr("TIsBackward", true) +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kLocal; + }) +.set_attr("FStatefulComputeEx", Backward) +.set_attr("FStatefulComputeEx", Backward); + +} // namespace custom } // namespace op } // namespace mxnet diff --git a/src/operator/custom/ndarray_op-inl.h b/src/operator/custom/ndarray_op-inl.h index 05b1a3a902e8..fa4208f1da89 100644 --- a/src/operator/custom/ndarray_op-inl.h +++ b/src/operator/custom/ndarray_op-inl.h @@ -52,10 +52,6 @@ class NDArrayOp : public Operator { const std::vector &in_grad, const std::vector &aux_args); - virtual ExecType exec_type() const { - return kAsync; - } - private: NDArrayOpParam param_; Context get_ctx(); @@ -169,6 +165,10 @@ class NDArrayOpProp : public OperatorProperty { Operator* CreateOperator(Context ctx) const override; + ExecType exec_type() const override { + return ExecType::kAsync; + } + private: NDArrayOpParam param_; }; // class PythonProp diff --git a/src/operator/deconvolution.cc b/src/operator/deconvolution.cc index 83af00903919..397bd0065f80 100644 --- a/src/operator/deconvolution.cc +++ b/src/operator/deconvolution.cc @@ -24,8 +24,6 @@ Operator* CreateOp(DeconvolutionParam param, int dtype, Operator* DeconvolutionProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0), in_shape, &out_shape, ctx); } diff --git a/src/operator/dropout-inl.h b/src/operator/dropout-inl.h index 47bb0a3dffd3..e77d61351be0 100644 --- a/src/operator/dropout-inl.h +++ b/src/operator/dropout-inl.h @@ -88,7 +88,7 @@ class DropoutOp : public Operator { Tensor out = out_data[dropout::kOut].FlatTo2D(s); if (ctx.is_train) { Tensor mask = out_data[dropout::kMask].FlatTo2D(s); -#if defined(USE_MKL) && defined(_OPENMP) +#if !defined(__CUDACC__) && defined(USE_MKL) && defined(_OPENMP) DType* outptr = out.dptr_; DType* dataptr = data.dptr_; int* maskptr = reinterpret_cast(mask.dptr_); @@ -124,7 +124,7 @@ class DropoutOp : public Operator { Tensor grad = out_grad[dropout::kOut].FlatTo2D(s); Tensor mask = out_data[dropout::kMask].FlatTo2D(s); Tensor gdata = in_grad[dropout::kData].FlatTo2D(s); -#if defined(USE_MKL) && defined(_OPENMP) +#if !defined(__CUDACC__) && defined(USE_MKL) && defined(_OPENMP) DType* ingradptr = gdata.dptr_; DType* outgradptr = grad.dptr_; int* maskptr = reinterpret_cast(mask.dptr_); diff --git a/src/operator/dropout.cc b/src/operator/dropout.cc index 20afef2c63c8..74a50baf80a4 100644 --- a/src/operator/dropout.cc +++ b/src/operator/dropout.cc @@ -21,10 +21,6 @@ Operator *CreateOp(DropoutParam param, int dtype) { // DO_BIND_DISPATCH comes from operator_common.h Operator *DropoutProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); } diff --git a/src/operator/fully_connected.cc b/src/operator/fully_connected.cc index cec2015425c6..56cf4f6dbdde 100644 --- a/src/operator/fully_connected.cc +++ b/src/operator/fully_connected.cc @@ -49,8 +49,6 @@ Operator* CreateOp(FullyConnectedParam param, int dtype, Operator *FullyConnectedProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { std::vector out_shape(1, TShape()), aux_shape; - std::vector out_type(1, -1), aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], in_shape, &out_shape, ctx); } diff --git a/src/operator/grid_generator.cc b/src/operator/grid_generator.cc index 8625d1ba971a..62ff75a88359 100644 --- a/src/operator/grid_generator.cc +++ b/src/operator/grid_generator.cc @@ -22,10 +22,6 @@ Operator* CreateOp(GridGeneratorParam param, int dtype) { Operator *GridGeneratorProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } diff --git a/src/operator/instance_norm.cc b/src/operator/instance_norm.cc index bf3285a7a9d0..cc2bd6b93e8b 100644 --- a/src/operator/instance_norm.cc +++ b/src/operator/instance_norm.cc @@ -18,10 +18,6 @@ Operator* CreateOp(InstanceNormParam param, int dtype) { Operator* InstanceNormProp::CreateOperatorEx(Context ctx, std::vector* in_shape, std::vector* in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } diff --git a/src/operator/lrn.cc b/src/operator/lrn.cc index e896e16b443a..ac4a309cbe05 100644 --- a/src/operator/lrn.cc +++ b/src/operator/lrn.cc @@ -28,10 +28,6 @@ Operator* CreateOp(LRNParam param, int dtype) { // DO_BIND_DISPATCH comes from operator_common.h Operator* LocalResponseNormProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } @@ -42,14 +38,14 @@ MXNET_REGISTER_OP_PROPERTY(LRN, LocalResponseNormProp) .add_arguments(LRNParam::__FIELDS__()) .describe(R"code(Applies local response normalization to the input. -The local response normalization layer performs "lateral inhibition" by normalizing -over local input regions. +The local response normalization layer performs "lateral inhibition" by normalizing +over local input regions. If :math:`a_{x,y}^{i}` is the activity of a neuron computed by applying kernel :math:`i` at position -:math:`(x, y)` and then applying the ReLU nonlinearity, the response-normalized -activity :math:`b_{x,y}^{i}` is given by the expression: +:math:`(x, y)` and then applying the ReLU nonlinearity, the response-normalized +activity :math:`b_{x,y}^{i}` is given by the expression: -.. math:: +.. math:: b_{x,y}^{i} = \frac{a_{x,y}^{i}}{\Bigg({k + \alpha \sum_{j=max(0, i-\frac{n}{2})}^{min(N-1, i+\frac{n}{2})} (a_{x,y}^{j})^{2}}\Bigg)^{\beta}} where the sum runs over :math:`n` "adjacent" kernel maps at the same spatial position, and :math:`N` is the total diff --git a/src/operator/pad.cc b/src/operator/pad.cc index ded48c99f608..5d1afca588fb 100644 --- a/src/operator/pad.cc +++ b/src/operator/pad.cc @@ -634,10 +634,6 @@ Operator *CreateOp(PadParam param, int dtype) { // DO_BIND_DISPATCH comes from operator_common.h Operator *PadProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } diff --git a/src/operator/pooling.cc b/src/operator/pooling.cc index c66543d711bf..f26c2e8b199e 100644 --- a/src/operator/pooling.cc +++ b/src/operator/pooling.cc @@ -70,10 +70,6 @@ Operator *CreateOp(PoolingParam param, int dtype) { // DO_BIND_DISPATCH comes from operator_common.h Operator* PoolingProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index f43379fdd8dd..f19c3bbad04b 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -22,10 +22,6 @@ Operator *CreateOp(RNNParam param, int dtype) { Operator *RNNProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc index 35fe94c33242..0faca1e463bc 100644 --- a/src/operator/roi_pooling.cc +++ b/src/operator/roi_pooling.cc @@ -217,10 +217,6 @@ Operator *CreateOp(ROIPoolingParam param, int dtype) { Operator *ROIPoolingProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); } diff --git a/src/operator/sequence_last.cc b/src/operator/sequence_last.cc index 7c796613efa8..8a50ff73ec64 100644 --- a/src/operator/sequence_last.cc +++ b/src/operator/sequence_last.cc @@ -20,10 +20,6 @@ Operator *CreateOp(SequenceLastParam param, int dtype) { Operator *SequenceLastProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } diff --git a/src/operator/sequence_mask.cc b/src/operator/sequence_mask.cc index 763bc17171ae..0ac782e51c3c 100644 --- a/src/operator/sequence_mask.cc +++ b/src/operator/sequence_mask.cc @@ -33,10 +33,6 @@ Operator *CreateOp(SequenceMaskParam param, int dtype) { Operator *SequenceMaskProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } diff --git a/src/operator/sequence_reverse.cc b/src/operator/sequence_reverse.cc index 871db9b3d486..01dcb6810e62 100644 --- a/src/operator/sequence_reverse.cc +++ b/src/operator/sequence_reverse.cc @@ -20,10 +20,6 @@ Operator *CreateOp(SequenceReverseParam param, int dtype) { Operator *SequenceReverseProp::CreateOperatorEx( Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index 06225a3b0be7..08580e9328cd 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -20,10 +20,6 @@ Operator *CreateOp(SoftmaxOutputParam param, int dtype) { // DO_BIND_DISPATCH comes from operator_common.h Operator *SoftmaxOutputProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } diff --git a/src/operator/spatial_transformer.cc b/src/operator/spatial_transformer.cc index fe91a143c23e..45c1d8588776 100644 --- a/src/operator/spatial_transformer.cc +++ b/src/operator/spatial_transformer.cc @@ -116,10 +116,6 @@ Operator* CreateOp(SpatialTransformerParam param, int dtype) { Operator *SpatialTransformerProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } diff --git a/src/operator/svm_output.cc b/src/operator/svm_output.cc index ead853e214b8..5f1f77ad9fc1 100644 --- a/src/operator/svm_output.cc +++ b/src/operator/svm_output.cc @@ -62,10 +62,6 @@ Operator *CreateOp(SVMOutputParam param, int dtype) { // DO_BIND_DISPATCH comes from operator_common.h Operator *SVMOutputProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } @@ -84,4 +80,3 @@ This tutorial demonstrates using SVM as output layer for classification instead } // namespace op } // namespace mxnet - diff --git a/src/operator/swapaxis.cc b/src/operator/swapaxis.cc index 24ea807ef9ce..097f9837025f 100644 --- a/src/operator/swapaxis.cc +++ b/src/operator/swapaxis.cc @@ -21,10 +21,6 @@ Operator* CreateOp(SwapAxisParam param, int dtype) { Operator* SwapAxisProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - CHECK(InferType(in_type, &out_type, &aux_type)); DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); } diff --git a/src/operator/upsampling.cc b/src/operator/upsampling.cc index cc9861346825..ad89d4ace137 100644 --- a/src/operator/upsampling.cc +++ b/src/operator/upsampling.cc @@ -44,10 +44,6 @@ Operator *CreateOp(UpSamplingParam param, int dtype) { Operator* UpSamplingProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); } diff --git a/tests/cpp/include/test_op.h b/tests/cpp/include/test_op.h index 2c96092db81c..57fda19e4c9e 100644 --- a/tests/cpp/include/test_op.h +++ b/tests/cpp/include/test_op.h @@ -17,8 +17,8 @@ * test_perf.h: Performance-related classes * test_op.h: Operator-specific testing classes */ -#ifndef TESTS_CPP_INCLUDE_TEST_OP_H_ -#define TESTS_CPP_INCLUDE_TEST_OP_H_ +#ifndef TEST_OP_H_ +#define TEST_OP_H_ #include "test_perf.h" #include "test_util.h" @@ -705,4 +705,4 @@ static test::op::OpInfo createOpAndInfoF(const boo } // namespace test } // namespace mxnet -#endif // TESTS_CPP_INCLUDE_TEST_OP_H_ +#endif // TEST_OP_H_ diff --git a/tests/cpp/include/test_perf.h b/tests/cpp/include/test_perf.h index 6343863db16e..93b7863de694 100644 --- a/tests/cpp/include/test_perf.h +++ b/tests/cpp/include/test_perf.h @@ -5,8 +5,8 @@ * \author Chris Olivier */ -#ifndef TESTS_CPP_INCLUDE_TEST_PERF_H_ -#define TESTS_CPP_INCLUDE_TEST_PERF_H_ +#ifndef TEST_PERF_H_ +#define TEST_PERF_H_ #include #include @@ -286,4 +286,4 @@ class TimingItem { } // namespace test } // namespace mxnet -#endif // TESTS_CPP_INCLUDE_TEST_PERF_H_ +#endif // TEST_PERF_H_ diff --git a/tests/cpp/include/test_util.h b/tests/cpp/include/test_util.h index b0e4c866f9de..3fa82688c115 100644 --- a/tests/cpp/include/test_util.h +++ b/tests/cpp/include/test_util.h @@ -4,8 +4,8 @@ * \brief unit test performance analysis functions * \author Chris Olivier */ -#ifndef TESTS_CPP_INCLUDE_TEST_UTIL_H_ -#define TESTS_CPP_INCLUDE_TEST_UTIL_H_ +#ifndef TEST_UTIL_H_ +#define TEST_UTIL_H_ #include #include @@ -413,4 +413,4 @@ struct ScopeSet { } // namespace test } // namespace mxnet -#endif // TESTS_CPP_INCLUDE_TEST_UTIL_H_ +#endif // TEST_UTIL_H_ diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index e345326632f3..7a958f7de01b 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1633,7 +1633,7 @@ def dot_sym_xT_yT(data_type): def test_batch_dot(): dtypes = ['float32', 'float64'] - + for data_type in dtypes: for batch_size in range(1, 5): for m in range(1, 5): @@ -3227,6 +3227,12 @@ def create_operator(self, ctx, shapes, dtypes): x = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10))) check_numeric_gradient(op, [x]) + dx = mx.nd.zeros_like(x) + mx.contrib.autograd.mark_variables([x], [dx]) + with mx.contrib.autograd.train_section(): + y = mx.nd.Custom(x, op_type='sqr') + y.backward() + def test_psroipooling(): for num_rois in [1, 2]: @@ -3306,10 +3312,10 @@ def test_deformable_psroipooling(): im_data_var = mx.symbol.Variable(name="im_data") rois_data_var = mx.symbol.Variable(name="rois_data") offset_data_var = mx.symbol.Variable(name="offset_data") - op = mx.contrib.sym.DeformablePSROIPooling(data=im_data_var, rois=rois_data_var, - trans=offset_data_var, spatial_scale=spatial_scale, - sample_per_part=4, group_size=num_group, - pooled_size=num_group, output_dim=num_classes, + op = mx.contrib.sym.DeformablePSROIPooling(data=im_data_var, rois=rois_data_var, + trans=offset_data_var, spatial_scale=spatial_scale, + sample_per_part=4, group_size=num_group, + pooled_size=num_group, output_dim=num_classes, trans_std=0.1, no_trans=False, name='test_op') if grad_nodes[0] == 'offset_data': # wider tolerance needed for coordinate differential