Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add from_plxpr conversion function #837

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,10 @@ Paul Haochen Wang,
(array([0.5, 0. , 0.5, 0. ]),)
```

* Adds `catalyst.from_plxpr.from_plxpr` for converting a PennyLane variant jaxpr into a
Catalyst variant jaxpr.
[(#663)](https://github.com/PennyLaneAI/catalyst/pull/663)

<h3>Breaking changes</h3>

* The `mitigate_with_zne` function no longer accepts a `degree` parameter for polynomial fitting
Expand Down
371 changes: 371 additions & 0 deletions frontend/catalyst/from_plxpr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,371 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This submodule defines a utility for converting plxpr into Catalyst jaxpr.
"""
from dataclasses import dataclass, field
from functools import partial
from typing import Callable

import jax
from jax.extend.linear_util import wrap_init
from pennylane.capture import AbstractMeasurement, AbstractOperator, qnode_prim

from catalyst.device import extract_backend_info, get_device_capabilities
from catalyst.jax_primitives import (
AbstractQreg,
AbstractQbit,
compbasis_p,
expval_p,
func_p,
namedobs_p,
probs_p,
qalloc_p,
qdealloc_p,
qdevice_p,
qextract_p,
qinsert_p,
qinst_p,
sample_p,
state_p,
var_p,
)
from catalyst.utils.toml import ProgramFeatures

measurement_map = {
"sample_wires": sample_p,
"expval_obs": expval_p,
"var_obs": var_p,
"probs_wires": probs_p,
"state_wires": state_p,
}


def _get_shapes_for(*measurements, shots=None, num_device_wires=0):
if jax.config.jax_enable_x64:
dtype_map = {
float: jax.numpy.float64,
int: jax.numpy.int64,
complex: jax.numpy.complex128,
}
else:
dtype_map = {

Check warning on line 63 in frontend/catalyst/from_plxpr.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/from_plxpr.py#L63

Added line #L63 was not covered by tests
float: jax.numpy.float32,
int: jax.numpy.int32,
complex: jax.numpy.complex64,
}

shapes = []
if not shots:
shots = [None]

for s in shots:
for m in measurements:
shape, dtype = m.abstract_eval(shots=s, num_device_wires=num_device_wires)
shapes.append(jax.core.ShapedArray(shape, dtype_map.get(dtype, dtype)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This map is interesting, I wonder if it is actually needed? Abstract JAX arrays can already hold Python types like int, float, complex which are considered "weak types" and will adapt to the bitwidth of "strong types" like int64, or fall back to the same types defined in the map above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like tests still pass when getting rid of it, so I'm going to say we don't need it until proven otherwise.

return shapes


# pylint: disable=unidiomatic-typecheck
def _read(var, env: dict):
return var.val if type(var) is jax.core.Literal else env[var]


def _get_device_kwargs(device: "pennylane.devices.Device") -> dict:
"""Calulcate the params for a device equation."""
features = ProgramFeatures(device.shots is not None)
capabilities = get_device_capabilities(device, features)
info = extract_backend_info(device, capabilities)
# Note that the value of rtd_kwargs is a string version of
# the info kwargs, not the info kwargs itself
Comment on lines +90 to +91

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice for future development to mention why we do this :)

Copy link
Contributor Author

@albi3ro albi3ro Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah... I'm not sure why. Think it has to do with mlir.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! @dime10 , do you have context here? I think it might be nice to comment on it :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! This is just the easiest way to serialize the data since it is going through the IR. The alternative would be to provide typed attributes for each config option which is more work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the info. Adding that as a clarifying source code comment.

return {
"rtd_kwargs": str(info.kwargs),
"rtd_lib": info.lpath,
"rtd_name": info.c_interface_name,
}


# code example has long lines
# pylint: disable=line-too-long
def from_plxpr(plxpr: jax.core.Jaxpr) -> Callable[..., jax.core.Jaxpr]:
"""Convert PennyLane variant jaxpr to Catalyst variant jaxpr.

Args:
jaxpr (jax.core.Jaxpr): PennyLane variant jaxpr

Returns:
Callable: A function that accepts the same arguments as the plxpr and returns catalyst
variant jaxpr.

Note that the input jaxpr should be workflow level and contain qnode primitives, rather than
qfunc level with individual operators.

.. code-block:: python

from catalyst.from_plxpr import from_plxpr

qml.capture.enable()

@qml.qnode(qml.device('lightning.qubit', wires=2))
def circuit(x):
qml.RX(x, 0)
return qml.probs(wires=(0, 1))

def f(x):
return circuit(2 * x) ** 2

plxpr = jax.make_jaxpr(circuit)(0.5)

print(from_plxpr(plxpr)(0.5))

.. code-block:: none

{ lambda ; a:f64[]. let
b:f64[4] = func[
call_jaxpr={ lambda ; c:f64[]. let
qdevice[
rtd_kwargs={'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}
rtd_lib=***
rtd_name=LightningSimulator
]
d:AbstractQreg() = qalloc 2
e:AbstractQbit() = qextract d 0
f:AbstractQbit() = qinst[
adjoint=False
ctrl_len=0
op=RX
params_len=1
qubits_len=1
] e c
g:AbstractQbit() = qextract d 1
h:AbstractObs(num_qubits=2,primitive=compbasis) = compbasis f g
i:f64[4] = probs[shape=(4,) shots=None] h
j:AbstractQreg() = qinsert d 0 f
qdealloc j
in (i,) }
fn=<QNode: device='<lightning.qubit device (wires=2) at 0x302761c90>', interface='auto', diff_method='best'>
] a
in (b,) }

"""
return jax.make_jaxpr(partial(from_plxpr_interpreter, plxpr.jaxpr, plxpr.consts))


# docstring link too long
# pylint: disable=line-too-long
def from_plxpr_interpreter(jaxpr: jax.core.Jaxpr, consts, *args) -> list:
"""Convert PennyLane variant jaxpr to Catalyst variant jaxpr.

See the documentation on
`Writing custom interpreters in JAX <https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html>`_
for a walkthrough on the general architecture and behavior of this function.

Given that ``catalyst.jax_primitives.func_p`` does not define a concrete implementation, this
function will fail outside of an abstract evaluation call.

"""
env = {} # dict mapping var "variables" to val "values"

# Bind args and consts to environment
for arg, invar in zip(args, jaxpr.invars):
env[invar] = arg
for const, constvar in zip(consts, jaxpr.constvars):
env[constvar] = const

Check warning on line 184 in frontend/catalyst/from_plxpr.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/from_plxpr.py#L184

Added line #L184 was not covered by tests

# Loop through equations and evaluate primitives using `bind`
for eqn in jaxpr.eqns:
# Read inputs to equation from environment
invals = [_read(invar, env) for invar in eqn.invars]
if eqn.primitive == qnode_prim:
if eqn.params["device"].shots != eqn.params["shots"]:
raise NotImplementedError("catalyst does not yet support dynamic shots")

f = partial(
_bind_catalxpr,
eqn.params["qfunc_jaxpr"].jaxpr,
eqn.params["qfunc_jaxpr"].consts,
eqn.params["device"],
)
# func_p is a CallPrimitive, so interpreter passed as first arg
# wrap_init turns the function into a WrappedFun, which can store
# transformations
outvals = func_p.bind(wrap_init(f), *invals, fn=eqn.params["qnode"])
else:
outvals = eqn.primitive.bind(*invals, **eqn.params)
# Primitives may return multiple outputs or not
if not eqn.primitive.multiple_results:
outvals = [outvals]
# Write the results of the primitive into the environment
for outvar, outval in zip(eqn.outvars, outvals):
env[outvar] = outval
return [env[outvar] for outvar in jaxpr.outvars]


@dataclass
class _InterpreterState:
"""This dataclass stores the mutable variables modified
over the course of interpreting the plxpr as catalxpr."""

qreg: AbstractQreg
"""The current quantum register."""

env: dict = field(default_factory=dict)
"""A dictionary mapping variables to values."""

wire_map: dict = field(default_factory=dict)
"""A map from wire values to ``AbstractQbit`` instances.

If a value is not present in this dictionary, it needs to be extracted
from the ``qreg`` property.
"""

op_math_cache: dict = field(default_factory=dict)
"""A cache of operations that will be consumed by later operations.
This is a map from the ``AbstractOperator`` variables to the corresponding
equation. The equation will need to be interpreted when the abstract
operator is consumed.
"""

def read(self, var):
Copy link
Collaborator

@dime10 dime10 Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is really nicely structured! The only thing I'm wondering is whether we can't use the interpreter state for the whole conversion procedure, since it seems to duplicate environment and free standing read function above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking of defining a base class for jaxpr interpreters and reusing the structure for a variety of different algorithms using the template design pattern.

You can see a prototype here:

https://github.com/PennyLaneAI/pennylane/blob/plxpr-interpreter/pennylane/capture/interpreters.py

The design and implementation is going to take a little bit of work, but in the end, we will be able to use that design to make this code much nicer :) I do see this code getting restructured once that happens.

Right now we free standing functions with a mutable input, but we could also make them all class methods. I'm fine with promoting everything to class methods now too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewriting te code to follow the structure I'm thinking of for a "Plxpr interpreter template".

"""Extract the value corresponding to a variable."""
return var.val if type(var) is jax.core.Literal else self.env[var]

def get_wire(self, wire_value) -> AbstractQbit:
"""Get the ``AbstractQbit`` corresponding to a wire value."""
if wire_value in self.wire_map:
return self.wire_map[wire_value]
return qextract_p.bind(self.qreg, wire_value)


def _deallocate(state: _InterpreterState) -> None:
"""Reinsert all active wires into the quantum register and deallocate the register."""
for orig_wire, wire in state.wire_map.items():
state.qreg = qinsert_p.bind(state.qreg, orig_wire, wire)
qdealloc_p.bind(state.qreg)


def _operator_eqn(eqn: jax.core.JaxprEqn, state: _InterpreterState) -> None:
"""Interpret a plxpr equation describing an operation as a catalxpr equation."""
if not isinstance(eqn.outvars[0], jax.core.DropVar):
state.op_math_cache[eqn.outvars[0]] = eqn
return

if "n_wires" not in eqn.params:
raise NotImplementedError(
f"Operator {eqn.primitive.name} not yet supported for catalyst conversion."
)
n_wires = eqn.params["n_wires"]

wire_values = [state.read(w) for w in eqn.invars[-n_wires:]]
wires = [state.get_wire(w) for w in wire_values]

invals = [state.read(invar) for invar in eqn.invars[:-n_wires]]
outvals = qinst_p.bind(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have two custom primitives because they deviate from the form qinst_p has, whose conversion isn't included here:

  • qunitary_p
  • gphase_p

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding.

*wires,
*invals,
op=eqn.primitive.name,
qubits_len=eqn.params["n_wires"],
params_len=len(eqn.invars) - eqn.params["n_wires"],
ctrl_len=0,
adjoint=False,
)

for wire_values, new_wire in zip(wire_values, outvals):
state.wire_map[wire_values] = new_wire


def _obs(eqn: jax.core.JaxprEqn, state: _InterpreterState):
"""Interpret the observable equation corresponding to a measurement equation's input."""
obs_eqn = state.op_math_cache[eqn.invars[0]]
if "n_wires" not in obs_eqn.params:
raise NotImplementedError(
f"from_plxpr can not yet interpret observables of type {obs_eqn.primitive}"
)

n_wires = obs_eqn.params["n_wires"]
wires = [state.get_wire(state.read(w)) for w in obs_eqn.invars[-n_wires:]]
invals = [state.read(invar) for invar in obs_eqn.invars[:-n_wires]]
return namedobs_p.bind(*wires, *invals, kind=obs_eqn.primitive.name)


def _compbasis_obs(eqn: jax.core.JaxprEqn, state: _InterpreterState, device: "qml.devices.Device"):
"""Add a computational basis sampling observable."""
if eqn.invars:
w_vals = [state.read(w_var) for w_var in eqn.invars]
else:
w_vals = device.wires # broadcast across all wires
wires = [state.get_wire(w) for w in w_vals]
return compbasis_p.bind(*wires)


def _measurement_eqn(eqn: jax.core.JaxprEqn, state: _InterpreterState, device):
if eqn.primitive.name not in measurement_map:
raise NotImplementedError(
f"measurement {eqn.primitive.name} not yet supported for conversion."
)
if eqn.params.get("has_eigvals", False):
raise NotImplementedError("from_plxpr does not yet support measurements with eigenvalues.")

if "_wires" in eqn.primitive.name:
obs = _compbasis_obs(eqn, state, device)
else:
obs = _obs(eqn, state)
# mcm based measurements wont be in measurement map yet
# so we can assume observable based

shaped_array = _get_shapes_for(
eqn.outvars[0].aval, shots=device.shots, num_device_wires=len(device.wires)
)[0]

primitive = measurement_map[eqn.primitive.name]
mval = primitive.bind(obs, shape=shaped_array.shape, shots=device.shots.total_shots)

# sample_p returns floats, so we need to converted it back to the expected integers here
if shaped_array.dtype != mval.dtype:
return jax.lax.convert_element_type(mval, shaped_array.dtype)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return mval


def _bind_catalxpr(jaxpr: jax.core.Jaxpr, consts, device, *args) -> list:
"""Interpret plxpr as jaxpr."""

qdevice_p.bind(**_get_device_kwargs(device))
qreg = qalloc_p.bind(len(device.wires))
state = _InterpreterState(qreg=qreg)

for arg, invar in zip(args, jaxpr.invars):
state.env[invar] = arg
for const, constvar in zip(consts, jaxpr.constvars):
state.env[constvar] = const

measurements = []
for eqn in jaxpr.eqns:
if isinstance(eqn.outvars[0].aval, AbstractOperator):
_operator_eqn(eqn, state)

elif isinstance(eqn.outvars[0].aval, AbstractMeasurement):
mval = _measurement_eqn(eqn, state, device)
state.env[eqn.outvars[0]] = mval
measurements.append(eqn.outvars[0])
else:
invals = [state.read(invar) for invar in eqn.invars]
outvals = eqn.primitive.bind(*invals, **eqn.params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
for outvar, outval in zip(eqn.outvars, outvals):
state.env[outvar] = outval

_deallocate(state)
# Read the final result of the Jaxpr from the environment
return [state.read(outvar) for outvar in measurements]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing this is expected for now, but just noting that here we can only ever return MP results, not values from classical ops.

Loading
Loading