Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Add AutoGraph support for Python for loops #258

Merged
merged 25 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e9a2154
Add basic support for Python for loops
dime10 Aug 22, 2023
e05cd6e
Add support for `range` in for loops
dime10 Aug 22, 2023
3b94c0d
Add reference to for_loop function
dime10 Aug 23, 2023
e5a9965
Fix construction of captured ranges
dime10 Aug 23, 2023
e9c91e1
Add support for enumerate inside of for loops
dime10 Aug 23, 2023
a08ca7a
Test support for unpacking loop values
dime10 Aug 25, 2023
53e2848
Fix various issues reported by Pylint
dime10 Aug 28, 2023
63d2980
Typos
dime10 Sep 12, 2023
5140b2e
Perform retracing on Python for loop
dime10 Sep 12, 2023
78a31eb
Add extraction of original source code information
dime10 Sep 12, 2023
85e5a6a
Add tensorflow-cpu to requirements.txt
dime10 Sep 12, 2023
1f2f2e9
Fix missing autograph_artifact in ag_primitives
dime10 Sep 12, 2023
3274162
Linting
dime10 Sep 12, 2023
8a1b854
Simplify pytest mocking
dime10 Sep 12, 2023
4a158ba
Remove hardcoded line numbers in tests
dime10 Sep 13, 2023
515dbcf
Improve clarity of code comment
dime10 Sep 15, 2023
063e547
Improve changelog & docstrings
dime10 Sep 15, 2023
5341c3a
Improve patch coverage to 100%
dime10 Sep 15, 2023
8219ffa
Add user flags for AG strictness
dime10 Sep 19, 2023
baec087
Restore the initial variable state before fallback
dime10 Sep 19, 2023
dfd98a7
Fix variable tracking in loop body tracing
dime10 Sep 19, 2023
8d2a40c
Add error checks around uninitialized loop values
dime10 Sep 19, 2023
c86b957
Add test involving if/cond/for/for_loop
dime10 Sep 19, 2023
3bd9ed2
Satisfy Pylint
dime10 Sep 19, 2023
90b1a49
Fix patch coverage
dime10 Sep 19, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/check-catalyst.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ jobs:
run: |
sudo apt-get install -y python3 python3-pip libomp-dev
python3 -m pip install -r requirements.txt
python3 -m pip install tensorflow-cpu # for autograph tests
python3 -m pip install .

- name: Get Cached LLVM Build
Expand Down Expand Up @@ -369,7 +368,6 @@ jobs:
run: |
sudo apt-get install -y python3 python3-pip libomp-dev
python3 -m pip install -r requirements.txt
python3 -m pip install tensorflow-cpu # for autograph tests
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
python3 -m pip install .

- name: Get Cached LLVM Build
Expand Down
6 changes: 5 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ extension-pkg-allow-list=catalyst.utils.wrapper
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=pennylane.ops,jaxlib.mlir.ir
ignored-modules=pennylane.ops,jaxlib.mlir.ir,jaxlib.xla_extension,tensorflow

# List of classes names for which member attributes should not be checked
# (useful for classes with attributes dynamically set). This supports can work
Expand All @@ -20,6 +20,10 @@ ignored-modules=pennylane.ops,jaxlib.mlir.ir
# List of file/directory patters to ignore.
ignore=test,llvm-project,mlir-hlo,.git,__pycache__,build,doc

# List of decorators that change a function's signature. Will ignore certain errors
# like 'no-value-for-parameter' on functions with these decorators.
signature-mutators=catalyst.pennylane_extensions.for_loop

[MESSAGES CONTROL]

# Enable the message, report, category or checker with the given id(s). You can
Expand Down
4 changes: 4 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

<h3>New features</h3>

* Catalyst users can now use Python for loop statements in their programs without having to
explicitly use the functional `catalyst.for_loop` form!
[#258](https://github.com/PennyLaneAI/catalyst/pull/258)

<h3>Improvements</h3>

* Update the Lightning backend device to work with the PL-Lightning monorepo.
Expand Down
4 changes: 2 additions & 2 deletions doc/dev/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ additionally need to install `Rust <https://www.rust-lang.org/tools/install>`_ a
If the CMake version available in your system is too old, you can also install up-to-date
versions of it via ``pip install cmake``.

All additional build and developer depencies are managed via the repository's ``requirements.txt``
and can be installed as follows:
All additional build and developer dependencies are managed via the repository's
dime10 marked this conversation as resolved.
Show resolved Hide resolved
``requirements.txt`` and can be installed as follows:

.. code-block:: console

Expand Down
279 changes: 275 additions & 4 deletions frontend/catalyst/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
functions. The purpose is to convert imperative style code to functional or graph-style code."""

import functools
from typing import Any, Callable, Tuple
import warnings
from typing import Any, Callable, Iterator, SupportsIndex, Tuple

import jax.numpy as jnp

# Use tensorflow implementations for handling function scopes and calls,
# as well as various utility objects.
Expand All @@ -29,11 +32,13 @@
FunctionScope,
with_function_scope,
)
from tensorflow.python.autograph.impl.api import autograph_artifact
from tensorflow.python.autograph.impl.api import converted_call as tf_converted_call
from tensorflow.python.autograph.operators.variables import (
Undefined,
UndefinedReturnValue,
)
from tensorflow.python.autograph.pyct.origin_info import LineLocation

import catalyst
from catalyst.ag_utils import AutoGraphError
Expand All @@ -44,6 +49,7 @@
"ConversionOptions",
"Undefined",
"UndefinedReturnValue",
"autograph_artifact",
"FunctionScope",
"with_function_scope",
"if_stmt",
Expand Down Expand Up @@ -102,23 +108,202 @@
set_state(results)


def _call_catalyst_for(start, stop, step, body_fn, get_state, enum_start=None, array_iterable=None):
"""Dispatch to a Catalyst implementation of for loops."""

@catalyst.for_loop(start, stop, step)
def functional_for(i):
if enum_start is None and array_iterable is None:
# for i in range(..)
body_fn(i)
elif enum_start is None:
# for x in array
body_fn(array_iterable[i])
else:
# for (i, x) in enumerate(array)
body_fn((i + enum_start, array_iterable[i]))

return get_state()

return functional_for()


def _call_python_for(body_fn, get_state, non_array_iterable):
"""Fallback to a Python implementation of for loops."""

for elem in non_array_iterable:
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
body_fn(elem)

return get_state()


def for_stmt(
iteration_target: Any,
_extra_test: Callable[[], bool] | None,
body_fn: Callable[[int], None],
get_state: Callable[[], Tuple],
set_state: Callable[[Tuple], None],
_symbol_names: Tuple[str],
_opts: dict,
):
"""An implementation of the AutoGraph 'for .. in ..' statement. The interface is defined by
AutoGraph, here we merely provide an implementation of it in terms of Catalyst primitives."""

assert _extra_test is None

# The general approach is to convert as much code as possible into a graph-based form:
# - For loops over iterables will attempt a conversion of the iterable to array, and fall back
# to Python otherwise.
dime10 marked this conversation as resolved.
Show resolved Hide resolved
# - For loops over a Python range will be converted to a native Catalyst for loop. However,
# since the now dynamic iteration variable can cause issues in downstream user code, any
# errors raised during the tracing of the loop body will restart the tracing process using
# a Python loop instead.
dime10 marked this conversation as resolved.
Show resolved Hide resolved
# - For loops over a Python enumeration use a combination of the above, providing a dynamic
# iteration variable and conversion of the iterable to array. If either fails, a fallback to
# Python is used.
# The fallback mechanism for tracing errors will raise a warning as the user may need to be
# aware that the graph conversion failed, for instance for lack of converting lists into arrays,
# but the conversion from iterable to array will fall back to Python silently.
dime10 marked this conversation as resolved.
Show resolved Hide resolved
fallback = False

if isinstance(iteration_target, CRange):
start, stop, step = iteration_target.get_raw_range()
enum_start = None
iteration_array = None
elif isinstance(iteration_target, CEnumerate):
start, stop, step = 0, len(iteration_target.iteration_target), 1
enum_start = iteration_target.start_idx
try:
iteration_array = jnp.asarray(iteration_target.iteration_target)
except: # pylint: disable=bare-except
iteration_array = None
fallback = True
else:
start, stop, step = 0, len(iteration_target), 1
enum_start = None
try:
iteration_array = jnp.asarray(iteration_target)
except: # pylint: disable=bare-except
iteration_array = None
fallback = True

# Attempt to trace the Catalyst for loop.
if not fallback:
try:
results = _call_catalyst_for(
start, stop, step, body_fn, get_state, enum_start, iteration_array
)
except Exception as e: # pylint: disable=broad-exception-caught
fallback = True
# pylint: disable=import-outside-toplevel
import inspect
import textwrap

for_loop_info = get_source_code_info(inspect.stack()[1])

warnings.warn(
f"Tracing of an AutoGraph converted for loop failed with the following exception:\n"
f" {type(e).__name__}:{textwrap.indent(str(e), ' ')}\n"
f"\n"
f"The error ocurred within the body of the following for loop statement:\n"
f"{for_loop_info}"
f"\n"
f"If you intended for the conversion to happen, make sure that the (now dynamic) "
f"loop variable is not used in tracing-incompatible ways, for instance by indexing "
f"a Python list with it. In that case, the list should be wrapped into an array.\n"
f"To understand different types of JAX tracing errors, please refer to the guide "
f"at: https://jax.readthedocs.io/en/latest/errors.html\n"
f"\n"
f"If you did not intend for the conversion to happen, you may safely ignore this "
f"warning."
dime10 marked this conversation as resolved.
Show resolved Hide resolved
)

# If anything goes wrong, we fall back to Python.
if fallback:
results = _call_python_for(body_fn, get_state, iteration_target)

# Sometimes we unpack the results of nested tracing scopes so that the user doesn't have to
# manipulate tuples when they don't expect it. Ensure set_state receives a tuple regardless.
if not isinstance(results, tuple):
results = (results,)
set_state(results)


def get_source_code_info(tb_frame):
"""Attempt to obtain original source code information for an exception raised within AutoGraph
transformed code.

Uses introspection on the call stack to extract the source map record from within AutoGraph
statements. However, it is not guaranteed to find the source map and may return nothing.
"""
import inspect # pylint: disable=import-outside-toplevel

ag_source_map = None

# Traverse frames in reverse to find caller with `ag_source_map` property:
# - function: directly on the callable object
# - qnode method: on the self object
# - qjit method: on the self.user_function object
try:
for frame in inspect.stack():
if frame.function == "converted_call" and "converted_f" in frame.frame.f_locals:
obj = frame.frame.f_locals["converted_f"]
ag_source_map = obj.ag_source_map
break
if "self" in frame.frame.f_locals:
obj = frame.frame.f_locals["self"]
if isinstance(obj, qml.QNode):
ag_source_map = obj.ag_source_map
break
if isinstance(obj, catalyst.QJIT):
ag_source_map = obj.user_function.ag_source_map
break
except: # nosec B110 # pylint: disable=bare-except
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
pass

Check warning on line 262 in frontend/catalyst/ag_primitives.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/ag_primitives.py#L261-L262

Added lines #L261 - L262 were not covered by tests

loc = LineLocation(tb_frame.filename, tb_frame.lineno)
if ag_source_map is not None and loc in ag_source_map:
function_name = ag_source_map[loc].function_name
filename = ag_source_map[loc].loc.filename
lineno = ag_source_map[loc].loc.lineno
source_code = ag_source_map[loc].source_code_line.strip()
else:
function_name = tb_frame.name
filename = tb_frame.filename
lineno = tb_frame.lineno
source_code = tb_frame.line

Check warning on line 274 in frontend/catalyst/ag_primitives.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/ag_primitives.py#L271-L274

Added lines #L271 - L274 were not covered by tests

info = f' File "{filename}", line {lineno}, in {function_name}\n' f" {source_code}\n"

return info


# Prevent autograph from converting PennyLane and Catalyst library code, this can lead to many
# issues such as always tracing through code that should only be executed conditionally. We might
# have to be even more restrictive in the future to prevent issues if necessary.
module_allowlist = (
config.DoNotConvert("pennylane"),
config.DoNotConvert("catalyst"),
config.DoNotConvert("jax"),
) + config.CONVERSION_RULES


def converted_call(fn, *args, **kwargs):
def converted_call(fn, args, kwargs, caller_fn_scope=None, options=None):
"""We want AutoGraph to use our own instance of the AST transformer when recursively
transforming functions, but otherwise duplicate the same behaviour."""

with Patcher(
(tf_autograph_api, "_TRANSPILER", catalyst.autograph.TRANSFORMER),
(config, "CONVERSION_RULES", module_allowlist),
):
# Dispatch range calls to a custom range class that enables constructs like
# `for .. in range(..)` to be converted natively to `for_loop` calls. This is beneficial
# since the Python range function does not allow tracers as arguments.
if fn is range:
return CRange(*args, **(kwargs if kwargs is not None else {}))
elif fn is enumerate:
return CEnumerate(*args, **(kwargs if kwargs is not None else {}))

# We need to unpack nested QNode and QJIT calls as autograph will have trouble handling
# them. Ideally, we only want the wrapped function to be transformed by autograph, rather
# than the QNode or QJIT call method.
Expand All @@ -134,9 +319,95 @@

@functools.wraps(fn.func)
def qnode_call_wrapper():
return tf_converted_call(fn.func, *args, **kwargs)
return tf_converted_call(fn.func, args, kwargs, caller_fn_scope, options)

new_qnode = qml.QNode(qnode_call_wrapper, device=fn.device, diff_method=fn.diff_method)
return new_qnode()

return tf_converted_call(fn, *args, **kwargs)
return tf_converted_call(fn, args, kwargs, caller_fn_scope, options)


class CRange:
josh146 marked this conversation as resolved.
Show resolved Hide resolved
"""Catalyst range object.

Can be passed to a Python for loop for native conversion to a for_loop call.
Otherwise this class behaves exactly like the Python range class.

Without this native conversion, all iteration targets in a Python for loop must be convertible
to arrays. For all other inputs the loop will be treated as a regular Python loop.
"""

def __init__(self, start_stop, stop=None, step=None):
self._py_range = None
self._start = start_stop if stop is not None else 0
self._stop = stop if stop is not None else start_stop
self._step = step if step is not None else 1

def get_raw_range(self):
"""Get the raw values defining this range: start, stop, step."""
return self._start, self._stop, self._step

@property
def py_range(self):
"""Access the underlying Python range object. If it doesn't exist, create one."""
if self._py_range is None:
self._py_range = range(self._start, self._stop, self._step)
return self._py_range

# Interface of the Python range class.
# pylint: disable=missing-function-docstring

@property
def start(self) -> int:
return self.py_range.start

Check warning on line 362 in frontend/catalyst/ag_primitives.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/ag_primitives.py#L362

Added line #L362 was not covered by tests
dime10 marked this conversation as resolved.
Show resolved Hide resolved

@property
def stop(self) -> int:
return self.py_range.stop

Check warning on line 366 in frontend/catalyst/ag_primitives.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/ag_primitives.py#L366

Added line #L366 was not covered by tests

@property
def step(self) -> int:
return self.py_range.step

Check warning on line 370 in frontend/catalyst/ag_primitives.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/ag_primitives.py#L370

Added line #L370 was not covered by tests

def count(self, __value: int) -> int:
return self.py_range.count(__value)

Check warning on line 373 in frontend/catalyst/ag_primitives.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/ag_primitives.py#L373

Added line #L373 was not covered by tests

def index(self, __value: int) -> int:
return self.py_range.index(__value)

Check warning on line 376 in frontend/catalyst/ag_primitives.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/ag_primitives.py#L376

Added line #L376 was not covered by tests

def __len__(self) -> int:
return self.py_range.__len__()

Check warning on line 379 in frontend/catalyst/ag_primitives.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/ag_primitives.py#L379

Added line #L379 was not covered by tests

def __eq__(self, __value: object) -> bool:
return self.py_range.__eq__(__value)

Check warning on line 382 in frontend/catalyst/ag_primitives.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/ag_primitives.py#L382

Added line #L382 was not covered by tests

def __hash__(self) -> int:
return self.py_range.__hash__()

Check warning on line 385 in frontend/catalyst/ag_primitives.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/ag_primitives.py#L385

Added line #L385 was not covered by tests

def __contains__(self, __key: object) -> bool:
return self.py_range.__contains__(__key)

Check warning on line 388 in frontend/catalyst/ag_primitives.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/ag_primitives.py#L388

Added line #L388 was not covered by tests

def __iter__(self) -> Iterator[int]:
return self.py_range.__iter__()

def __getitem__(self, __key: SupportsIndex | slice) -> int | range:
self.py_range.__getitem__(__key)

Check warning on line 394 in frontend/catalyst/ag_primitives.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/ag_primitives.py#L394

Added line #L394 was not covered by tests

def __reversed__(self) -> Iterator[int]:
self.py_range.__reversed__()

Check warning on line 397 in frontend/catalyst/ag_primitives.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/ag_primitives.py#L397

Added line #L397 was not covered by tests


class CEnumerate(enumerate):
"""Catalyst enumeration object.

Can be passed to a Python for loop for conversion into a for_loop call. The loop index, as well
as the iterable element will be provided to the loop body.
Otherwise this class behaves exactly like the Python enumerate class.

Note that the iterable must be convertible to an array, otherwise the loop will be treated as a
regular Python loop.
"""

def __init__(self, iterable, start=0):
self.iteration_target = iterable
self.start_idx = start
Loading
Loading