Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Add AutoGraph support for Python for loops #258

Merged
merged 25 commits into from
Sep 21, 2023
Merged

Conversation

dime10
Copy link
Collaborator

@dime10 dime10 commented Aug 23, 2023

Introduces support for Python for .. in ...: statement capture as part of the compiled program.
Similar to PR #235, AutoGraph is used to convert such statements into the equivalent Catalyst version before tracing occurs. Specifically, the following constructs are supported via AutoGraph:

  • for elem in iterable: - These get converted into a for_loop(0, len(iterable), 1) with elem = iterable[i] automatically assigned using the iteration index, assuming iterable is convertible to a JAX array. If this is not the case, the loop is executed as is in Python.
  • for i in range(start, stop, step): These are converted directly into their equivalent for_loop(start, stop, step). Contrary to the default Python range, when AutoGraph is enabled range can also accept dynamic tracers as start, stop, step values. If any exception is raised during the tracing of the for_loop body, Catalyst will fall back to Python with a warning.
  • for i, elem in enemurate(iterable): - These get converted into for_loop(0, len(iterable), 1) with the iteration index assigned to the variable chosen by the user (in this case i), and elem = iterable[i]. This also assumes that iterable is convertible to an array, and that the loop body traces without exception, otherwise the loop is executed in Python.

Note that a warning is raised when when a Python fallback is triggered due to a tracing exception. Python fallbacks caused by the iterable not being convertible to array are silent.

[sc-41287]

@codecov
Copy link

codecov bot commented Aug 23, 2023

Codecov Report

Patch coverage: 100.00% and project coverage change: +0.16% 🎉

Comparison is base (ffdd0ab) 99.31% compared to head (90b1a49) 99.47%.
Report is 2 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #258      +/-   ##
==========================================
+ Coverage   99.31%   99.47%   +0.16%     
==========================================
  Files          41       41              
  Lines        7141     7284     +143     
  Branches      377      393      +16     
==========================================
+ Hits         7092     7246     +154     
+ Misses         27       20       -7     
+ Partials       22       18       -4     
Files Changed Coverage Δ
frontend/catalyst/compilation_pipelines.py 100.00% <ø> (ø)
frontend/catalyst/pennylane_extensions.py 99.14% <ø> (+<0.01%) ⬆️
frontend/catalyst/__init__.py 95.83% <100.00%> (+0.37%) ⬆️
frontend/catalyst/ag_primitives.py 100.00% <100.00%> (ø)
frontend/catalyst/autograph.py 100.00% <100.00%> (ø)
frontend/catalyst/jax_primitives.py 97.04% <100.00%> (ø)
frontend/catalyst/jax_tracer.py 99.27% <100.00%> (+3.79%) ⬆️

... and 4 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Base automatically changed from python_cf to main August 25, 2023 19:27
Catalyst for loop conversion works directly when iterating over arrays.
Non-array iteration targets will attempt a conversion to array,
if this fails we fall back to Python loops.
A custom range class is added that acts identically to the Python range
class, however it allows tracers to be used as arguments for start,
stop, and step.

The for loop conversion looks for the new range class and can
automatically incorporate it into the iteration bounds of the Catalyst
`for_loop` function.

Some advantages include that a static Python range does not need to be
materialized into a constant array, and dynamic ranges are now also
supported.
A downside is that currently a conversion to Catalyst's for_loop always
takes place when a range is encountered. However, the user may use the
indices obtained from the range to index non-array objects, which is not
detectable by us and would lead to a tracing error.
.. when any exception is raised during tracing of the Catalyst loop.
A warning is raised when an exception occurs within the Catalyst loop,
allowing users to correct certain mistakes (such as wrapping a list into
an array). However, the conversion remains safe in the sense that it
always falls back on the code as the user wrote it.
No longer requires the pytest-mock package dependency.
The tests were more precise this way, but too brittle as the line
numbers can change frequently.
@dime10 dime10 marked this pull request as ready for review September 13, 2023 00:22
Copy link
Member

@josh146 josh146 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dime10, great work.

I'm approving a bit early because I don't want to be a bottleneck here with my upcoming travel, but I only reviewed the tests and the docs, so the implementation should still be reviewed by someone else :)

Minor docs comment, but the qjit docstring should also be updated to specify for is now supported:

autograph (bool): Experimental support for automatically converting Python control
    flow statements to Catalyst-compatible control flow. Currently only supports Python
    ``if``, ``elif``, and ``else`` statements. Note that this feature requires an
    available TensorFlow installation.

doc/dev/installation.rst Show resolved Hide resolved
frontend/catalyst/ag_primitives.py Show resolved Hide resolved
frontend/catalyst/ag_primitives.py Show resolved Hide resolved
frontend/catalyst/ag_primitives.py Outdated Show resolved Hide resolved
frontend/catalyst/ag_primitives.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_autograph.py Show resolved Hide resolved
frontend/test/pytest/test_autograph.py Show resolved Hide resolved
frontend/test/pytest/test_autograph.py Show resolved Hide resolved
frontend/test/pytest/test_autograph.py Show resolved Hide resolved
frontend/test/pytest/test_autograph.py Show resolved Hide resolved
@dime10
Copy link
Collaborator Author

dime10 commented Sep 13, 2023

Thanks for the review @josh146 :) Good point on the documentation, I'll have to polish it up a bit before merging!

Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR 👍 I am just wondering if CRange could inherit from range like you did with enumerate. This would remove issues with testing. Happy to approve after that

.github/workflows/check-catalyst.yaml Show resolved Hide resolved
frontend/catalyst/ag_primitives.py Show resolved Hide resolved
frontend/catalyst/ag_primitives.py Outdated Show resolved Hide resolved
frontend/catalyst/ag_primitives.py Outdated Show resolved Hide resolved
- catalyst.autograph_ignore_fallbacks:
  Silences warnings resulting from Python fallbacks.
  This is useful if the user doesn't want to see any warnings, or in
  situations where the warning cannot be fixed.
- catalyst.autograph_strict_conversion:
  Turns Python fallback warnings into errors and produces the full
  traceback of the error that caused the fallback.
  This can be useful when debugging why a fallback happened.
Depending on when the tracing fails, variables tracked by autograph
could have been modified already. To prevent an invalid state, we
restore all autograph tracked variables to their original state before
running the loop in Python.
This bug resulted in loop carried values (like a sum value) not being
updated during the execution of the loop.

This commit also adds a variety of tests around various uses of
"iteration arguments".
@dime10
Copy link
Collaborator Author

dime10 commented Sep 19, 2023

I've added an additional user facing feature here: 8219ffa

It provides two flags that the user can set around the conversion strictness (I used them during debugging and in the tests):

  • catalyst.autograph_ignore_fallbacks: Silences warnings resulting from Python fallbacks.
    This is useful if the user doesn't want to see any warnings, or in situations where the warning cannot be fixed.
  • catalyst.autograph_strict_conversion: Turns Python fallback warnings into errors and produces the full
    traceback of the error that caused the fallback.
    This can be useful when debugging why a fallback happened.

@josh146
Copy link
Member

josh146 commented Sep 19, 2023

Thanks @dime10! I think these are good options to include.

Non-blocking for this PR, but for end-users I worry about having these new options be module variables as opposed to an option you pass to @qjit.

In Strawberry Fields, we used to have strawberryfields.hbar be a module level variable that users could set, and we ran into a tonne of problems and edge cases due to this global state/the fact that functions weren't pure.

It mainly showed up in several places:

  • Parallelization. E.g., users using threads, or dask, or even pytest running tests in parallel, would cause race conditions.

  • there would sometimes be bugs in tests because the preceding test wouldn't properly do a tear down

  • developers would sometimes alter these values internally, and forget to change them back. Or there would be an exception before it could change back (and the developer didn't use try-finally).

So before usage gets baked in, might be good to turn these into arguments.

Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good to me 💯 I suggest a autograph configuration kwarg for qjit but not blocking

@dime10
Copy link
Collaborator Author

dime10 commented Sep 19, 2023

I was originally thinking of similar configurations in JAX which are also global, like jax.config.update("jax_debug_nans", True). However, you raise some good points so I'm thinking we could add something like this:

@qjit(autograph=True, ag_mode=...)

Where ag_mode can be one of the following:

  • "strict": Always raise an error with the full traceback when a conversion fails. Enable this
    to debug why a conversion may have failed, or when successful conversion is critical.
  • "permissive": The default & safe option. Catalyst will never abort compilation because
    an AutoGraph conversion failed, but some errors are re-raised as warnings.
  • "silent": Catalyst will not raise any warnings or errors and will fall back to Python
    silently whenever AutoGraph conversion fails.

However, there is a problem because there is nothing actually linking the qjit object to the eventual invocation of the autograph transformed for loop. I think this is true in general for all our primitives, so there is no obvious way to pass compilation options into any of them.
(I just discovered another use of a global here actually, but it's only used internally:

JaxTape.device = device
)

@dime10 dime10 merged commit 8106beb into main Sep 21, 2023
18 checks passed
@dime10 dime10 deleted the python_for branch September 21, 2023 16:37
@josh146
Copy link
Member

josh146 commented Sep 21, 2023

  • "strict": Always raise an error with the full traceback when a conversion fails. Enable this to debug why a conversion may have failed, or when successful conversion is critical.
  • "permissive": The default & safe option. Catalyst will never abort compilation because an AutoGraph conversion failed, but some errors are re-raised as warnings.
  • "silent": Catalyst will not raise any warnings or errors and will fall back to Python silently whenever AutoGraph conversion fails.

I really like this 🙌 I'll add it to Q4 as a potential story.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants