-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] Add AutoGraph support for Python for loops #258
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #258 +/- ##
==========================================
+ Coverage 99.31% 99.47% +0.16%
==========================================
Files 41 41
Lines 7141 7284 +143
Branches 377 393 +16
==========================================
+ Hits 7092 7246 +154
+ Misses 27 20 -7
+ Partials 22 18 -4
☔ View full report in Codecov by Sentry. |
Catalyst for loop conversion works directly when iterating over arrays. Non-array iteration targets will attempt a conversion to array, if this fails we fall back to Python loops.
A custom range class is added that acts identically to the Python range class, however it allows tracers to be used as arguments for start, stop, and step. The for loop conversion looks for the new range class and can automatically incorporate it into the iteration bounds of the Catalyst `for_loop` function. Some advantages include that a static Python range does not need to be materialized into a constant array, and dynamic ranges are now also supported. A downside is that currently a conversion to Catalyst's for_loop always takes place when a range is encountered. However, the user may use the indices obtained from the range to index non-array objects, which is not detectable by us and would lead to a tracing error.
.. when any exception is raised during tracing of the Catalyst loop. A warning is raised when an exception occurs within the Catalyst loop, allowing users to correct certain mistakes (such as wrapping a list into an array). However, the conversion remains safe in the sense that it always falls back on the code as the user wrote it.
.. in warnings/errors arising inside of converted code.
No longer requires the pytest-mock package dependency.
The tests were more precise this way, but too brittle as the line numbers can change frequently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @dime10, great work.
I'm approving a bit early because I don't want to be a bottleneck here with my upcoming travel, but I only reviewed the tests and the docs, so the implementation should still be reviewed by someone else :)
Minor docs comment, but the qjit
docstring should also be updated to specify for
is now supported:
autograph (bool): Experimental support for automatically converting Python control
flow statements to Catalyst-compatible control flow. Currently only supports Python
``if``, ``elif``, and ``else`` statements. Note that this feature requires an
available TensorFlow installation.
Thanks for the review @josh146 :) Good point on the documentation, I'll have to polish it up a bit before merging! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great PR 👍 I am just wondering if CRange could inherit from range like you did with enumerate. This would remove issues with testing. Happy to approve after that
- catalyst.autograph_ignore_fallbacks: Silences warnings resulting from Python fallbacks. This is useful if the user doesn't want to see any warnings, or in situations where the warning cannot be fixed. - catalyst.autograph_strict_conversion: Turns Python fallback warnings into errors and produces the full traceback of the error that caused the fallback. This can be useful when debugging why a fallback happened.
Depending on when the tracing fails, variables tracked by autograph could have been modified already. To prevent an invalid state, we restore all autograph tracked variables to their original state before running the loop in Python.
This bug resulted in loop carried values (like a sum value) not being updated during the execution of the loop. This commit also adds a variety of tests around various uses of "iteration arguments".
I've added an additional user facing feature here: 8219ffa It provides two flags that the user can set around the conversion strictness (I used them during debugging and in the tests):
|
Thanks @dime10! I think these are good options to include. Non-blocking for this PR, but for end-users I worry about having these new options be module variables as opposed to an option you pass to In Strawberry Fields, we used to have It mainly showed up in several places:
So before usage gets baked in, might be good to turn these into arguments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good to me 💯 I suggest a autograph configuration kwarg for qjit but not blocking
I was originally thinking of similar configurations in JAX which are also global, like @qjit(autograph=True, ag_mode=...) Where
However, there is a problem because there is nothing actually linking the qjit object to the eventual invocation of the autograph transformed for loop. I think this is true in general for all our primitives, so there is no obvious way to pass compilation options into any of them. catalyst/frontend/catalyst/jax_tracer.py Line 114 in 64be9d2
|
I really like this 🙌 I'll add it to Q4 as a potential story. |
Introduces support for Python
for .. in ...:
statement capture as part of the compiled program.Similar to PR #235, AutoGraph is used to convert such statements into the equivalent Catalyst version before tracing occurs. Specifically, the following constructs are supported via AutoGraph:
for elem in iterable:
- These get converted into afor_loop(0, len(iterable), 1)
withelem = iterable[i]
automatically assigned using the iteration index, assumingiterable
is convertible to a JAX array. If this is not the case, the loop is executed as is in Python.for i in range(start, stop, step):
These are converted directly into their equivalentfor_loop(start, stop, step)
. Contrary to the default Pythonrange
, when AutoGraph is enabledrange
can also accept dynamic tracers asstart
,stop
,step
values. If any exception is raised during the tracing of thefor_loop
body, Catalyst will fall back to Python with a warning.for i, elem in enemurate(iterable):
- These get converted intofor_loop(0, len(iterable), 1)
with the iteration index assigned to the variable chosen by the user (in this casei
), andelem = iterable[i]
. This also assumes thatiterable
is convertible to an array, and that the loop body traces without exception, otherwise the loop is executed in Python.Note that a warning is raised when when a Python fallback is triggered due to a tracing exception. Python fallbacks caused by the iterable not being convertible to array are silent.
[sc-41287]