Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Allow stop of arange to be inferred from dims. #12064

Merged
merged 5 commits into from Aug 24, 2018

Conversation

taliesinb
Copy link
Contributor

Description

This PR adds the ability for an arange operator to leave the stop value unspecified, so that it will be inferred from the output shape (via backward shape inference). This is important to achieve shape polymorphism for efficient bucketing.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • allow start and stop to be the same

Comments

  • The more natural way of doing this would be for the stop parameter be None, but this was already used for a syntactic feature that allows arange(5) to be used to mean arange(0, 5).

@taliesinb
Copy link
Contributor Author

CC @sbodenstein @ThomasDelteil @szha

@taliesinb
Copy link
Contributor Author

@anirudh2290 how long should I expect for the CI check to take?

@eric-haibin-lin eric-haibin-lin added Operator pr-awaiting-review PR is waiting for code review labels Aug 9, 2018
@@ -471,6 +473,11 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs,
<< "Range does not support step=0, received " << param.step;
CHECK(param.repeat > 0)
<< "Range only supports repeat > 0, received " << param.repeat;
if (param.start == param.stop.value()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

though I understand that end=None is already taken, this condition still feels a bit like a hack...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szha what's the best way to proceed? it seems that everyone is very busy, so doing a discussion over PR comments might take a very long time. ideal would be some kind of realtime chat where we can discuss design alternatives and quickly come to a consensus. is there such a forum?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if that is not desired, here's a simple proposal: introduce a new parameter called "infer_range" that defaults to false. if it is true, then exactly one of param.stop, param.start, or param.step can be None, and will be inferred from the others and the output dimensions. I may only implement something more limited that makes param.stop be the only inferrable parameter, and leave that to others to implement in the future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a flag sounds good to me

@taliesinb
Copy link
Contributor Author

@szha implemented!

@taliesinb
Copy link
Contributor Author

@yzhliu @nswamy @anirudh2290 hi folks! it would be fantastic if we could get an idea of when you'll be able to review this PR further... it will help us plan other work that depend on this inside our company.

:as opts}]
(NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype))
(NDArray/arange (float start) ($/option (float stop)) step repeat infer-range ctx dtype))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gigasquid could you help review this part?

:as opts}]
(Symbol/arange (float start) ($/option (float stop)) step repeat nil dtype))
(Symbol/arange (float start) ($/option (float stop)) step repeat infer-range nil dtype))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
dtype=dtype, ctx=str(ctx))
return _internal._arange(start=start, stop=stop, step=step, infer_range=infer_range,
repeat=repeat, dtype=dtype, ctx=str(ctx))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: infer_range and repeat keyword arguments swapped place (not that it matters)

val params = Map("start" -> start, "step" -> step,
"repeat" -> repeat, "ctx" -> ctx.toString, "dtype" -> dType.toString())
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> inferRange, "ctx" -> ctx.toString, "dtype" -> dType.toString())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lanking520 could you help review this part?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. @taliesinb Could you please try to move the newly introduced param to the end of the function in order to bring backward compatibility?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update the doc string for this new parameter. same for symbol

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually prefer you create a new arange method with inferRange without default in the order you have here, the existing one one call with the default value(false).
Almost all of the methods have context as the last parameter, this one could cause confusion.

val params = Map("start" -> start, "step" -> step,
"repeat" -> repeat, "dtype" -> dType.toString())
"repeat" -> repeat, "infer_range" -> inferRange, "dtype" -> dType.toString())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@taliesinb Same applies to what I have mentioned above

val params = Map("start" -> start, "step" -> step,
"repeat" -> repeat, "ctx" -> ctx.toString, "dtype" -> dType.toString())
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> inferRange, "ctx" -> ctx.toString, "dtype" -> dType.toString())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. @taliesinb Could you please try to move the newly introduced param to the end of the function in order to bring backward compatibility?

val params = Map("start" -> start, "step" -> step,
"repeat" -> repeat, "dtype" -> dType.toString())
"repeat" -> repeat, "infer_range" -> inferRange, "dtype" -> dType.toString())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@taliesinb Same applies to what I have mentioned above

@taliesinb taliesinb force-pushed the feature/arange_inference branch 2 times, most recently from b82fd38 to 56bb7fb Compare August 18, 2018 10:59
@taliesinb
Copy link
Contributor Author

@lanking520 @nswamy I've implemented the suggestion to have a second operator instead of passing this as an option. It wasn't clear whether this was desired just for Scala or for both Scala and Clojure, so I did it for both. If CI passes then this should be ready for merge unless there are more comments.

(arange start stop {})))

(defn arange-with-inference
"Behaves like arange operator, but infers the stop value from the output shape,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job adding the clojure function with nice documentation. 👍 If you are feeling up to it you could also add the corresponding test for it here https://github.com/apache/incubator-mxnet/blob/master/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj#L141

(arange start stop {})))

(defn arange-with-inference
"Behaves like arange operator, but infers the stop value from the output shape,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here - Great job adding the clojure functions. If you want to add the corresponding test that would be awesome too https://github.com/apache/incubator-mxnet/blob/master/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj#L214. It can also be done in a follow up PR if that works better 😸

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gigasquid happy to try to do that! i haven't ever used Clojure or Scala before. But I've run into a problem even getting the first step, make scalapkg, to work on macOS. The make initially failed, comlpaining it couldn't find the mvn executable. I assumed that was maven, brew install maven had me first brew install java. Then make scalapkg seemed to be happy, and downloaded a bunch of stuff (including scala). But it failed with this:

[INFO] /Users/taliesinb/git/MXNet/scala-package/init/src/main/scala:-1: info: compiling
[INFO] Compiling 2 source files to /Users/taliesinb/git/MXNet/scala-package/init/target/classes at 1534710433224
Downloading from central: https://repo.maven.apache.org/maven2/org/scala-lang/modules/scala-xml_2.11/1.0.4/scala-xml_2.11-1.0.4.jar
Downloaded from central: https://repo.maven.apache.org/maven2/org/scala-lang/modules/scala-xml_2.11/1.0.4/scala-xml_2.11-1.0.4.jar (648 kB at 755 kB/s)
Downloading from central: https://repo.maven.apache.org/maven2/org/scala-lang/scala-library/2.11.4/scala-library-2.11.4.jar
Downloaded from central: https://repo.maven.apache.org/maven2/org/scala-lang/scala-library/2.11.4/scala-library-2.11.4.jar (5.5 MB at 1.3 MB/s)
Downloading from central: https://repo.maven.apache.org/maven2/org/scala-lang/scala-library/2.11.6/scala-library-2.11.6.jar
Downloaded from central: https://repo.maven.apache.org/maven2/org/scala-lang/scala-library/2.11.6/scala-library-2.11.6.jar (5.6 MB at 1.2 MB/s)
[INFO] compiler plugin: BasicArtifact(org.scalamacros,paradise_2.11.8,2.1.0)
Downloading from central: https://repo.maven.apache.org/maven2/org/scalamacros/paradise_2.11.8/2.1.0/paradise_2.11.8-2.1.0.jar
Downloaded from central: https://repo.maven.apache.org/maven2/org/scalamacros/paradise_2.11.8/2.1.0/paradise_2.11.8-2.1.0.jar (271 kB at 397 kB/s)
[ERROR] error: scala.reflect.internal.MissingRequirementError: object java.lang.Object in compiler mirror not found.
[ERROR] 	at scala.reflect.internal.MissingRequirementError$.signal(MissingRequirementError.scala:17)
[ERROR] 	at scala.reflect.internal.MissingRequirementError$.notFound(MissingRequirementError.scala:18)
[INFO] 	at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:53)

This seems to be related to https://issues.scala-lang.org/browse/SI-9103, but that issue is still open. I have no idea whats going on or how to make progress. I reran the make scalapkg to no effect, here's a gist with the full output: https://gist.github.com/taliesinb/d0f09e9f0202c3983298511383542f59. Do you have any suggestions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@taliesinb I'm impressed that you jumped in there on the Clojure and Scala 💯 - From the issue it seems like it is the JDK you are using. Using JDK 8 should solve the problems. If you have multiple versions of the JDK installed, you should just be able to switch by using an export of the right JAVA_HOME see here. Give it a try and see how it goes. If you don't want to hold up this PR, I'd be happy to assist on a follow up PR if you'd like 😸

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gigasquid great that worked! thanks for the help! I'll keep you posted as I (hopefully) make progress with this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@taliesinb are you making a new PR for Clojure or do you want to make changes to this one ? This is good for Scala APIs.
Thanks for the great work 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm @gigasquid I'm finding the instructions in the README.md file a little unclear in places. For example, under "Build from MXNET Source", I find this instruction a bit cryptic:

then replace the correct jar for your architecture in the project.clj, example [org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.3.0-SNAPSHOT"]

I would find this easier to understand if it was very explicit, such as "replace X with Y in section Z".

Here is what my project.clj contained out of the box:

                 ;; Jars from Nexus
                 ;[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.2.1"]
                 ;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.2.1"]
                 ;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-gpu "1.2.1"]

                 ;;; CI
                 [org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.3.0-SNAPSHOT"]

At this point, not knowing what to replace with what, I read the section "Cloning the repo and running from source", which mentions uncommenting rather than replacing. That section is also a bit confusing:

you will need to replace the native version of the line in the project dependencies with your configuration.

Which line? What is a "native version of the line"? Perhaps it could say "you will need to find and uncomment the appropriate line in the dependencies section of the project.clj file, and comment the rest". We could also make the project.clj section clearer so its more obvious what to do:

                 ;; default behavior, to be used by the CI bot on github; comment this line and
                 ;; uncomment the appropriate line in one of the other sections
                 [org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.3.0-SNAPSHOT"]

                 ;; use a prebuilt JAR from Nexus
                 ;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.2.1"]
                 ;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-gpu "1.2.1"]
                 ;[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.2.1"]

                 ;;; build a local JAR from source
                 ;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.3.0-SNAPSHOT"]
                 ;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-gpu "1.3.0-SNAPSHOT"]
                 ;[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.3.0-SNAPSHOT"]

Now, while the instructions could be a bit clearer, I did figure out the point eventually, and so I tried adding this line and commenting the rest:

[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.3.0-SNAPSHOT"]

After running lein clean and lein test I get this:

Generating symbol file
INFO  MXNetJVM: Try loading mxnet-scala from native path.
INFO  MXNetJVM: Try loading mxnet-scala-osx-x86_64-gpu from native path.
INFO  MXNetJVM: Try loading mxnet-scala-osx-x86_64-cpu from native path.
WARN  MXNetJVM: MXNet Scala native library not found in path. Copying native library from the archive. Consider installing the library somewhere in the path (for Windows: PATH, for Linux: LD_LIBRARY_PATH), or specifying by Java cmd option -Djava.library.path=[lib path].
INFO  org.apache.mxnet.util.NativeLibraryLoader: Replaced .dylib with .jnilib
INFO  org.apache.mxnet.util.NativeLibraryLoader: Loading libmxnet-scala.jnilib from /lib/native/ copying to mxnet-scala
[2

That WARN makes it sound like it's not using the library I built earlier using make scalainstall, which will mean I can't actually test my new functionality! Wasn't make scalainstall supposed to make the MXNet scala libraries available for everyone on my system?

How should I fix this?

Also, an ergonomics question: the test suite takes a while to run. With Python, it was very easy to just run the new tests I added using e.g. nosetests -v tests/python/unittest/test_operator.py. Is there a similar incantation for Clojure?

Thanks in advance for your help!

Copy link
Contributor Author

@taliesinb taliesinb Aug 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nswamy I'm now skeptical of my Scala changes. For example, I'm not sure that the additions to arange in NDArray.scala are correct. My concern is that the only way that you can even use the new inference feature is via backward inference, because the inference is based on the output shape of the tensor produced by arange, which must be inferred from a different part of the graph. So unless I'm missing something, calling the imperative arange function with infer_range = true will always fail as it has to produce an NDArray immediately, but this is not possible because backward inference is only relevant for symbols.

The new functionality should work in the symbolic context, however.
EDIT: to answer your original question, I'd prefer to add a Scala test to verify this last claim! If i did the wrong thing before... I don't trust myself now. Plus, if you agree, I should delete the imperative version of the new arange functionality from all language APIs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@taliesinb Thanks for the feedback on the wording. I'll update it to be more clear. It seems like you are doing everything exactly right. Once you do a scalainstall it will install locally in your maven a new 1.3.0-SNAPSHOT. Since you updated your project.clj to use this, it will load up the updated jar. The WARN is again misleading and can be improved, but it should be working :)
As far as running just one test, you can certainly do that with lein test :only org.apache.clojure-mxnet.ndarray-test and lein test :only org.apache.clojure-mxnet.operator-test.

Thanks again for the feedback and let me know if you have any other issues

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gigasquid thanks for the info. I'll get back to this tomorrow hopefully.

Copy link
Contributor Author

@taliesinb taliesinb Aug 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gigasquid ok we're good to go. I removed the pointless imperative function version of arange-with-inference, so I only had to add a test to operator_test.clj.

However, in doing this, I think I've picked up a problem with approx=, in which it incorrectly returns true if one of the comparisands (is that a word??) is shorter than the other, and differs in the remaining elements that the other does not have.

For example, try change the test starting on line 200 to the following:

(deftest ones
  (let [ones (sym/ones [2 2])
        exec (sym/simple-bind ones (context/default-context))]
    (is (approx= 1e-4
                 [1 1 1 1 9 9 9 9 9 9]
                 (-> exec (executor/forward) (executor/outputs) (first))))))

(I've introduced the 9 9 9 9 9 9 here). This test still passes.

I've reported the issue here: #12320, and fixed it in this PR. It doesn't produce any regressions, luckily!

If my new test looks good to you, we should be ready to merge!

@@ -407,11 +407,30 @@ object NDArray extends NDArrayBase {
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, repeat: Int = 1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@taliesinb Thanks for making this change. There is a compile error and CI is breaking.
I also want to slightly change this, I fixed the compile error and modified NDArray/Symbol and pushed a commit to your branch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nswamy oh thanks! my bad.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nswamy Looks good!

Copy link
Member

@lanking520 lanking520 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Include a test of this fix as well.
@gigasquid
Copy link
Member

@taliesinb Looks great! Thanks for the Clojure tests and fixing the helper function too. :shipit:

@gigasquid
Copy link
Member

Is there any more feedback? If not, I think this is good to merge.

@gigasquid gigasquid added pr-awaiting-merge Review and CI is complete. Ready to Merge and removed pr-awaiting-review PR is waiting for code review labels Aug 24, 2018
@nswamy nswamy merged commit 7bfe427 into apache:master Aug 24, 2018
@nswamy
Copy link
Member

nswamy commented Aug 24, 2018

Thanks for the great work and patience @taliesinb

@taliesinb taliesinb deleted the feature/arange_inference branch August 24, 2018 16:36
XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
* Allow stop of arange to be inferred from dims.

Enabled via a flag.

* modify NDArray/Symbol to add infer_range param

* Add test for arange-with-inference.

* Add a comment to readme about JDK 8.

* Fix approx=.

Include a test of this fix as well.
anirudh2290 pushed a commit to anirudh2290/mxnet that referenced this pull request Sep 19, 2018
* Allow stop of arange to be inferred from dims.

Enabled via a flag.

* modify NDArray/Symbol to add infer_range param

* Add test for arange-with-inference.

* Add a comment to readme about JDK 8.

* Fix approx=.

Include a test of this fix as well.
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Operator pr-awaiting-merge Review and CI is complete. Ready to Merge
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants