Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add deferred compute support
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Feb 10, 2020
1 parent b65db3c commit a7077de
Show file tree
Hide file tree
Showing 12 changed files with 846 additions and 62 deletions.
38 changes: 38 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,40 @@ MXNET_DLL int MXCachedOpRegisterOpHook(NDArrayHandle handle,
CachedOpMonitorCallback callback,
bool monitor_all);

/*!
* \brief Get current status of deferred compute mode
* \param curr returns the current status.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayIsDeferredComputeEnabled(int *curr);

/*!
* \brief set whether to enable deferred compute mode
* \param deferred_compute_enabled 1 to enable, 0 to disable.
* \param prev returns the previous status before this set.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArraySetDeferredComputeEnabled(int deferred_compute_enabled, int *prev);

/*!
* \brief Convert the graph constructed during deferred computation mode to a Symbol.
* \param input_handles ndarray handles of non-deferred computed inputs
* \param output_handles ndarray handles of outputs
* \param input_names names associated with the inputs of the returned Symbol
* \param out grouped output symbol handle
*
* Construct a Symbol for the deferred computation graph. input_handles must
* provide all ndarray handles of non-deferred computed ndarrays used as
* arguments to operators inside the deferred compute scope. output_handles
* specifies the outputs of interest which the returned symbol will compute.
*/
MXNET_DLL int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle *input_handles,
NDArrayHandle *output_handles,
const char** input_names,
int num_inputs,
int num_outputs,
SymbolHandle *out);

//--------------------------------------------
// Part 3: symbolic configuration generation
//--------------------------------------------
Expand Down Expand Up @@ -1494,6 +1528,10 @@ MXNET_DLL int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char **return_type DEFAULT(NULL));
/*!
* \brief Create an AtomicSymbol.
*
* A Symbol is said to be atomic if it is not composed of other Symbols. Atomic
* Symbols can be composed.
*
* \param creator the AtomicSymbolCreator
* \param num_param the number of parameters
* \param keys the keys to the params
Expand Down
94 changes: 92 additions & 2 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class Imperative {
OpReqType grad_req;
OpStatePtr state;
std::vector<NDArray> outputs;
std::vector<NDArray> out_grads;
std::vector<NDArray> out_grads; // used to hold gradient arrays the user is
// interested in (marked variables)
bool fresh_out_grad;

AGInfo() :
Expand All @@ -79,7 +80,7 @@ class Imperative {
}

static bool IsNone(const NDArray& arr) {
return arr.entry_.node == nullptr || arr.entry_.node->info.empty();
return arr.autograd_entry_.node == nullptr || arr.autograd_entry_.node->info.empty();
}

static bool IsVariable(const nnvm::ObjectPtr& node) {
Expand All @@ -88,6 +89,77 @@ class Imperative {
&& info.out_grads.size() == 1;
}
};

/*! \brief DCInfo datastructure to enable deferred computation */
class DCInfo {
public:
explicit DCInfo() {
// Default constructor provided for the sake of any.h. Should not be used.
throw std::invalid_argument("Unsupported default constructor");
};
explicit DCInfo(const std::vector<NDArray *> &inputs,
const std::vector<NDArray *> &outputs);

/*! \brief Compute the outputs of the associated operator. */
static void Compute(const NDArray &arr);

static DCInfo &Get(const nnvm::ObjectPtr &node) {
return dmlc::get<DCInfo>(node->info);
}

static bool IsNone(const NDArray &arr) {
return arr.deferredcompute_entry_.node == nullptr ||
arr.deferredcompute_entry_.node->info.empty();
}

static bool IsComputed(const NDArray &arr) {
return IsNone(arr) ||
dmlc::get<DCInfo>(arr.deferredcompute_entry_.node->info).is_computed_;
}

static DCInfo &Create(const nnvm::ObjectPtr &node,
const std::vector<NDArray *> &inputs,
const std::vector<NDArray *> &outputs);

private:
friend class Imperative;

/*! \brief Copies of input NDArrays
*
* If respective input NDArray is deallocated on the frontend, we still need
* to keep a copy around to facilitate deferred computation of this array.
* The copies share the chunk.
*
* They are automatically deallocated after computation finished.
*/
std::vector<NDArray> inputs_;

/*! \brief Handles of input NDArrays used by frontend
*
* Frontend may request conversion to Symbol, specifying a list of NDArray
* handles corresponding to inputs and outputs of the Symbol. We store the
* handles used by frontend to facilitate matching in
* GetDeferredComputeSymbol.
*
* Note that the frontend may have deallocated the NDArray* and the
* input_handles stored here may point to invalid memory.
*/
std::vector<const NDArray *> input_handles_;

/*! \brief Copies of output NDArrays
*
* If respective output NDArray is deallocated on the frontend, we still
* need to keep a copy around to facilitate deferred computation of arrays
* relying on the output array. The copies share the chunk.
*
* They are automatically deallocated after computation finished.
*/
std::vector<NDArray> outputs_;

/*! \brief Remember if the outputs associated with this DCInfo have been computed already */
bool is_computed_ = false;
};

/*! \brief whether operator recording is on. */
bool is_training() const {
return is_train_;
Expand All @@ -108,6 +180,14 @@ class Imperative {
is_recording_ = is_recording;
return old;
}
/*! \brief whether deferred compute mode is on. */
bool is_deferred_compute() const { return is_deferred_compute_; }
/*! \brief turn on or turn off operator recording for autograd. */
bool set_is_deferred_compute(bool is_deferred_compute) {
bool old = is_deferred_compute_;
is_deferred_compute_ = is_deferred_compute;
return old;
}
/*! \brief return current numpy compatibility status,
* GlobalOn(2), ThreadLocalOn(1), Off(0).
* */
Expand Down Expand Up @@ -143,6 +223,14 @@ class Imperative {
const OpStatePtr& state = OpStatePtr(),
std::vector<bool>* p_save_inputs = nullptr,
std::vector<bool>* p_save_outputs = nullptr);
/*! \brief to record operator, return corresponding node. */
void RecordDeferredCompute(nnvm::NodeAttrs&& attrs,
std::vector<NDArray*>& inputs,
std::vector<NDArray*>& outputs);
/*! \brief obtain symbol representation of deferred compute session. */
nnvm::Symbol *GetDeferredComputeSymbol(
const std::vector<std::pair<NDArray *, std::string>> &inputs,
const std::vector<NDArray *> &outputs);
/*! \brief */
OpStatePtr Invoke(const Context& default_ctx,
const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -204,12 +292,14 @@ class Imperative {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local bool is_train_;
static thread_local bool is_recording_;
static thread_local bool is_deferred_compute_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static thread_local bool is_np_shape_thread_local_;
#else
static MX_THREAD_LOCAL bool is_train_;
static MX_THREAD_LOCAL bool is_recording_;
static MX_THREAD_LOCAL bool is_deferred_compute_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
Expand Down
52 changes: 24 additions & 28 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class NDArray {
public:
/*! \brief default constructor */
NDArray()
: entry_(nullptr) {
: autograd_entry_(nullptr) {
}
/*!
* \brief constructs a new dynamic NDArray
Expand All @@ -98,7 +98,7 @@ class NDArray {
shape_(shape),
dtype_(dtype),
storage_type_(kDefaultStorage),
entry_(nullptr) {
autograd_entry_(nullptr) {
}
/*! \brief constructor for NDArray with storage type
*/
Expand All @@ -117,7 +117,7 @@ class NDArray {
shape_(),
dtype_(dtype),
storage_type_(kDefaultStorage),
entry_(nullptr) {
autograd_entry_(nullptr) {
}
/*!
* \brief constructing a static NDArray that shares data with TBlob
Expand All @@ -131,7 +131,7 @@ class NDArray {
shape_(data.shape_),
dtype_(data.type_flag_),
storage_type_(kDefaultStorage),
entry_(nullptr) {
autograd_entry_(nullptr) {
}

/*!
Expand All @@ -149,7 +149,7 @@ class NDArray {
}),
shape_(data.shape_),
dtype_(data.type_flag_), storage_type_(kDefaultStorage),
entry_(nullptr) {
autograd_entry_(nullptr) {
}

/*! \brief create ndarray from shared memory */
Expand All @@ -158,7 +158,7 @@ class NDArray {
shape_(shape),
dtype_(dtype),
storage_type_(kDefaultStorage),
entry_(nullptr) {
autograd_entry_(nullptr) {
}

/*!
Expand All @@ -177,7 +177,7 @@ class NDArray {
shape_(shape),
dtype_(data.type_flag_),
storage_type_(stype),
entry_(nullptr) {
autograd_entry_(nullptr) {
}
/*!
* \brief initialize the NDArray, assuming it is not assigned a meaningful shape before
Expand Down Expand Up @@ -326,9 +326,9 @@ class NDArray {
inline bool is_none() const {
return ptr_.get() == nullptr;
}
/*! \return updated grad state in entry_ */
/*! \return updated grad state in autograd_entry_ */
bool fresh_out_grad() const;
/*! \return updated grad state in entry_ */
/*! \return updated grad state in autograd_entry_ */
void set_fresh_out_grad(bool state) const;
/*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized
* Throws an exception if the indices array shape is inconsistent
Expand Down Expand Up @@ -364,27 +364,19 @@ class NDArray {
/*!
* \brief Block until all the pending write operations with respect
* to current NDArray are finished, and read can be performed.
*
* If the array has not been computed yet (deferred compute), this will
* trigger computation.
*/
inline void WaitToRead() const {
if (is_none()) return;
Engine::Get()->WaitForVar(ptr_->var);
}
void WaitToRead() const;
/*!
* \brief Block until all the pending read/write operations with respect
* to current NDArray are finished, and write can be performed.
*
* If the array has not been computed yet (deferred compute), this will
* trigger computation.
*/
inline void WaitToWrite() const {
if (is_none()) return;
/*!
* Push an empty mutable function to flush all preceding reads to the
* variable.
*/
Engine::Get()->PushAsync(
[](RunContext, Engine::CallbackOnComplete on_complete) {
on_complete();
}, Context{}, {}, {ptr_->var});
Engine::Get()->WaitForVar(ptr_->var);
}
void WaitToWrite() const;
/*! \return the associated variable of the ndarray.*/
inline Engine::VarHandle var() const {
return ptr_->var;
Expand Down Expand Up @@ -645,11 +637,13 @@ class NDArray {
*/
NDArray ReshapeWithRecord(const mxnet::TShape &shape);
/*!
* \brief Return a copy of this NDArray without autograd history
* \brief Return a copy of this NDArray without autograd and deferred compute
* history
*/
NDArray Detach() const {
NDArray ret(*this);
ret.entry_ = nnvm::NodeEntry(nullptr);
ret.autograd_entry_ = nnvm::NodeEntry(nullptr);
ret.deferredcompute_entry_ = nnvm::NodeEntry(nullptr);
return ret;
}

Expand Down Expand Up @@ -1100,7 +1094,9 @@ class NDArray {
/*! \brief storage type of data */
NDArrayStorageType storage_type_ = kUndefinedStorage;
/*! \brief node entry for autograd */
nnvm::NodeEntry entry_;
nnvm::NodeEntry autograd_entry_;
/*! \brief node entry for deferred computation tracking */
nnvm::NodeEntry deferredcompute_entry_;
/*!
* \brief internal TBlob
* \note When user access tblob_ by some const methods like
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@
from . import rnn
from . import gluon

from . import _deferred_compute

# With the native kvstore module (such as 'dist_sync_device'), the module launches a separate
# process when role is set to "server". This should be done after other modules are initialized.
# Otherwise this may result in errors when unpickling custom LR scheduler/optimizers.
Expand Down
Loading

0 comments on commit a7077de

Please sign in to comment.