Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-684] Add cond operator #11760

Merged
merged 8 commits into from
Jul 24, 2018
Merged

[MXNET-684] Add cond operator #11760

merged 8 commits into from
Jul 24, 2018

Conversation

junrushao
Copy link
Member

@junrushao junrushao commented Jul 14, 2018

Waiting for the while_loop operator to be merged so that I could rebase to master. Please do not merge for now.

Description

This PR is part of the proposal of adding a set of control flow operators to MXNet. Link to proposal. See also foreach (#11531) and while_loop (#11566).

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Add ifelse operator in src/operator/control_flow.cc

TODO

  • Wait for while_loop to be merged
  • Rebase to master
  • Fix flaky tests in while_loop because of numeric overflow
  • Remove duplicated benchmark codes

@junrushao junrushao force-pushed the if_pr branch 2 times, most recently from 9fa6ec2 to 322cb05 Compare July 17, 2018 21:16
@junrushao junrushao changed the title [WIP] Add ifelse operator [MXNET-684] Add ifelse operator Jul 17, 2018
@junrushao junrushao force-pushed the if_pr branch 5 times, most recently from dbcaae6 to 3bcaf8e Compare July 19, 2018 15:55
@junrushao
Copy link
Member Author

@zheng-da @piiswrong @szha @eric-haibin-lin Hey could you help review this PR?

@@ -363,3 +362,97 @@ def _func_wrapper(loop_vars):
[" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)]
))
return stacked_outputs, list(loop_vars)

def ifelse(cond, then_func, else_func, inputs):
"""Run a if-then-else using user-defined condition and computation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a => an

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

This operator simulates a if-like branch which chooses to do one of
the two customized computations according to the specified condition.

`inputs` is a list of NDArrays on which the condition and computations reply on.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reply => rely

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@@ -556,3 +556,154 @@ def _union_inputs(*graphs):
outputs = [result[i] for i in range(num_out_data)]
final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)]
return outputs, final_loop_vars

def ifelse(cond, then_func, else_func, inputs, name="ifelse"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please fix the same typos as the one above.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, ))
return inputs

def _create_subgraph(graph_vars, graph_func, subgraph_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems this function and the function below are the same as the one in while_loop. Can you move them out and reuse them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not exactly the same. One would search for var_locs, another doesn't.

outputs = _to_ndarray_tuple(outputs, "outputs of then_func")
else:
outputs = else_func(*inputs)
outputs = _to_ndarray_tuple(outputs, "outputs of else_func")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way of checking if the outputs from the if branch and the else branch have the same number of outputs and the same types, etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's give up ><

auto infer_subg = [&params, in_shape, out_shape](std::shared_ptr<Symbol> subg,
ShapeVector *_subg_out,
const nnvm::Tuple<dim_t> &input_locs,
bool fill_out_shape) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also reuse this function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not identical either.

params.then_input_locs, true);
bool succ_2 = infer_subg(attrs.subgraphs[2], &else_out_shape, \
params.else_input_locs, true);
return succ_0 && succ_1 && succ_2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to check then_out_shape and else_out_shape and see if they are the same.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, fixed

bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &then_in_type, out_type);
CHECK(sync_in_in(params.then_input_locs, in_type, &then_in_type, is_udf));
bool succ_2 = InferSubgraphDataType(*attrs.subgraphs[2], &else_in_type, out_type);
CHECK(sync_in_in(params.else_input_locs, in_type, &else_in_type, is_udf));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two subgraphs write to the same out_type, so we don't have to worry in this case.

CHECK(sync_in_in(params.then_input_locs, in_attrs, &then_in_attrs, is_udf));
bool succ_2 = InferSubgraphStorage(*attrs.subgraphs[2], dev_mask, \
&else_mode, &else_in_attrs, out_attrs);
CHECK(sync_in_in(params.else_input_locs, in_attrs, &else_in_attrs, is_udf));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two subgraphs write to the same out_type, so we don't have to worry in this case.

@zheng-da
Copy link
Contributor

do you need to rebase to the master since the flaky test has been fixed.

@junrushao
Copy link
Member Author

junrushao commented Jul 20, 2018

We are going to rename the operator from ifelse to condition according to discussion with @zheng-da offline

@junrushao junrushao changed the title [MXNET-684] Add ifelse operator [MXNET-684] Add cond operator Jul 20, 2018
@@ -363,3 +362,97 @@ def _func_wrapper(loop_vars):
[" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)]
))
return stacked_outputs, list(loop_vars)

def ifelse(cond, then_func, else_func, inputs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have both cond and inputs? Can we just have cond which could be true/false?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may help improve user experience I think.

@@ -363,3 +362,97 @@ def _func_wrapper(loop_vars):
[" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)]
))
return stacked_outputs, list(loop_vars)

def cond(cond_func, then_func, else_func, inputs): # pylint: disable=redefined-outer-name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove inputs and have tf style interface?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Just finished the API change. Thanks!

@junrushao junrushao changed the title [MXNET-684] Add cond operator [MXNET-684] Add condition operator Jul 21, 2018
@junrushao
Copy link
Member Author

junrushao commented Jul 21, 2018

@zheng-da and I have done an API change, according to the valuable comments from @eric-haibin-lin.

Here is the signature of our new API: condition(cond: Symbol, then_func: Function, else_func: Function) -> List[Symbol], where then_func and else_func are functions with no arguments.

Although there is no actual difference between the old and new APIs in the backend, we believe that this change will make our API easier to use for customers.

Thanks again for @eric-haibin-lin and @zheng-da for the valuable discussion!

@junrushao
Copy link
Member Author

So could someone help merge the code?

@zheng-da
Copy link
Contributor

@eric-haibin-lin @szha @piiswrong Do you have more comments? If not, can you merge it?

@junrushao
Copy link
Member Author

We are trying to get this in 1.3. So could someone help merge this PR?

@@ -28,7 +28,7 @@
except ImportError:
pass

__all__ = ["rand_zipfian", "foreach", "while_loop"]
__all__ = ["rand_zipfian", "foreach", "while_loop", "condition"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

condition is not a good name. Maybe cond, or conditional or if_else or something like that

@@ -363,3 +362,87 @@ def _func_wrapper(loop_vars):
[" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)]
))
return stacked_outputs, list(loop_vars)

def condition(cond, then_func, else_func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change signature to be the same with other packages?

@junrushao
Copy link
Member Author

@piiswrong Our initial name is ifelse. @zheng-da proposed to change a new name ‘cond’. ‘cond’ cannot make pylint happy, because it conflicts with the first argument of ‘while_loop’. So then I change it to ‘condition’. So could you guys help let know me a best naming so that we could get it merged today? @zheng-da @piiswrong

@junrushao
Copy link
Member Author

I change condition to cond for now

@junrushao junrushao changed the title [MXNET-684] Add condition operator [MXNET-684] Add cond operator Jul 23, 2018
@piiswrong piiswrong merged commit 4bb141d into apache:master Jul 24, 2018
@junrushao
Copy link
Member Author

junrushao commented Jul 24, 2018

Thank you so much guys for offering me valuable suggestions, and making this PR possible!

XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
* Initial commit for `Ifelse`

* Address comments

* Rename ifelse to condition

* API change

* Trigger CI

* Rename condition to cond

* Fix lint
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants