Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Set monitor callback basic support
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudh2290 committed Aug 20, 2019
1 parent b77f524 commit 75a1321
Show file tree
Hide file tree
Showing 11 changed files with 284 additions and 8 deletions.
12 changes: 12 additions & 0 deletions include/mxnet/c_api.h
Expand Up @@ -111,6 +111,11 @@ typedef void (*EngineFuncParamDeleter)(void*);
typedef void (*ExecutorMonitorCallback)(const char*,
NDArrayHandle,
void*);
/*! \brief Monitor callback called at operator level for cached op */
typedef void (*CachedOpMonitorCallback)(const char*,
const char*,
NDArrayHandle);


struct NativeOpInfo {
void (*forward)(int, float**, int*, unsigned**, int*, void*);
Expand Down Expand Up @@ -1222,6 +1227,13 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
NDArrayHandle **outputs,
const int** out_stypes);

/*!
* \brief cached op set monitor callback
*/
MXNET_DLL int MXCachedOpRegisterOpHook(NDArrayHandle handle,
CachedOpMonitorCallback callback,
bool monitor_all);

//--------------------------------------------
// Part 3: symbolic configuration generation
//--------------------------------------------
Expand Down
28 changes: 27 additions & 1 deletion python/mxnet/_ctypes/ndarray.py
Expand Up @@ -29,6 +29,13 @@
from ..base import check_call


def _monitor_callback_wrapper(callback):
"""A wrapper for the user-defined handle."""
def callback_handle(name, opr_name, array, _):
""" ctypes function """
callback(name, opr_name, array)
return callback_handle

class NDArrayBase(object):
"""Base data structure for ndarray"""
__slots__ = ["handle", "writable"]
Expand Down Expand Up @@ -112,10 +119,11 @@ def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op):

class CachedOp(object):
"""Cached operator handle."""
__slots__ = ["handle", "is_np_sym"]
__slots__ = ["handle", "is_np_sym", "_monitor_callback"]

def __init__(self, sym, flags=()):
self.handle = CachedOpHandle()
self._monitor_callback = None

from ..symbol.numpy._symbol import _Symbol
self.is_np_sym = bool(isinstance(sym, _Symbol))
Expand Down Expand Up @@ -170,3 +178,21 @@ def __call__(self, *args, **kwargs):
else:
return [create_ndarray_fn(ctypes.cast(output_vars[i], NDArrayHandle),
stype=out_stypes[i]) for i in range(num_output.value)]

def _register_op_hook(self, callback, monitor_all=False):
"""Install callback for monitor.
Parameters
----------
callback : function
Takes a string for node_name, string for op_name and a NDArrayHandle.
monitor_all : bool, default False
If true, monitor both input _imperative_invoked output, otherwise monitor output only.
"""
cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_char_p, NDArrayHandle, ctypes.c_void_p)
if callback:
self._monitor_callback = cb_type(_monitor_callback_wrapper(callback))
check_call(_LIB.MXCachedOpRegisterOpHook(
self.handle,
self._monitor_callback,
ctypes.c_int(monitor_all)))
37 changes: 37 additions & 0 deletions python/mxnet/gluon/block.py
Expand Up @@ -590,6 +590,19 @@ def forward(self, *args):
# pylint: disable= invalid-name
raise NotImplementedError

def register_op_hook(self, callback, monitor_all=False):
"""Install callback monitor.
Parameters
----------
callback : function
Takes a string and a NDArrayHandle.
monitor_all : bool, default False
If true, monitor both input and output, otherwise monitor output only.
"""
for cld in self._children.values():
cld.register_op_hook(callback, monitor_all)

def summary(self, *inputs):
"""Print the summary of the model's output and parameters.
Expand Down Expand Up @@ -754,6 +767,8 @@ def __init__(self, prefix=None, params=None):
self._in_format = None
self._active = False
self._flags = []
self._callback = None
self._monitor_all = False

def __setattr__(self, name, value):
"""Registers parameters."""
Expand Down Expand Up @@ -833,6 +848,12 @@ def _deferred_infer_shape(self, *args):
def _call_cached_op(self, *args):
if self._cached_op is None:
self._build_cache(*args)
assert self._cached_op, "cached op is not None"
if self._callback:
self._cached_op._register_op_hook(self._callback, self._monitor_all)
if len(self._flags) >= 2 and self._flags[1]:
warnings.warn("Callback is not supported when static_shape=True "
" and is likely to not work correctly")

args, fmt = _flatten(args, "input")
assert fmt == self._in_format, "Invalid input format"
Expand Down Expand Up @@ -938,6 +959,22 @@ def export(self, path, epoch=0, remove_amp_cast=True):
save_fn = _mx_npx.save if is_np_array() else ndarray.save
save_fn('%s-%04d.params'%(path, epoch), arg_dict)

def register_op_hook(self, callback, monitor_all=False):
"""Install op hook for block recursively.
Parameters
----------
callback : function
Takes a string and a NDArrayHandle.
monitor_all : bool, default False
If true, monitor both input and output, otherwise monitor output only.
"""
self._callback = callback
self._monitor_all = monitor_all
for cld in self._children.values():
cld._callback = callback
cld._monitor_all = monitor_all

def forward(self, x, *args):
"""Defines the forward computation. Arguments can be either
:py:class:`NDArray` or :py:class:`Symbol`."""
Expand Down
20 changes: 20 additions & 0 deletions src/c_api/c_api_ndarray.cc
Expand Up @@ -378,3 +378,23 @@ int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out) {
*out = reinterpret_cast<SymbolHandle>(sym);
API_END();
}

int MXCachedOpRegisterOpHook(NDArrayHandle handle,
CachedOpMonitorCallback callback,
bool monitor_all) {
API_BEGIN();
CachedOpMonitorCallback callback_temp = nullptr;
std::function<void(const char *, const char *, void*)> clbk;
if (callback) {
callback_temp = callback;
clbk = [callback_temp](const char *name, const char *opr_name,
void *handle) {
callback_temp(name, opr_name, handle);
};
} else {
clbk = nullptr;
}
CachedOpPtr op = *static_cast<CachedOpPtr *>(handle);
op->RegisterOpHook(clbk, monitor_all);
API_END();
}
57 changes: 57 additions & 0 deletions src/common/utils.cc
Expand Up @@ -51,5 +51,62 @@ void CastStorageDispatch<cpu>(const OpContext& ctx,
mxnet::op::CastStorageComputeImpl<cpu>(ctx, input, output);
}

void ExecuteMonInputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback) {
static const auto &flist_inputs =
nnvm::Op::GetAttr<nnvm::FListInputNames>("FListInputNames");
std::vector<std::string> input_names;
const nnvm::IndexedGraph::Node &inode = idx[nid];
const nnvm::Node *node = inode.source;
if (flist_inputs.count(node->op())) {
input_names = flist_inputs[node->op()](node->attrs);
} else {
for (size_t i = 0; i < node->num_inputs(); ++i) {
input_names.emplace_back("input" + std::to_string(i));
}
}

for (size_t i = 0; i < node->num_inputs(); ++i) {
const nnvm::NodeEntry &input = node->inputs[i];
if (state_arrays[idx.entry_id(input)]->is_none()) {
continue;
}
NDArray *cpy = new NDArray(*state_arrays[idx.entry_id(input)]);
std::string name = inode.source->attrs.name + "_" + input_names[i];
monitor_callback(name.c_str(), inode.source->op()->name.c_str(),
reinterpret_cast<void *>(cpy));
}
}

void ExecuteMonOutputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback) {
static const auto &flist_outputs =
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
std::vector<std::string> output_names;
const nnvm::IndexedGraph::Node &inode = idx[nid];
const nnvm::Node *node = inode.source;
if (flist_outputs.count(node->op())) {
output_names = flist_outputs[node->op()](node->attrs);
} else {
for (size_t i = 0; i < node->num_outputs(); ++i) {
output_names.emplace_back(std::to_string(i));
}
}

for (size_t i = 0; i < node->num_outputs(); ++i) {
if (state_arrays[idx.entry_id(nid, i)]->is_none()) {
continue;
}
NDArray *cpy = new NDArray(*state_arrays[idx.entry_id(nid, i)]);
std::string name = inode.source->attrs.name + "_" + output_names[i];
monitor_callback(name.c_str(), inode.source->op()->name.c_str(),
reinterpret_cast<void *>(cpy));
}
}

} // namespace common
} // namespace mxnet
9 changes: 9 additions & 0 deletions src/common/utils.h
Expand Up @@ -791,6 +791,15 @@ inline void ConvertToLegacyShape(mxnet::ShapeVector* shapes) {
ConvertToLegacyShape(&(shapes->at(i)));
}
}
void ExecuteMonInputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback);

void ExecuteMonOutputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback);

} // namespace common
} // namespace mxnet
Expand Down
23 changes: 20 additions & 3 deletions src/imperative/cached_op.cc
Expand Up @@ -697,6 +697,9 @@ void CachedOp::StaticRunOps(
ndinputs.emplace_back(state_arrays[idx.entry_id(j)]);
CHECK(!ndinputs.back()->is_none());
}
if (monitor_callback_ && monitor_all_) {
mxnet::common::ExecuteMonInputCallback(idx, state_arrays, i, monitor_callback_);
}
ndoutputs.clear();
ndoutputs.reserve(num_outputs);
req.clear();
Expand All @@ -708,6 +711,7 @@ void CachedOp::StaticRunOps(
CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none());
}
const DispatchMode dispatch_mode = dispatch_modes[i];

if (createop.count(node.source->op())) {
arg_shapes.clear();
arg_dtypes.clear();
Expand Down Expand Up @@ -735,6 +739,9 @@ void CachedOp::StaticRunOps(
default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
dispatch_mode);
}
if (monitor_callback_) {
mxnet::common::ExecuteMonOutputCallback(idx, state_arrays, i, monitor_callback_);
}
}
}
}
Expand Down Expand Up @@ -883,12 +890,12 @@ OpStatePtr CachedOp::DynamicForward(
// So if it's not the inline mode, we disable recording.
RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
std::move(ref_count), &states, dispatch_modes,
recording && inlining_);
recording && inlining_, nullptr, monitor_callback_, monitor_all_);
} else {
mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
NaiveRunGraph(false, default_ctx, idx, arrays, 0, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states,
dispatch_modes, recording && inlining_, &shapes);
dispatch_modes, recording && inlining_, &shapes, monitor_callback_, monitor_all_);
{
auto state_ptr = GetCachedOpState(default_ctx);
auto& state = state_ptr.get_state<CachedOpState>();
Expand Down Expand Up @@ -1028,7 +1035,7 @@ void CachedOp::DynamicBackward(

RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes,
Imperative::Get()->is_recording());
Imperative::Get()->is_recording(), nullptr, monitor_callback_);

if (retain_graph) {
buff.resize(num_forward_entries);
Expand Down Expand Up @@ -1295,6 +1302,16 @@ void CachedOpBackward(const OpStatePtr& state_ptr,
CopyFromTo(out_bufs[i], outputs[i]);
}

/*
* Register the callback to be called when the operator is executed
*/
void CachedOp::RegisterOpHook(const CachedOp::CachedOpMonCallback& callback,
bool monitor_all) {
CHECK(callback) << "invalid callback";
monitor_callback_ = callback;
monitor_all_ = monitor_all;
}

OpStatePtr CreateCachedOpState(const NodeAttrs& attrs,
Context ctx,
const mxnet::ShapeVector& in_shapes,
Expand Down
8 changes: 8 additions & 0 deletions src/imperative/cached_op.h
Expand Up @@ -74,6 +74,9 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
};

class CachedOp {
using CachedOpMonCallback =
std::function<void(const char *, const char *, void *)>;

public:
CachedOp(
const nnvm::Symbol& sym,
Expand Down Expand Up @@ -134,6 +137,8 @@ class CachedOp {
sym.outputs = fwd_graph_.outputs;
return sym;
}
void RegisterOpHook(const CachedOp::CachedOpMonCallback& callback,
bool monitor_all = false);

private:
struct GraphInfo;
Expand Down Expand Up @@ -203,6 +208,9 @@ class CachedOp {
std::vector<bool> save_inputs_, save_outputs_;
std::vector<OpReqType> bwd_output_reqs_;

std::function<void(const char*, const char*, NDArrayHandle)> monitor_callback_{nullptr};
bool monitor_all_{false};

std::mutex mutex_;
std::unordered_map<Context, std::vector<OpStatePtr> > cached_op_states_;
};
Expand Down

0 comments on commit 75a1321

Please sign in to comment.