Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plxpr can capture operations #5511

Merged
merged 63 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
dd1f380
first pass
albi3ro Apr 12, 2024
83e4fe7
add module
dwierichs Apr 12, 2024
22fbadc
changelog
dwierichs Apr 12, 2024
156a749
Merge branch 'master' into add-capture-module
dwierichs Apr 12, 2024
ad7b637
import, fix
dwierichs Apr 12, 2024
d9b8b8a
git
dwierichs Apr 12, 2024
1634916
tests
dwierichs Apr 12, 2024
f6ba19d
Merge branch 'master' into add-capture-module
dwierichs Apr 12, 2024
15559d6
move switches to switches.py
dwierichs Apr 12, 2024
c511fb5
lint
dwierichs Apr 12, 2024
add82e9
identify all operators as jax primitives
albi3ro Apr 12, 2024
242cad1
Merge branch 'add-capture-module' into plxpr-capture-operations
albi3ro Apr 12, 2024
4c6c7bb
add dunder math support
albi3ro Apr 12, 2024
1d927b8
allow overriding primmitive bind call
albi3ro Apr 15, 2024
67a2069
improving testing
albi3ro Apr 15, 2024
45ab87a
fix up to allow using abc still
albi3ro Apr 15, 2024
c8de208
Update pennylane/capture/meta_type.py
albi3ro Apr 15, 2024
fe4ebcf
Update pennylane/operation.py
albi3ro Apr 15, 2024
27c78b8
adding some more documentation
albi3ro Apr 16, 2024
00c09e0
Merge branch 'plxpr-capture-operations' of https://github.com/PennyLa…
albi3ro Apr 16, 2024
68020ca
Merge branch 'master' into plxpr-capture-operations
dwierichs Apr 16, 2024
9747e1e
Apply suggestions from code review
albi3ro Apr 16, 2024
21d12e0
pow support
albi3ro Apr 16, 2024
58367bc
Merge branch 'plxpr-capture-operations' of https://github.com/PennyLa…
albi3ro Apr 16, 2024
c83a42a
minor fixes
albi3ro Apr 16, 2024
712a0cf
fix pauli rot
albi3ro Apr 16, 2024
f2fbe31
Update pennylane/capture/__init__.py
albi3ro Apr 17, 2024
8c2a4eb
responding to feedback
albi3ro Apr 17, 2024
121fbbe
move metaclass initialization to __init_subclass__
albi3ro Apr 18, 2024
cf8f2aa
responding to feedback, changelog
albi3ro Apr 19, 2024
136a5cd
improve testing for evaluating the jaxpr
albi3ro Apr 22, 2024
4ce8720
Merge branch 'master' into plxpr-capture-operations
albi3ro Apr 22, 2024
e35f602
Apply suggestions from code review
albi3ro Apr 23, 2024
d6bd42d
[skip ci] responding to feedback and polishing
albi3ro Apr 23, 2024
afd22f6
Merge branch 'master' into plxpr-capture-operations
albi3ro Apr 23, 2024
8b03ab0
Merge branch 'master' into plxpr-capture-operations
albi3ro Apr 24, 2024
a1bb1f9
Merge branch 'master' into plxpr-capture-operations
albi3ro Apr 25, 2024
c1c26c4
Update pennylane/capture/meta_type.py
albi3ro Apr 25, 2024
3b5cdc9
[skip ci] rename to CaptureMeta, create primitives file
albi3ro May 3, 2024
2b70e77
Merge branch 'master' into plxpr-capture-operations
albi3ro May 6, 2024
1c03b95
add source code clarification
albi3ro May 6, 2024
3792377
Merge branch 'plxpr-capture-operations' of https://github.com/PennyLa…
albi3ro May 6, 2024
9bc466f
changelog
albi3ro May 6, 2024
8a1fead
Update tests/capture/test_operators.py
albi3ro May 6, 2024
43aba1d
Update pennylane/capture/__init__.py
albi3ro May 7, 2024
d2864a4
Merge branch 'master' into plxpr-capture-operations
albi3ro May 7, 2024
e269c9a
Update tests/capture/test_operators.py
albi3ro May 7, 2024
41a24ef
Update tests/capture/test_operators.py
albi3ro May 7, 2024
4f74c34
Apply suggestions from code review
albi3ro May 7, 2024
867d54f
responding to feedback
albi3ro May 7, 2024
8796db8
merge
albi3ro May 7, 2024
3220f10
final code review responses
albi3ro May 7, 2024
c0be87f
minor fixes
albi3ro May 8, 2024
ac7ea0a
Update pennylane/capture/primitives.py
albi3ro May 9, 2024
66a9d7d
Merge branch 'master' into plxpr-capture-operations
albi3ro May 9, 2024
63b5f4b
pylint
albi3ro May 9, 2024
d190afe
Merge branch 'plxpr-capture-operations' of https://github.com/PennyLa…
albi3ro May 9, 2024
2fa7799
remove trailing whitespace
albi3ro May 9, 2024
ff9fe69
Update tests/capture/test_operators.py
albi3ro May 9, 2024
ab5f845
update changelog and some phrasing
albi3ro May 10, 2024
5108691
Merge branch 'master' into plxpr-capture-operations
albi3ro May 10, 2024
073b81d
Update pennylane/capture/__init__.py
albi3ro May 10, 2024
34133fe
Merge branch 'master' into plxpr-capture-operations
albi3ro May 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@
* Sets up the framework for the development of an `assert_equal` function for testing operator comparison.
[(#5634)](https://github.com/PennyLaneAI/pennylane/pull/5634)

* PennyLane operators can now automatically be captured as instructions in JAXPR. See the experimental
`capture` module for more information.
[(#5511)](https://github.com/PennyLaneAI/pennylane/pull/5511)

* The `decompose` transform has an `error` kwarg to specify the type of error that should be raised,
allowing error types to be more consistent with the context the `decompose` function is used in.
[(#5669)](https://github.com/PennyLaneAI/pennylane/pull/5669)
Expand Down
73 changes: 73 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,78 @@
>>> qml.capture.enabled()
False

**Custom Operator Behaviour**

Any operator that inherits from :class:`~.Operator` gains a default ability to be captured
in a Jaxpr. Any positional argument is bound as a tracer, wires are processed out into individual tracers,
and any keyword arguments are passed as keyword metadata.

.. code-block:: python

class MyOp1(qml.operation.Operator):

def __init__(self, arg1, wires, key=None):
super().__init__(arg1, wires=wires)

def qfunc(a):
MyOp1(a, wires=(0,1), key="a")

qml.capture.enable()
print(jax.make_jaxpr(qfunc)(0.1))

.. code-block::

{ lambda ; a:f32[]. let
_:AbstractOperator() = MyOp1[key=a n_wires=2] a 0 1
in () }

But an operator developer may need to override custom behavior for calling ``cls._primitive.bind``
(where ``cls`` indicates the class) if:

* The operator does not accept wires, like :class:`~.SymbolicOp` or :class:`~.CompositeOp`.
* The operator needs to enforce a data/ metadata distinction, like :class:`~.PauliRot`.

In such cases, the operator developer can override ``cls._primitive_bind_call``, which
will be called when constructing a new class instance instead of ``type.__call__``. For example,

.. code-block:: python

class JustMetadataOp(qml.operation.Operator):

def __init__(self, metadata):
super().__init__(wires=[])
self._metadata = metadata

@classmethod
def _primitive_bind_call(cls, metadata):
return cls._primitive.bind(metadata=metadata)


def qfunc():
JustMetadataOp("Y")

qml.capture.enable()
print(jax.make_jaxpr(qfunc)())

.. code-block::

{ lambda ; . let _:AbstractOperator() = JustMetadataOp[metadata=Y] in () }

albi3ro marked this conversation as resolved.
Show resolved Hide resolved
As you can see, the input ``"Y"``, while being passed as a positional argument, is converted to
metadata within the custom ``_primitive_bind_call`` method.

If needed, developers can also override the implementation method of the primitive like was done with ``Controlled``.
``Controlled`` needs to do so to handle packing and unpacking the control wires.

.. code-block:: python

class MyCustomOp(qml.operation.Operator):
pass

@MyCustomOp._primitive.def_impl
def _(*args, **kwargs):
return type.__call__(MyCustomOp, *args, **kwargs)
"""
from .switches import disable, enable, enabled
from .capture_meta import CaptureMeta
from .primitives import create_operator_primitive
46 changes: 46 additions & 0 deletions pennylane/capture/capture_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Defines a metaclass for automatic integration of any ``Operator`` with plxpr program capture.

See ``explanations.md`` for technical explanations of how this works.
"""

from .switches import enabled


# pylint: disable=no-self-argument, too-few-public-methods
class CaptureMeta(type):
"""A metatype that dispatches class creation to ``cls._primitve_bind_call`` instead
of normal class creation.

See ``pennylane/capture/explanations.md`` for more detailed information on how this technically
works.
"""

def _primitive_bind_call(cls, *args, **kwargs):
raise NotImplementedError(
"Types using CaptureMeta must implement cls._primitive_bind_call to"
" gain integration with plxpr program capture."
)

def __call__(cls, *args, **kwargs):
# this method is called everytime we want to create an instance of the class.
# default behavior uses __new__ then __init__

if enabled():
# when tracing is enabled, we want to
# use bind to construct the class if we want class construction to add it to the jaxpr
return cls._primitive_bind_call(*args, **kwargs)
return type.__call__(cls, *args, **kwargs)
237 changes: 237 additions & 0 deletions pennylane/capture/explanations.md
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
This documentation explains the principles behind `qml.capture.CaptureMeta`.


```python
import jax
```

# Primitive basics


```python
my_func_prim = jax.core.Primitive("my_func")

@my_func_prim.def_impl
def _(x):
return x**2

@my_func_prim.def_abstract_eval
def _(x):
return jax.core.ShapedArray((1,), x.dtype)

def my_func(x):
return my_func_prim.bind(x)
```


```python
>>> jaxpr = jax.make_jaxpr(my_func)(0.1)
>>> jaxpr
{ lambda ; a:f32[]. let b:f32[1] = my_func a in (b,) }
>>> jaxpr.jaxpr.eqns
[a:f32[1] = my_func b]
```

## Metaprogramming


```python
class MyMetaClass(type):

def __init__(cls, *args, **kwargs):
print(f"Creating a new type {cls} with {args}, {kwargs}. ")

# giving every class a property
cls.a = "a"

def __call__(cls, *args, **kwargs):
print(f"creating an instance of type {cls} with {args}, {kwargs}. ")
inst = cls.__new__(cls, *args, **kwargs)
inst.__init__(*args, **kwargs)
return inst
```

Now let's define a class with this meta class.

You can see that when we *define* the class, we have called `MyMetaClass.__init__` to create the new type


```python
class MyClass(metaclass=MyMetaClass):

def __init__(self, *args, **kwargs):
print("now creating an instance in __init__")
self.args = args
self.kwargs = kwargs
```

Creating a new type <class '__main__.MyClass'> with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': <function MyClass.__init__ at 0x11c59cae0>}), {}.


And that we have set a class property `a`


```python
>>> MyClass.a
'a'
```

But can we actually create instances of these classes?


```python
>> obj = MyClass(0.1, a=2)
>>> obj
creating an instance of type <class '__main__.MyClass'> with (0.1,), {'a': 2}.
now creating an instance in __init__
<__main__.MyClass at 0x11c5a2810>
```


So far, we've just added print statements around default behavior. Let's try something more radical


```python
class MetaClass2(type):

def __call__(cls, *args, **kwargs):
return 2.0

class MyClass2(metaclass=MetaClass2):

def __init__(self, *args, **kwargs):
print("Am I here?")
self.args = args
```

You can see now that instead of actually getting an instance of `MyClass2`, we just get `2.0`.

Using a metaclass, we can hijack what happens when a type is called.


```python
>>> out = MyClass2(1.0)
>>> out, out == 2.0
(2.0, True)
```

## Putting Primitives and Metaprogramming together

We have two goals that we need to accomplish with our meta class.

1. Create an associated primitive every time we define a new class type
2. Hijack creating a new instance to use `primitive.bind` instead


```python
class PrimitiveMeta(type):

def __init__(cls, *args, **kwargs):
# here we set up the primitive
primitive = jax.core.Primitive(cls.__name__)

@primitive.def_impl
def _(*inner_args, **inner_kwargs):
# just normal class creation if not tracing
return type.__call__(cls, *inner_args, **inner_kwargs)

@primitive.def_abstract_eval
def _(*inner_args, **inner_kwargs):
# here we say that we just return an array of type float32 and shape (1,)
# other abstract types could be used instead
return jax.core.ShapedArray((1,), jax.numpy.float32)

cls._primitive = primitive

def __call__(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)
```


```python
class PrimitiveClass(metaclass=PrimitiveMeta):

def __init__(self, a):
self.a = a

def __repr__(self):
return f"PrimitiveClass({self.a})"
```

What happens if we just create a class normally as is?


```python
>>> PrimitiveClass(1.0)
PrimitiveClass(1.0)
```

But now it can also be used in tracing as well


```python
>>> jax.make_jaxpr(PrimitiveClass)(1.0)
{ lambda ; a:f32[]. let b:f32[1] = PrimitiveClass a in (b,) }
```

Great!👍

Now you can see that the problem is that we lied in our definition of abstract evaluation. Jax thinks that `PrimitiveClass` returns something of shape `(1,)` and type `float32`.

But jax doesn't have an abstract type that really describes "PrimitiveClass". So we need to define an register our own.


```python
class AbstractPrimitiveClass(jax.core.AbstractValue):

def __eq__(self, other):
return isinstance(other, AbstractPrimitiveClass)

def __hash__(self):
return hash("AbstractPrimitiveClass")

jax.core.raise_to_shaped_mappings[AbstractPrimitiveClass] = lambda aval, _: aval
```

Now we can redefine our class to use this abstract class


```python
class PrimitiveMeta2(type):

def __init__(cls, *args, **kwargs):
# here we set up the primitive
primitive = jax.core.Primitive(cls.__name__)

@primitive.def_impl
def _(*inner_args, **inner_kwargs):
# just normal class creation if not tracing
return type.__call__(cls, *inner_args, **inner_kwargs)

@primitive.def_abstract_eval
def _(*inner_args, **inner_kwargs):
# here we say that we just return an array of type float32 and shape (1,)
# other abstract types could be used instead
return AbstractPrimitiveClass()

cls._primitive = primitive

def __call__(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)

class PrimitiveClass2(metaclass=PrimitiveMeta2):

def __init__(self, a):
self.a = a

def __repr__(self):
return f"PrimitiveClass({self.a})"
```

Now in our jaxpr, we can see thet `PrimitiveClass2` returns something of type `AbstractPrimitiveClass`.


```python
>>> jax.make_jaxpr(PrimitiveClass2)(0.1)
{ lambda ; a:f32[]. let b:AbstractPrimitiveClass() = PrimitiveClass2 a in (b,) }
```
Loading
Loading