Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ class Executor:
>>> c = 2 * a + b
>>> texec = c._bind(mx.cpu(), {'a': mx.nd.array([1,2]), 'b':mx.nd.array([2,3])})
"""
def __init__(self, sym, ctx, args, args_grad, grad_req, aux_states):
def __init__(self, sym, ctx, args, args_grad, grad_req, aux_states, static_alloc=False):
self.outputs = None
self._input_names = sym.list_inputs()
self._aux_names = sym.list_auxiliary_states()
self._arg_names = sym.list_arguments()
self._output_names = sym.list_outputs()
self._ctx = ctx
self._grad_req = grad_req
self.static_alloc = static_alloc
# grad_req
self._requires_grad = False
if isinstance(grad_req, dict):
Expand Down Expand Up @@ -121,7 +122,7 @@ def __init__(self, sym, ctx, args, args_grad, grad_req, aux_states):
with self._ctx:
self._args[i].attach_grad(req, stype=g.stype)
self._args[i].grad[:] = g
self._cached_op = ndarray.CachedOp(sym)
self._cached_op = ndarray.CachedOp(sym, flags=[("static_alloc", self.static_alloc)])

def get_optimized_symbol(self):
"""Get an optimized version of the symbol from the executor.
Expand Down
7 changes: 5 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1793,7 +1793,7 @@ def _simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,
return Executor(self, ctx, args, args_grad, grad_req, aux_states)

def _bind(self, ctx, args, args_grad=None, grad_req='write',
aux_states=None):
aux_states=None, static_alloc=False):
"""Binds the current symbol to an executor and returns it.

We first declare the computation and then bind to the data to run.
Expand Down Expand Up @@ -1856,6 +1856,9 @@ def _bind(self, ctx, args, args_grad=None, grad_req='write',
`auxiliary_states` to the corresponding `NDArray`,
- In either case, all the auxiliary states need to be provided.

static_alloc : bool, default False
Statically allocate memory to improve speed. Memory usage may increase.

Returns
-------
executor : Executor
Expand All @@ -1874,7 +1877,7 @@ def _bind(self, ctx, args, args_grad=None, grad_req='write',
gradient they interested in.
"""
assert isinstance(grad_req, (str, dict))
return Executor(self, ctx, args, args_grad, grad_req, aux_states)
return Executor(self, ctx, args, args_grad, grad_req, aux_states, static_alloc)

def gradient(self, wrt):
"""Gets the autodiff of current symbol.
Expand Down
21 changes: 17 additions & 4 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,21 @@ bool CachedOp::SetBackwardGraph(GraphInfo* info,
g.attrs[AddPrefix(BACKWARD, REF_COUNT)] = std::make_shared<dmlc::any>(std::move(ref_count));
}

// Set AddTo Entry based on the req that users provide
if (detect_inplace_addto) {
std::vector<int> addto_entry(idx.num_node_entries(), 0);
for (size_t i = 0; i < info->grad_graph.outputs.size(); ++i) {
if (reqs[i] == kAddTo) {
auto entry = info->grad_graph.outputs[i];
if (!idx.exist(entry.node.get()))
continue;
auto eid = idx.entry_id(entry);
addto_entry[eid] = 1;
}
}
g.attrs["addto_entry"] = std::make_shared<nnvm::any>(std::move(addto_entry));
}

auto shapes = info->fwd_graph.GetAttr<mxnet::ShapeVector>("shape");
shapes.resize(idx.num_node_entries(), mxnet::TShape());
auto dtypes = info->fwd_graph.GetAttr<DTypeVector>("dtype");
Expand Down Expand Up @@ -1047,8 +1062,7 @@ void CachedOp::StaticBackward(const bool retain_graph,
auto entry = state.info.grad_graph.outputs[iter->second];
if (!idx.exist(entry.node.get()))
continue;
auto eid = idx.entry_id(entry);
state.array_reqs[eid] = reqs[iter->second];
auto eid = idx.entry_id(entry);
// An input and an output may share the same array.
INIT_DETACHED(outputs[iter->second], arrays[eid]);
arrays[eid] = outputs[iter->second];
Expand All @@ -1058,8 +1072,7 @@ void CachedOp::StaticBackward(const bool retain_graph,
auto entry = state.info.grad_graph.outputs[i];
if (!idx.exist(entry.node.get()))
continue;
auto eid = idx.entry_id(entry);
state.array_reqs[eid] = reqs[i];
auto eid = idx.entry_id(entry);
// An input and an output may share the same array.
INIT_DETACHED(outputs[i], arrays[eid]);
arrays[eid] = outputs[i];
Expand Down
7 changes: 6 additions & 1 deletion src/imperative/inplace_addto_detect_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ Graph DetectInplaceAddTo(Graph g) {
auto& idx = g.indexed_graph();
// reference cont.
std::vector<int> ref_count(idx.num_node_entries(), 0);
std::vector<int> addto_entry(idx.num_node_entries(), 0);
std::vector<int> addto_entry;
if (g.attrs.count("addto_entry")) {
addto_entry = g.GetAttr<std::vector<int> >("addto_entry");
} else {
addto_entry = std::vector<int>(idx.num_node_entries(), 0);
}
std::vector<int> skip_plus_node(idx.num_nodes(), 0);

for (auto& e : idx.outputs()) {
Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,21 @@ def check_init(static_alloc, static_shape):
check_init(False, False)
check_init(True, False)
check_init(True, True)

def test_elemwise_add_grad():
json = "{\"nodes\": [{\"op\":\"null\",\"name\":\".Inputs.Input1\",\"inputs\":[]},{\"op\":\"null\",\"name\":\".Inputs.Input2\",\"inputs\":[]},{\"op\":\"elemwise_add\",\"name\":\".$0\",\"inputs\":[[0,0,0],[1,0,0]]},{\"op\":\"_copy\",\"name\":\".Outputs.Output\",\"inputs\":[[2,0,0]]}],\"arg_nodes\":[0,1],\"heads\":[[3,0,0]]}"
sym = mx.symbol.fromjson(json)

ex = sym._bind(
mx.cpu(),
{'.Inputs.Input1': mx.nd.array([0.4]), '.Inputs.Input2': mx.nd.array([0.5])},
args_grad={
'.Inputs.Input1': mx.ndarray.zeros((1)),
'.Inputs.Input2': mx.ndarray.zeros((1))
},
grad_req={'.Inputs.Input1': 'null', '.Inputs.Input2': 'write'}
)
ex.forward(is_train=True)
print(ex.outputs)
ex.backward(out_grads=mx.nd.array([1]))
print(ex.grad_arrays)