Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix shape inference bug #7682

Merged
merged 1 commit into from
Aug 31, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/executor/infer_graph_attr_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,
uint32_t eid = idx.entry_id(nid, igrad[i].index);
if (fis_none(rshape[eid])) {
rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])];
} else {
} else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) {
// Need to skip empty forward shape, because it may not be
// available now and it is possible to infer the forward
// shape in one of the next a few passes
CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])])
<< "Backward shape inconsistent with the forward shape";
}
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,31 @@ def test_zero_prop2():
assert False


def test_simple_bind_special_case():
"""This is a special case that results in shape inference
failure after moving simple_bind logic from frontend to backend.
Added here for testing against the network similar to the following one.

Network diagram:
weight --> abs_op --> sum_op --
|--> add_op
data --> fc_op --> sum_op --

Given data's shape, if the shape inference starts from weight node,
then the node entries of negative_op and sum_op are unknown in the
forward pass. Therefore, there are several unknown shapes after the
first forward pass is done. Now the backward inference pass starts with
the assumption that there are no unknown-shape node entries in the forward
pass, and consequently, leads to CHECK_EQ failure.
"""
data_shape = (5, 13)
data = mx.sym.Variable('data')
fc = mx.sym.FullyConnected(data=data, num_hidden=1, no_bias=True, name='fc')
modified_weight = mx.sym.abs(fc.get_internals()['fc_weight'])
net = mx.sym.sum(modified_weight) + mx.sym.sum(fc)
net.simple_bind(ctx=mx.cpu(), data=data_shape)


if __name__ == '__main__':
import nose
nose.runmodule()