Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dmlc-core
4 changes: 2 additions & 2 deletions python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class DataBatch(object):
----------
data : list of NDArray
A list of input data.
label : list of NDArray
label : list of NDArray, optional
A list of input labels.
pad : int, optional
The number of examples padded at the batch end. It is used when the
Expand All @@ -100,7 +100,7 @@ class DataBatch(object):
provide_label : list of (name, shape), optional
The *i*-th elements describes the name and shape of ``label[i]``.
"""
def __init__(self, data, label, pad=None, index=None,
def __init__(self, data, label=None, pad=None, index=None,
bucket_key=None, provide_data=None, provide_label=None):
if data is not None:
assert isinstance(data, (list, tuple)), "Data must be list of NDArrays"
Expand Down
7 changes: 3 additions & 4 deletions python/mxnet/module/executor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ def _collect_arrays(self):

data_names = [x[0] for x in self.data_shapes]
if self.inputs_need_grad:
self.input_grad_arrays = [[exec_.grad_arrays[i] for exec_ in self.execs]
for i, name in enumerate(self.arg_names)
if name in data_names]
self.input_grad_arrays = [[exec_.grad_arrays[self.arg_names.index(name)]
for exec_ in self.execs]
for name in data_names if name in self.arg_names]
else:
self.input_grad_arrays = None

Expand Down Expand Up @@ -505,7 +505,6 @@ def backward(self, out_grads=None):
out_grads_slice.append(og_my_slice.as_in_context(self.contexts[i]))
else:
out_grads_slice.append(grad.copyto(self.contexts[i]))

exec_.backward(out_grads=out_grads_slice)

def update_metric(self, eval_metric, labels):
Expand Down
47 changes: 37 additions & 10 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,50 @@
import mxnet as mx
import mxnet.ndarray as nd
import numpy as np
from functools import reduce

def test_module_dtype():
dtype = np.float16
dshape = (3, 8, 7)
dtype = np.float16
dshape = (3, 8, 7)

sym = mx.sym.Variable('data')
sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC')
sym = mx.sym.Variable('data')
sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC')

mod = mx.mod.Module(sym, ('data',), None, context=[mx.cpu(0), mx.cpu(1)])
mod.bind(data_shapes=[mx.io.DataDesc('data', dshape, dtype, layout='TNC')])
mod.init_params()
mod.forward(mx.io.DataBatch(data=[mx.nd.ones(dshape, dtype=dtype)],
mod = mx.mod.Module(sym, ('data',), None, context=[mx.cpu(0), mx.cpu(1)])
mod.bind(data_shapes=[mx.io.DataDesc('data', dshape, dtype, layout='TNC')])
mod.init_params()
mod.forward(mx.io.DataBatch(data=[mx.nd.ones(dshape, dtype=dtype)],
label=None))
mod.backward([mx.nd.ones(dshape, dtype=dtype)])
mod.backward([mx.nd.ones(dshape, dtype=dtype)])

for x in mod.get_outputs():
for x in mod.get_outputs():
assert x.dtype == dtype


def test_module_input_grads():
a = mx.sym.Variable('a', __layout__='NC')
b = mx.sym.Variable('b', __layout__='NC')
c = mx.sym.Variable('c', __layout__='NC')

c = a + 2 * b + 3 * c
net = mx.mod.Module(c, data_names=['b', 'c', 'a'], label_names=None,
context=[mx.cpu(0), mx.cpu(1)])
net.bind(data_shapes=[['b', (5, 5)], ['c', (5, 5)], ['a', (5, 5)]],
label_shapes=None, inputs_need_grad=True)
net.init_params()

net.forward(data_batch=mx.io.DataBatch(data=[nd.ones((5, 5)),
nd.ones((5, 5)),
nd.ones((5, 5))]))
net.backward(out_grads=[nd.ones((5, 5))])
input_grads = net.get_input_grads()
b_grad = input_grads[0].asnumpy()
c_grad = input_grads[1].asnumpy()
a_grad = input_grads[2].asnumpy()
assert np.all(a_grad == 1), a_grad
assert np.all(b_grad == 2), b_grad
assert np.all(c_grad == 3), c_grad

def test_module_layout():
sym = mx.sym.Variable('data')
sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC')
Expand Down Expand Up @@ -230,6 +256,7 @@ def mean_abs(x):

if __name__ == '__main__':
test_module_dtype()
test_module_input_grads()
test_module_states()
test_module_reshape()
test_save_load()
Expand Down