Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plxpr can capture operations #5511

Merged
merged 63 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
dd1f380
first pass
albi3ro Apr 12, 2024
83e4fe7
add module
dwierichs Apr 12, 2024
22fbadc
changelog
dwierichs Apr 12, 2024
156a749
Merge branch 'master' into add-capture-module
dwierichs Apr 12, 2024
ad7b637
import, fix
dwierichs Apr 12, 2024
d9b8b8a
git
dwierichs Apr 12, 2024
1634916
tests
dwierichs Apr 12, 2024
f6ba19d
Merge branch 'master' into add-capture-module
dwierichs Apr 12, 2024
15559d6
move switches to switches.py
dwierichs Apr 12, 2024
c511fb5
lint
dwierichs Apr 12, 2024
add82e9
identify all operators as jax primitives
albi3ro Apr 12, 2024
242cad1
Merge branch 'add-capture-module' into plxpr-capture-operations
albi3ro Apr 12, 2024
4c6c7bb
add dunder math support
albi3ro Apr 12, 2024
1d927b8
allow overriding primmitive bind call
albi3ro Apr 15, 2024
67a2069
improving testing
albi3ro Apr 15, 2024
45ab87a
fix up to allow using abc still
albi3ro Apr 15, 2024
c8de208
Update pennylane/capture/meta_type.py
albi3ro Apr 15, 2024
fe4ebcf
Update pennylane/operation.py
albi3ro Apr 15, 2024
27c78b8
adding some more documentation
albi3ro Apr 16, 2024
00c09e0
Merge branch 'plxpr-capture-operations' of https://github.com/PennyLa…
albi3ro Apr 16, 2024
68020ca
Merge branch 'master' into plxpr-capture-operations
dwierichs Apr 16, 2024
9747e1e
Apply suggestions from code review
albi3ro Apr 16, 2024
21d12e0
pow support
albi3ro Apr 16, 2024
58367bc
Merge branch 'plxpr-capture-operations' of https://github.com/PennyLa…
albi3ro Apr 16, 2024
c83a42a
minor fixes
albi3ro Apr 16, 2024
712a0cf
fix pauli rot
albi3ro Apr 16, 2024
f2fbe31
Update pennylane/capture/__init__.py
albi3ro Apr 17, 2024
8c2a4eb
responding to feedback
albi3ro Apr 17, 2024
121fbbe
move metaclass initialization to __init_subclass__
albi3ro Apr 18, 2024
cf8f2aa
responding to feedback, changelog
albi3ro Apr 19, 2024
136a5cd
improve testing for evaluating the jaxpr
albi3ro Apr 22, 2024
4ce8720
Merge branch 'master' into plxpr-capture-operations
albi3ro Apr 22, 2024
e35f602
Apply suggestions from code review
albi3ro Apr 23, 2024
d6bd42d
[skip ci] responding to feedback and polishing
albi3ro Apr 23, 2024
afd22f6
Merge branch 'master' into plxpr-capture-operations
albi3ro Apr 23, 2024
8b03ab0
Merge branch 'master' into plxpr-capture-operations
albi3ro Apr 24, 2024
a1bb1f9
Merge branch 'master' into plxpr-capture-operations
albi3ro Apr 25, 2024
c1c26c4
Update pennylane/capture/meta_type.py
albi3ro Apr 25, 2024
3b5cdc9
[skip ci] rename to CaptureMeta, create primitives file
albi3ro May 3, 2024
2b70e77
Merge branch 'master' into plxpr-capture-operations
albi3ro May 6, 2024
1c03b95
add source code clarification
albi3ro May 6, 2024
3792377
Merge branch 'plxpr-capture-operations' of https://github.com/PennyLa…
albi3ro May 6, 2024
9bc466f
changelog
albi3ro May 6, 2024
8a1fead
Update tests/capture/test_operators.py
albi3ro May 6, 2024
43aba1d
Update pennylane/capture/__init__.py
albi3ro May 7, 2024
d2864a4
Merge branch 'master' into plxpr-capture-operations
albi3ro May 7, 2024
e269c9a
Update tests/capture/test_operators.py
albi3ro May 7, 2024
41a24ef
Update tests/capture/test_operators.py
albi3ro May 7, 2024
4f74c34
Apply suggestions from code review
albi3ro May 7, 2024
867d54f
responding to feedback
albi3ro May 7, 2024
8796db8
merge
albi3ro May 7, 2024
3220f10
final code review responses
albi3ro May 7, 2024
c0be87f
minor fixes
albi3ro May 8, 2024
ac7ea0a
Update pennylane/capture/primitives.py
albi3ro May 9, 2024
66a9d7d
Merge branch 'master' into plxpr-capture-operations
albi3ro May 9, 2024
63b5f4b
pylint
albi3ro May 9, 2024
d190afe
Merge branch 'plxpr-capture-operations' of https://github.com/PennyLa…
albi3ro May 9, 2024
2fa7799
remove trailing whitespace
albi3ro May 9, 2024
ff9fe69
Update tests/capture/test_operators.py
albi3ro May 9, 2024
ab5f845
update changelog and some phrasing
albi3ro May 10, 2024
5108691
Merge branch 'master' into plxpr-capture-operations
albi3ro May 10, 2024
073b81d
Update pennylane/capture/__init__.py
albi3ro May 10, 2024
34133fe
Merge branch 'master' into plxpr-capture-operations
albi3ro May 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/code/qml_capture.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
qml.capture
===========

.. currentmodule:: pennylane.capture

.. automodule:: pennylane.capture

2 changes: 1 addition & 1 deletion doc/code/qml_compiler.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
qml.compiler
===============
============

.. currentmodule:: pennylane.compiler

Expand Down
1 change: 1 addition & 0 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ PennyLane is **free** and **open source**, released under the Apache License, Ve
:caption: Internals
:hidden:

code/qml_capture
code/qml_devices
code/qml_measurements
code/qml_operation
Expand Down
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

<h3>New features since last release</h3>

* Added a qml.capture module that will contain PennyLane's own capturing mechanism for hybrid
quantum-classical programs.
[(#5509)](https://github.com/PennyLaneAI/pennylane/pull/5509)

* The `FABLE` template is added for efficient block encoding of matrices. Users can now call FABLE to efficiently construct circuits according to a user-set approximation level.
[(#5107)](https://github.com/PennyLaneAI/pennylane/pull/5107)

Expand Down
1 change: 1 addition & 0 deletions pennylane/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pennylane.numpy
from pennylane.queuing import QueuingManager, apply

import pennylane.capture
import pennylane.kernels
import pennylane.math
import pennylane.operation
Expand Down
101 changes: 101 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
.. currentmodule:: pennylane

This module implements PennyLane's capturing mechanism for hybrid
quantum-classical programs.

.. warning::

This module is experimental and will change significantly in the future.


To activate and deactivate the new PennyLane program capturing mechanism, use
the switches ``qml.capture.enable_plxpr`` and ``qml.capture.disable_plxpr``.
Whether or not the capturing mechanism is currently being used can be
queried with ``qml.capture.plxpr_enabled``.
By default, the mechanism is disabled:

.. code-block:: pycon

>>> import pennylane as qml
>>> qml.capture.plxpr_enabled()
False
>>> qml.capture.enable_plxpr()
>>> qml.capture.plxpr_enabled()
True
>>> qml.capture.disable_plxpr()
>>> qml.capture.plxpr_enabled()
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
False

**Custom Operator Behavior**
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

Any operation that inherits from :class:`~.Operator` gains a default ability to be captured
by jaxpr. Any positional argument is bound as a tracer, wires are processed out into individual tracers,
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
and any keyword args are passed as keyword metadata.

.. code-block:: python

class MyOp1(qml.operation.Operator):

def __init__(self, arg1, wires, key=None):
super().__init__(arg1, wires=wires)

def qfunc(a):
MyOp1(a, wires=(0,1), key="a")

qml.capture.enable_plxpr()
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
print(jax.make_jaxpr(qfunc)(0.1))

.. code-block::

{ lambda ; a:f32[]. let
_:AbstractOperator() = MyOp1[key=a n_wires=2] a 0 1
in () }
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

But an operator developer may need to override custom behavior for calling ``cls._primitive.bind`` if:
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

* The operator does not accept wires like :class:`~.SymbolicOp` or :class:`~.CompositeOp`.
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
* The operator allows metadata to be provided positionally, like :class:`~.PauliRot`.

In such cases, the operator developer can override ``cls._primitive_bind_call``. This is what
will be called when constructing a new class instance instead of ``type.__call__``. For example,
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

class WeirdOp(qml.operation.Operator):
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, metadata="X"):
dime10 marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(wires=[])
self._metadata = metadata

@classmethod
def _primitive_bind_call(cls, metadata):
return cls._primitive.bind(metadata=metadata)


def qfunc():
WeirdOp("Y")

qml.capture.enable_plxpr()
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
print(jax.make_jaxpr(qfunc)())

.. code-block::

{ lambda ; . let _:AbstractOperator() = WeirdOp[metadata=Y] in () }

albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""
from .switches import enable_plxpr, disable_plxpr, plxpr_enabled
from .meta_type import PLXPRMeta
238 changes: 238 additions & 0 deletions pennylane/capture/explanations.md
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
This documentation explains the principles behind `qml.capture.PLXPRMeta`.


```python
import jax
```

# Primitive basics


```python
my_func_prim = jax.core.Primitive("my_func")

@my_func_prim.def_impl
def _(x):
return x**2

@my_func_prim.def_abstract_eval
def _(x):
return jax.core.ShapedArray((1,), x.dtype)

def my_func(x):
return my_func_prim.bind(x)
```


```python
>>> jaxpr = jax.make_jaxpr(my_func)(0.1)
>>> jaxpr
{ lambda ; a:f32[]. let b:f32[1] = my_func a in (b,) }
>>> jaxpr.jaxpr.eqns
[a:f32[1] = my_func b]
```

## Metaprogramming


```python
class MyMetaClass(type):

def __init__(cls, *args, **kwargs):
print(f"Creating a new type {cls} with {args}, {kwargs}. ")

# giving every class a property
cls.a = "a"

def __call__(cls, *args, **kwargs):
print(f"creating an instance of type {cls} with {args}, {kwargs}. ")
inst = cls.__new__(cls, *args, **kwargs)
inst.__init__(*args, **kwargs)
return inst
```

Now let's define a class with this meta class.

You can see that when we *define* the class, we have called `MyMetaClass.__init__` to create the new type


```python
class MyClass(metaclass=MyMetaClass):

def __init__(self, *args, **kwargs):
print("now creating an instance in __init__")
self.args = args
self.kwargs = kwargs
```

Creating a new type <class '__main__.MyClass'> with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': <function MyClass.__init__ at 0x11c59cae0>}), {}.


And that we have set a class property `a`


```python
>>> MyClass.a
'a'
```

But can we actually create instances of these classes?


```python
>> obj = MyClass(0.1, a=2)
>>> obj
creating an instance of type <class '__main__.MyClass'> with (0.1,), {'a': 2}.
now creating an instance in __init__
<__main__.MyClass at 0x11c5a2810>
```


So far, we've just added print statements around default behavior. Let's try something more radical


```python
class MetaClass2(type):

def __call__(cls, *args, **kwargs):
return 2.0

class WeirdClass(metaclass=MetaClass2):
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, *args, **kwargs):
print("Am I here?")
self.args = args
```

You can see now that instead of actually getting an instance of `WeirdClass`, we just get `2.0`.

Using a metaclass, we can hijack what happens when a type is called.


```python
>>> out = WeirdClass(1.0)
>>> out, out == 2.0
(2.0, True)
```

## Putting Primitives and Metaprogramming together

We have two goals that we need to accomplish with our meta class.

1. Create an associated primitive every time we define a new class type
2. Hijack creating a new instance to use `primitive.bind` instead


```python
class PrimitiveMeta(type):

def __init__(cls, *args, **kwargs):
# here we set up the primitive
primitive = jax.core.Primitive(cls.__name__)

@primitive.def_impl
def _(*inner_args, **inner_kwargs):
# just normal class creation if not tracing
return type.__call__(cls, *inner_args, **inner_kwargs)

@primitive.def_abstract_eval
def _(*inner_args, **inner_kwargs):
# here we say that we just return an array of type float32 and shape (1,)
# other abstract types could be used instead
return jax.core.ShapedArray((1,), jax.numpy.float32)

cls._primitive = primitive

def __call__(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)
```


```python
class PrimitiveClass(metaclass=PrimitiveMeta):

def __init__(self, a):
self.a = a

def __repr__(self):
return f"PrimitiveClass({self.a})"
```

What happens if we just create a class normally as is?


```python
>>> PrimitiveClass(1.0)
PrimitiveClass(1.0)
```

But now it can also be used in tracing as well


```python
>>> jax.make_jaxpr(PrimitiveClass)(1.0)
{ lambda ; a:f32[]. let b:f32[1] = PrimitiveClass a in (b,) }
```

Great!👍

Now you can see that the problem is that we lied in our definition of abstract evaluation. Jax thinks that `PrimitiveClass` returns something of shape `(1,)` and type `float32`.

But jax doesn't have an abstract type that really describes "PrimitiveClass". So we need to define an register our own.


```python
class AbstractPrimitiveClass(jax.core.AbstractValue):

def __eq__(self, other):
return isinstance(other, AbstractPrimitiveClass)

def __hash__(self):
return hash("AbstractPrimitiveClass")

# not quite sure what it does, but we have to do this...
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
jax.core.raise_to_shaped_mappings[AbstractPrimitiveClass] = lambda aval, _: aval
```

Now we can redefine our class to use this abstract class


```python
class PrimitiveMeta2(type):

def __init__(cls, *args, **kwargs):
# here we set up the primitive
primitive = jax.core.Primitive(cls.__name__)

@primitive.def_impl
def _(*inner_args, **inner_kwargs):
# just normal class creation if not tracing
return type.__call__(cls, *inner_args, **inner_kwargs)

@primitive.def_abstract_eval
def _(*inner_args, **inner_kwargs):
# here we say that we just return an array of type float32 and shape (1,)
# other abstract types could be used instead
return AbstractPrimitiveClass()

cls._primitive = primitive

def __call__(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)

class PrimitiveClass2(metaclass=PrimitiveMeta2):

def __init__(self, a):
self.a = a

def __repr__(self):
return f"PrimitiveClass({self.a})"
```

Now in our jaxpr, we can see thet `PrimitiveClass2` returns something of type `AbstractPrimitiveClass`.


```python
>>> jax.make_jaxpr(PrimitiveClass2)(0.1)
{ lambda ; a:f32[]. let b:AbstractPrimitiveClass() = PrimitiveClass2 a in (b,) }
```
Loading
Loading