Skip to content

Commit

Permalink
Refine fix to handle the case output is a TupleWrapper
Browse files Browse the repository at this point in the history
Add a regression test guarding on original bug.
  • Loading branch information
Li Xiaoquan committed Mar 12, 2019
1 parent ae2d046 commit fd3b54a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
16 changes: 12 additions & 4 deletions nnvm/python/nnvm/to_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def to_relay(graph, shape_dict, dtype_dict, params):
graph = graph.apply(["InferShape", "InferType"])
shape = graph.json_attr("shape")
dtype = [graph_attr.TCODE_TO_DTYPE[di] for di in graph.json_attr("dtype")]
heads = [x[0] for x in json.loads(graph.json())['heads']]
heads = [x for x in json.loads(graph.json())['heads']]

gidx = graph.index
relay_map = {}
Expand All @@ -468,8 +468,9 @@ def to_relay(graph, shape_dict, dtype_dict, params):
fn_params.append(v)
relay_map[nid] = v
else:
if nid in heads:
output_ids.append(nid)
for head in heads:
if head[0] == nid:
output_ids.append(head)

if op_name in NNVM_OP_2_RELAY_OP:
str_attrs = StrAttrsDict(attrs)
Expand All @@ -479,7 +480,14 @@ def to_relay(graph, shape_dict, dtype_dict, params):
raise Exception(
"nnvm.to_relay: unsupported operator: {0}".format(op_name))

outputs = [relay_map[nid] for nid in output_ids]
outputs = []
for i in output_ids:
output = relay_map[i[0]]
if isinstance(output, expr.TupleWrapper):
outputs.append(output[i[1]])
else:
outputs.append(output)

if len(outputs) == 1:
body = outputs[0]
else:
Expand Down
17 changes: 17 additions & 0 deletions tests/python/frontend/nnvm_to_relay/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ def test_forward_dqn():
verify_nnvm_to_relay(model, params, data_shape=(1, 4, 84, 84))


def test_forward_split_concatenate():
shape = (2, 16)

tensor = nnvm.sym.Variable("data", shape=shape)

splited = nnvm.sym.split(tensor, indices_or_sections=2, axis=1)

concatenated = nnvm.sym.concatenate(*splited, axis=1)

params = {}

verify_nnvm_to_relay(splited[0], params, data_shape=shape)
verify_nnvm_to_relay(splited[1], params, data_shape=shape)
verify_nnvm_to_relay(splited, params, data_shape=shape)
verify_nnvm_to_relay(concatenated, params, data_shape=shape)


if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand Down

0 comments on commit fd3b54a

Please sign in to comment.