From 298449f452bd6701e8a0ecf923104c9438079d71 Mon Sep 17 00:00:00 2001 From: David Ittah Date: Fri, 21 Jun 2024 16:26:03 -0400 Subject: [PATCH] Support aliased controlled gates via their base operation (#792) Small improvements to the gate decomposition strategy with the new device API. Controlled operations defined via specialized classes (like `Toffoli` or `ControlledQubitUnitary`) are now implemented as controlled versions of their base operation if the device supports it. _In particular, `MultiControlledX` is no longer executed as a `QubitUnitary` with Lightning._ [sc-65228] --- .pylintrc | 2 +- doc/changelog.md | 116 ++++++++++-------- doc/dev/custom_devices.rst | 6 +- doc/dev/quick_start.rst | 9 +- frontend/catalyst/__init__.py | 1 - frontend/catalyst/compiled_functions.py | 2 +- frontend/catalyst/device/decomposition.py | 98 ++++++++++++++- frontend/catalyst/device/qjit_device.py | 7 +- frontend/catalyst/jax_primitives.py | 2 +- frontend/catalyst/jax_tracer.py | 4 +- frontend/catalyst/jit.py | 15 --- frontend/catalyst/programs/verification.py | 5 +- .../catalyst/third_party/cuda/__init__.py | 1 - frontend/catalyst/utils/toml.py | 30 ++--- frontend/test/lit/test_decomposition.py | 87 +------------ .../pytest/{ => device}/test_decomposition.py | 83 +++++++++---- frontend/test/pytest/test_custom_devices.py | 1 + frontend/test/pytest/test_operations.py | 20 +++ frontend/test/pytest/test_quantum_control.py | 3 - runtime/lib/backend/dummy/dummy_device.toml | 11 +- 20 files changed, 286 insertions(+), 217 deletions(-) rename frontend/test/pytest/{ => device}/test_decomposition.py (61%) diff --git a/.pylintrc b/.pylintrc index dfeaad00b2..fb5c51f915 100644 --- a/.pylintrc +++ b/.pylintrc @@ -37,7 +37,7 @@ enable=useless-suppression # it should appear only once). # Cyclical import checks are disabled for now as they are frequently used in # the code base, but this can be removed in the future once cycles are resolved. -disable=too-few-public-methods,invalid-name,too-many-locals,cyclic-import,import-error,no-else-return,unnecessary-ellipsis,duplicate-code +disable=too-few-public-methods,invalid-name,too-many-locals,cyclic-import,import-error,no-else-return,unnecessary-ellipsis,duplicate-code,abstract-method,no-name-in-module [MISCELLANEOUS] diff --git a/doc/changelog.md b/doc/changelog.md index 09b29a942c..2116457136 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -5,7 +5,8 @@ * `qjit` adheres to user-specified `mcm_method` given to the `QNode`. [(#798)](https://github.com/PennyLaneAI/catalyst/pull/798) -* The `dynamic_one_shot` transform uses a single auxiliary tape which is repeatedly simulated `n_shots` times to simulate hardware-like results. +* The `dynamic_one_shot` transform uses a single auxiliary tape which is repeatedly simulated + `n_shots` times to simulate hardware-like results. The loop over shots is executed with `catalyst.vmap`. [(#5617)](https://github.com/PennyLaneAI/pennylane/pull/5617) @@ -91,7 +92,7 @@ ``` -* Support for using `catalyst.value_and_grad` with a `qjit`-ted function. +* Support for using `catalyst.value_and_grad` with a `qjit`-ted function. [(#804)](https://github.com/PennyLaneAI/catalyst/pull/804) ```py @@ -237,11 +238,16 @@ [(#784)](https://github.com/PennyLaneAI/catalyst/pull/784) * Catalyst's adjoint and ctrl methods are now fully compatible with the PennyLane equivalent when - applied to a single Operator. This should lead to improved compatibility with PennyLane library code, - as well when reusing quantum functions with both Catalyst and PennyLane. + applied to a single Operator. This should lead to improved compatibility with PennyLane library + code, as well when reusing quantum functions with both Catalyst and PennyLane. [(#768)](https://github.com/PennyLaneAI/catalyst/pull/768) [(#771)](https://github.com/PennyLaneAI/catalyst/pull/771) +* Controlled operations defined via specialized classes (like `Toffoli` or `ControlledQubitUnitary`) + are now implemented as controlled versions of their base operation if the device supports it. + In particular, `MultiControlledX` is no longer executed as a `QubitUnitary` with Lightning. + [(#792)](https://github.com/PennyLaneAI/catalyst/pull/792) + * Catalyst now has support for `qml.sample(m)` where `m` is the result of a mid-circuit measurement. For now the feature is equivalent to returning `m` directly from a quantum function, but will be improved to return an array with one measurement result for each @@ -280,7 +286,8 @@ both be provided as keyword arguments. [(#790)](https://github.com/PennyLaneAI/catalyst/pull/790) -* Finite difference is now always possible regardless of whether the differentiated function has a valid gradient for autodiff or not. +* Finite difference is now always possible regardless of whether the differentiated function has a + valid gradient for autodiff or not. [(#789)](https://github.com/PennyLaneAI/catalyst/pull/789) * A new GitHub workflow makes available a binary distribution for Linux Arm64. @@ -295,7 +302,9 @@

Bug fixes

-* `device_shots` is modified to `0` on the fly in `Measure` (and set back to its original value after the call to `PartialProbs`) to compute mid-circuit probabilities analytically, even when the device has finite shots. +* `device_shots` is modified to `0` on the fly in `Measure` (and set back to its original value + after the call to `PartialProbs`) to compute mid-circuit probabilities analytically, even when the + device has finite shots. [(#801)](https://github.com/PennyLaneAI/catalyst/pull/801) * The Catalyst runtime now raises an error if an qubit is accessed out of bounds from the allocated @@ -306,7 +315,9 @@ [(#733)](https://github.com/PennyLaneAI/catalyst/pull/733) * Correctly linking openblas routines necessary for `jax.scipy.linalg.expm`. - In this bug fix, four openblas routines were newly linked and are now discoverable by `stablehlo.custom_call@`. They are `blas_dtrsm`, `blas_ztrsm`, `lapack_dgetrf`, `lapack_zgetrf`. + In this bug fix, four openblas routines were newly linked and are now discoverable by + `stablehlo.custom_call@`. They are `blas_dtrsm`, `blas_ztrsm`, `lapack_dgetrf`, + `lapack_zgetrf`. [(#752)](https://github.com/PennyLaneAI/catalyst/pull/752) * Correctly recording types of constant array when lowering `catalyst.grad` to mlir @@ -320,7 +331,8 @@

Internal changes

-* Catalyst uses the `collapse` method of Lightning simulators in `Measure` to select a state vector branch and normalize. +* Catalyst uses the `collapse` method of Lightning simulators in `Measure` to select a state vector + branch and normalize. [(#801)](https://github.com/PennyLaneAI/catalyst/pull/801) * The `QCtrl` class in Catalyst has been renamed to `HybridCtrl`, indicating its capability @@ -413,48 +425,56 @@ interface and allows for multiple `MemrefCallable` to be defined for a single callback, which is necessary for custom gradient of `pure_callbacks`. -* A new `catalyst::gradient::GradientOpInterface` is available when querying the gradient method in the mlir c++ api. +* A new `catalyst::gradient::GradientOpInterface` is available when querying the gradient method in + the mlir c++ api. [(#800)](https://github.com/PennyLaneAI/catalyst/pull/800) - `catalyst::gradient::GradOp`, `ValueAndGradOp`, `JVPOp`, and `VJPOp` now inherits traits in this new `GradientOpInterface`. The supported attributes are now `getMethod()`, `getCallee()`, `getDiffArgIndices()`, `getDiffArgIndicesAttr()`, `getFiniteDiffParam()`, and `getFiniteDiffParamAttr()`. - - - There are operations that could potentially be used as `GradOp`, `ValueAndGradOp`, `JVPOp` or `VJPOp`. When trying to get the gradient method, instead of doing - ```C++ - auto gradOp = dyn_cast(op); - auto jvpOp = dyn_cast(op); - auto vjpOp = dyn_cast(op); - - llvm::StringRef MethodName; - if (gradOp) - MethodName = gradOp.getMethod(); - else if (jvpOp) - MethodName = jvpOp.getMethod(); - else if (vjpOp) - MethodName = vjpOp.getMethod(); - ``` - to identify which op it actually is and protect against segfaults (calling `nullptr.getMethod()`), in the new interface we just do - ```C++ - auto gradOpInterface = cast(op); - llvm::StringRef MethodName = gradOpInterface.getMethod(); - ``` - - - Another advantage is that any concrete gradient operation object can behave like a `GradientOpInterface` : - ```C++ - GradOp op; // or ValueAndGradOp op, ... - auto foo = [](GradientOpInterface op){ - llvm::errs() << op.getCallee(); - }; - foo(op); // this works! - ``` - - - Finally, concrete op specific methods can still be called by "reinterpret"-casting the interface back to a concrete op (provided the concrete op type is correct): - ```C++ - auto foo = [](GradientOpInterface op){ - size_t numGradients = cast(&op)->getGradients().size(); - }; - ValueAndGradOp op; - foo(op); // this works! - ``` + `catalyst::gradient::GradOp`, `ValueAndGradOp`, `JVPOp`, and `VJPOp` now inherits traits in this + new `GradientOpInterface`. The supported attributes are now `getMethod()`, `getCallee()`, + `getDiffArgIndices()`, `getDiffArgIndicesAttr()`, `getFiniteDiffParam()`, and + `getFiniteDiffParamAttr()`. + + - There are operations that could potentially be used as `GradOp`, `ValueAndGradOp`, `JVPOp` or + `VJPOp`. When trying to get the gradient method, instead of doing + ```C++ + auto gradOp = dyn_cast(op); + auto jvpOp = dyn_cast(op); + auto vjpOp = dyn_cast(op); + + llvm::StringRef MethodName; + if (gradOp) + MethodName = gradOp.getMethod(); + else if (jvpOp) + MethodName = jvpOp.getMethod(); + else if (vjpOp) + MethodName = vjpOp.getMethod(); + ``` + to identify which op it actually is and protect against segfaults (calling + `nullptr.getMethod()`), in the new interface we just do + ```C++ + auto gradOpInterface = cast(op); + llvm::StringRef MethodName = gradOpInterface.getMethod(); + ``` + + - Another advantage is that any concrete gradient operation object can behave like a + `GradientOpInterface`: + ```C++ + GradOp op; // or ValueAndGradOp op, ... + auto foo = [](GradientOpInterface op){ + llvm::errs() << op.getCallee(); + }; + foo(op); // this works! + ``` + + - Finally, concrete op specific methods can still be called by "reinterpret"-casting the interface + back to a concrete op (provided the concrete op type is correct): + ```C++ + auto foo = [](GradientOpInterface op){ + size_t numGradients = cast(&op)->getGradients().size(); + }; + ValueAndGradOp op; + foo(op); // this works! + ```

Contributors

diff --git a/doc/dev/custom_devices.rst b/doc/dev/custom_devices.rst index 0ae1d51c18..6857c666c0 100644 --- a/doc/dev/custom_devices.rst +++ b/doc/dev/custom_devices.rst @@ -262,7 +262,7 @@ headers and fields are generally required, unless stated otherwise. CY = { properties = [ "invertible" ] } CZ = { properties = [ "invertible" ] } PhaseShift = { properties = [ "controllable", "invertible" ] } - ControlledPhaseShift = { properties = [ "controllable", "invertible" ] } + ControlledPhaseShift = { properties = [ "invertible" ] } RX = { properties = [ "controllable", "invertible" ] } RY = { properties = [ "controllable", "invertible" ] } RZ = { properties = [ "controllable", "invertible" ] } @@ -294,7 +294,7 @@ headers and fields are generally required, unless stated otherwise. QubitStateVector = {} StatePrep = {} ControlledQubitUnitary = {} - DiagonalQubitUnitary = {} + MultiControlledX = {} SingleExcitation = {} SingleExcitationPlus = {} SingleExcitationMinus = {} @@ -310,7 +310,7 @@ headers and fields are generally required, unless stated otherwise. # Gates which should be translated to QubitUnitary [operators.gates.matrix] - MultiControlledX = {} + DiagonalQubitUnitary = {} # Observables supported by the device [operators.observables] diff --git a/doc/dev/quick_start.rst b/doc/dev/quick_start.rst index 39e6faf973..3c63a06d71 100644 --- a/doc/dev/quick_start.rst +++ b/doc/dev/quick_start.rst @@ -106,12 +106,13 @@ more complex quantum circuits; see below for the list of currently supported ope .. important:: - Most decomposition logic will be equivalent to PennyLane's decomposition. - However, decomposition logic will differ in the following cases: + Decomposition will generally happen in accordance with the specification provided by devices, + which can vary from device to device (e.g. ``default.qubit`` and ``lightning.qubit`` might + decompose quite differently.) + However, Catalyst's decomposition logic will differ in the following cases: 1. All :class:`qml.Controlled ` operations will decompose to :class:`qml.QubitUnitary ` operations. - 2. :class:`qml.ControlledQubitUnitary ` operations will decompose to :class:`qml.QubitUnitary ` operations. - 3. The list of device-supported gates employed by Catalyst is currently different than that of the ``lightning.qubit`` device, as defined by the :class:`~.qjit_device.QJITDevice`. + 2. The set of operations supported by Catalyst itself can in some instances lead to additional decompositions compared to the device itself. .. raw:: html diff --git a/frontend/catalyst/__init__.py b/frontend/catalyst/__init__.py index ec80e5d7a4..a1680b06fe 100644 --- a/frontend/catalyst/__init__.py +++ b/frontend/catalyst/__init__.py @@ -39,7 +39,6 @@ try: if INSTALLED: - # pylint: disable=no-name-in-module from catalyst._revision import __revision__ # pragma: no cover else: from subprocess import check_output diff --git a/frontend/catalyst/compiled_functions.py b/frontend/catalyst/compiled_functions.py index 382c6b6891..07f417461b 100644 --- a/frontend/catalyst/compiled_functions.py +++ b/frontend/catalyst/compiled_functions.py @@ -36,7 +36,7 @@ get_decomposed_signature, typecheck_signatures, ) -from catalyst.utils import wrapper # pylint: disable=no-name-in-module +from catalyst.utils import wrapper from catalyst.utils.c_template import get_template, mlir_type_to_numpy_type from catalyst.utils.filesystem import Directory from catalyst.utils.jnp_to_memref import get_ranked_memref_descriptor diff --git a/frontend/catalyst/device/decomposition.py b/frontend/catalyst/device/decomposition.py index 7ae34f83a8..6aaabfcf61 100644 --- a/frontend/catalyst/device/decomposition.py +++ b/frontend/catalyst/device/decomposition.py @@ -47,18 +47,112 @@ logger.addHandler(logging.NullHandler()) +def check_alternative_control_support(op, capabilities): + """Verify that aliased controlled operations aren't supported via alternative definitions.""" + + if ( + isinstance(op, qml.ControlledQubitUnitary) + and capabilities.native_ops.get("QubitUnitary") + and capabilities.native_ops.get("QubitUnitary").controllable + ): + decomp = qml.ops.Controlled( + qml.QubitUnitary(*op.data, wires=op.target_wires), op.control_wires, op.control_values + ) + elif ( + isinstance(op, qml.ControlledPhaseShift) + and capabilities.native_ops.get("PhaseShift") + and capabilities.native_ops.get("PhaseShift").controllable + ): + decomp = qml.ops.Controlled( + qml.PhaseShift(*op.data, wires=op.target_wires), op.control_wires + ) + elif ( + isinstance(op, (qml.CNOT, qml.Toffoli, qml.MultiControlledX)) + and capabilities.native_ops.get("PauliX") + and capabilities.native_ops.get("PauliX").controllable + ): + decomp = qml.ops.Controlled( + qml.PauliX(wires=op.target_wires), op.control_wires, op.control_values, op.work_wires + ) + elif ( + isinstance(op, qml.CY) + and capabilities.native_ops.get("PauliY") + and capabilities.native_ops.get("PauliY").controllable + ): + decomp = qml.ops.Controlled(qml.PauliY(wires=op.target_wires), op.control_wires) + elif ( + isinstance(op, (qml.CZ, qml.CCZ)) + and capabilities.native_ops.get("PauliZ") + and capabilities.native_ops.get("PauliZ").controllable + ): + decomp = qml.ops.Controlled(qml.PauliZ(wires=op.target_wires), op.control_wires) + elif ( + isinstance(op, qml.CRX) + and capabilities.native_ops.get("RX") + and capabilities.native_ops.get("RX").controllable + ): + decomp = qml.ops.Controlled(qml.RX(*op.data, wires=op.target_wires), op.control_wires) + elif ( + isinstance(op, qml.CRY) + and capabilities.native_ops.get("RY") + and capabilities.native_ops.get("RY").controllable + ): + decomp = qml.ops.Controlled(qml.RY(*op.data, wires=op.target_wires), op.control_wires) + elif ( + isinstance(op, qml.CRZ) + and capabilities.native_ops.get("RZ") + and capabilities.native_ops.get("RZ").controllable + ): + decomp = qml.ops.Controlled(qml.RZ(*op.data, wires=op.target_wires), op.control_wires) + elif ( + isinstance(op, qml.CRot) + and capabilities.native_ops.get("Rot") + and capabilities.native_ops.get("Rot").controllable + ): + decomp = qml.ops.Controlled(qml.Rot(*op.data, wires=op.target_wires), op.control_wires) + elif ( + isinstance(op, qml.CH) + and capabilities.native_ops.get("Hadamard") + and capabilities.native_ops.get("Hadamard").controllable + ): + decomp = qml.ops.Controlled(qml.Hadamard(wires=op.target_wires), op.control_wires) + elif ( + isinstance(op, qml.CSWAP) + and capabilities.native_ops.get("SWAP") + and capabilities.native_ops.get("SWAP").controllable + ): + decomp = qml.ops.Controlled(qml.SWAP(wires=op.target_wires), op.control_wires) + else: + decomp = None + + return [decomp] if decomp else decomp + + +def check_alternative_support(op, capabilities): + """Verify that aliased operations aren't supported via alternative definitions.""" + + if isinstance(op, qml.ops.Controlled): + return check_alternative_control_support(op, capabilities) + + return None + + def catalyst_decomposer(op, capabilities: DeviceCapabilities): """A decomposer for catalyst, to be passed to the decompose transform. Takes an operator and returns the default decomposition, unless the operator should decompose to a QubitUnitary. Raises a CompileError for MidMeasureMP""" if isinstance(op, MidMeasureMP): raise CompileError("Must use 'measure' from Catalyst instead of PennyLane.") - # TODO: remove hardcoded controlled to matrix decomp. - # Check op.has_matrix to support controlled ops without matrices: + + alternative_decomp = check_alternative_support(op, capabilities) + if alternative_decomp is not None: + return alternative_decomp + if capabilities.to_matrix_ops.get(op.name) or ( op.has_matrix and isinstance(op, qml.ops.Controlled) ): return _decompose_to_matrix(op) + return op.decomposition() diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index bb0d1b425f..c02fccfcd9 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -85,7 +85,6 @@ "PhaseShift", "PSWAP", "QubitUnitary", - "ControlledQubitUnitary", "Rot", "RX", "RY", @@ -341,11 +340,9 @@ def default_expand_fn(self, circuit, max_expansion=10): Most decomposition logic will be equivalent to PennyLane's decomposition. However, decomposition logic will differ in the following cases: - 1. All :class:`qml.QubitUnitary ` operations + 1. All unsupported :class:`qml.Controlled ` instances will decompose to :class:`qml.QubitUnitary ` operations. - 2. :class:`qml.ControlledQubitUnitary ` operations - will decompose to :class:`qml.QubitUnitary ` operations. - 3. The list of device-supported gates employed by Catalyst is currently different than + 2. The list of device-supported gates employed by Catalyst is currently different than that of the ``lightning.qubit`` device, as defined by the :class:`~.qjit_device.QJITDevice`. diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 87c8965d64..2508beb5df 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -93,7 +93,7 @@ from catalyst.utils.extra_bindings import FromElementsOp, TensorExtractOp from catalyst.utils.types import convert_shaped_arrays_to_tensors -# pylint: disable=unused-argument,too-many-lines,too-many-statements,too-many-arguments,protected-access +# pylint: disable=unused-argument,too-many-lines,too-many-statements,too-many-function-args,protected-access ######### # Types # diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 7d34ac653f..f71742cb33 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -28,7 +28,7 @@ from pennylane import QubitDevice, QubitUnitary, QueuingManager from pennylane.measurements import MeasurementProcess from pennylane.operation import AnyWires, Operation, Operator, Wires -from pennylane.ops import Adjoint, Controlled, ControlledOp, ControlledQubitUnitary +from pennylane.ops import Adjoint, Controlled, ControlledOp from pennylane.tape import QuantumTape from pennylane.transforms.core import TransformProgram @@ -597,7 +597,7 @@ def trace_quantum_operations( def bind_native_operation(qrp, op, controlled_wires, controlled_values, adjoint=False): # For named-controlled operations (e.g. CNOT, CY, CZ) - bind directly by name. For # Controlled(OP) bind OP with native quantum control syntax, and similarly for Adjoint(OP). - if type(op) in (Controlled, ControlledOp, ControlledQubitUnitary): + if type(op) in (Controlled, ControlledOp): return bind_native_operation(qrp, op.base, op.control_wires, op.control_values, adjoint) elif isinstance(op, Adjoint): return bind_native_operation(qrp, op.base, controlled_wires, controlled_values, True) diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index d50c9a5a66..b1949883ff 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -175,21 +175,6 @@ def circuit(x: complex, z: ShapedArray(shape=(3,), dtype=jnp.float64)): For more details on compilation and debugging, please see :doc:`/dev/sharp_bits`. - .. important:: - - Most decomposition logic will be equivalent to PennyLane's decomposition. - However, decomposition logic will differ in the following cases: - - 1. All :class:`qml.Controlled ` operations will decompose - to :class:`qml.QubitUnitary ` operations. - - 2. :class:`qml.ControlledQubitUnitary ` operations will - decompose to :class:`qml.QubitUnitary ` operations. - - 3. The list of device-supported gates employed by Catalyst is currently different than that - of the ``lightning.qubit`` device, as defined by the - :class:`~.qjit_device.QJITDevice`. - .. details:: :title: AutoGraph and Python control flow diff --git a/frontend/catalyst/programs/verification.py b/frontend/catalyst/programs/verification.py index a5418f46dc..f726c66596 100644 --- a/frontend/catalyst/programs/verification.py +++ b/frontend/catalyst/programs/verification.py @@ -24,7 +24,6 @@ CompositeOp, Controlled, ControlledOp, - ControlledQubitUnitary, Hamiltonian, SymbolicOp, ) @@ -109,7 +108,6 @@ def verify_no_state_variance_returns(tape: QuantumTape) -> None: return (tape,), lambda x: x[0] -# pylint: disable=too-many-statements @transform def verify_operations(tape: QuantumTape, grad_method, qjit_device): """verify the quantum program against Catalyst requirements. This transform makes no @@ -179,8 +177,7 @@ def _inv_op_checker(op, in_inverse): return in_inverse # If its a PL Controlled we also want to check its base to catch C(Adjoint(base)). # PL simplification should mean pure PL operators will not be more nested than this. - # TODO: remove ControlledQubitUnitary to treat it as independant gate everywhere - if type(op) in (Controlled, ControlledOp, ControlledQubitUnitary): + if type(op) in (Controlled, ControlledOp): _inv_op_checker(op.base, in_inverse) return in_inverse # Early exit when not in inverse, only determine the inverse status for recursing later. diff --git a/frontend/catalyst/third_party/cuda/__init__.py b/frontend/catalyst/third_party/cuda/__init__.py index 8c44614348..cf61ec2abd 100644 --- a/frontend/catalyst/third_party/cuda/__init__.py +++ b/frontend/catalyst/third_party/cuda/__init__.py @@ -96,7 +96,6 @@ def wrap_fn(fn): # Do we need to reimplement apply for every child? -# pylint: disable=abstract-method class BaseCudaInstructionSet(qml.QubitDevice): """Base instruction set for CUDA-Quantum devices""" diff --git a/frontend/catalyst/utils/toml.py b/frontend/catalyst/utils/toml.py index fe23cf93fb..8b2006b61e 100644 --- a/frontend/catalyst/utils/toml.py +++ b/frontend/catalyst/utils/toml.py @@ -16,7 +16,7 @@ """ import importlib.util -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import reduce from itertools import repeat from typing import Any, Dict, List, Set @@ -58,9 +58,9 @@ def read_toml_file(toml_file: str) -> TOMLDocument: class OperationProperties: """Capabilities of a single operation""" - invertible: bool - controllable: bool - differentiable: bool + invertible: bool = False + controllable: bool = False + differentiable: bool = False def intersect_properties(a: OperationProperties, b: OperationProperties) -> OperationProperties: @@ -76,17 +76,17 @@ def intersect_properties(a: OperationProperties, b: OperationProperties) -> Oper class DeviceCapabilities: # pylint: disable=too-many-instance-attributes """Quantum device capabilities""" - native_ops: Dict[str, OperationProperties] - to_decomp_ops: Dict[str, OperationProperties] - to_matrix_ops: Dict[str, OperationProperties] - native_obs: Dict[str, OperationProperties] - measurement_processes: Set[str] - qjit_compatible_flag: bool - mid_circuit_measurement_flag: bool - runtime_code_generation_flag: bool - dynamic_qubit_management_flag: bool - non_commuting_observables_flag: bool - options: Dict[str, bool] + native_ops: Dict[str, OperationProperties] = field(default_factory=dict) + to_decomp_ops: Dict[str, OperationProperties] = field(default_factory=dict) + to_matrix_ops: Dict[str, OperationProperties] = field(default_factory=dict) + native_obs: Dict[str, OperationProperties] = field(default_factory=dict) + measurement_processes: Set[str] = field(default_factory=dict) + qjit_compatible_flag: bool = False + mid_circuit_measurement_flag: bool = False + runtime_code_generation_flag: bool = False + dynamic_qubit_management_flag: bool = False + non_commuting_observables_flag: bool = False + options: Dict[str, bool] = field(default_factory=dict) def intersect_operations( diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 7696ca9a28..fb3951060c 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -21,7 +21,7 @@ import jax import pennylane as qml -from catalyst import cond, for_loop, measure, qjit, while_loop +from catalyst import measure, qjit from catalyst.compiler import get_lib_path from catalyst.device import get_device_capabilities from catalyst.utils.toml import ( @@ -103,7 +103,7 @@ def execute(self, circuits, execution_config): def test_decompose_multicontrolledx(): - """Test decomposition of MultiControlledX.""" + """Test decomposition of MultiControlledX as an aliased gate.""" dev = get_custom_device_without(5, discards={"MultiControlledX"}) @qjit(target="mlir") @@ -112,9 +112,9 @@ def test_decompose_multicontrolledx(): def decompose_multicontrolled_x1(theta: float): qml.RX(theta, wires=[0]) # CHECK-NOT: name = "MultiControlledX" - # CHECK: quantum.unitary + # CHECK: quantum.custom "PauliX"() {{%[a-zA-Z0-9_]+}} ctrls({{%[a-zA-Z0-9_]+}}, {{%[a-zA-Z0-9_]+}}, {{%[a-zA-Z0-9_]+}}) # CHECK-NOT: name = "MultiControlledX" - qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) + qml.MultiControlledX(wires=[0, 1, 2, 3]) return qml.state() print(decompose_multicontrolled_x1.mlir) @@ -123,85 +123,6 @@ def decompose_multicontrolled_x1(theta: float): test_decompose_multicontrolledx() -def test_decompose_multicontrolledx_in_conditional(): - """Test decomposition of MultiControlledX in conditional.""" - dev = get_custom_device_without(5, discards={"MultiControlledX"}) - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: @jit_decompose_multicontrolled_x2 - def decompose_multicontrolled_x2(theta: float, n: int): - qml.RX(theta, wires=[0]) - - # CHECK-NOT: name = "MultiControlledX" - # CHECK: quantum.unitary - # CHECK-NOT: name = "MultiControlledX" - @cond(n > 1) - def cond_fn(): - qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) - - cond_fn() - return qml.state() - - print(decompose_multicontrolled_x2.mlir) - - -test_decompose_multicontrolledx_in_conditional() - - -def test_decompose_multicontrolledx_in_while_loop(): - """Test decomposition of MultiControlledX in while loop.""" - dev = get_custom_device_without(5, discards={"MultiControlledX"}) - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: @jit_decompose_multicontrolled_x3 - def decompose_multicontrolled_x3(theta: float, n: int): - qml.RX(theta, wires=[0]) - - # CHECK-NOT: name = "MultiControlledX" - # CHECK: quantum.unitary - # CHECK-NOT: name = "MultiControlledX" - @while_loop(lambda v: v[0] < 10) - def loop(v): - qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) - return v[0] + 1, v[1] - - loop((0, n)) - return qml.state() - - print(decompose_multicontrolled_x3.mlir) - - -test_decompose_multicontrolledx_in_while_loop() - - -def test_decompose_multicontrolledx_in_for_loop(): - """Test decomposition of MultiControlledX in for loop.""" - dev = get_custom_device_without(5, discards={"MultiControlledX"}) - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: @jit_decompose_multicontrolled_x4 - def decompose_multicontrolled_x4(theta: float, n: int): - qml.RX(theta, wires=[0]) - - # CHECK-NOT: name = "MultiControlledX" - # CHECK: quantum.unitary - # CHECK-NOT: name = "MultiControlledX" - @for_loop(0, n, 1) - def loop(_): - qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) - - loop() - return qml.state() - - print(decompose_multicontrolled_x4.mlir) - - -test_decompose_multicontrolledx_in_for_loop() - - def test_decompose_rot(): """Test decomposition of Rot gate.""" dev = get_custom_device_without(1, discards={"Rot", "C(Rot)"}) diff --git a/frontend/test/pytest/test_decomposition.py b/frontend/test/pytest/device/test_decomposition.py similarity index 61% rename from frontend/test/pytest/test_decomposition.py rename to frontend/test/pytest/device/test_decomposition.py index ef55e910fc..fc0979b420 100644 --- a/frontend/test/pytest/test_decomposition.py +++ b/frontend/test/pytest/device/test_decomposition.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 Xanadu Quantum Technologies Inc. +# Copyright 2024 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,15 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Unit test module for catalyst/device/decomposition.py""" + from copy import deepcopy import pennylane as qml import pytest -from jax import numpy as jnp -from catalyst import CompileError, ctrl, measure, qjit +from catalyst import CompileError, ctrl, qjit from catalyst.device import get_device_capabilities -from catalyst.utils.toml import ProgramFeatures, pennylane_operation_set +from catalyst.device.decomposition import catalyst_decomposer +from catalyst.utils.toml import ( + DeviceCapabilities, + OperationProperties, + ProgramFeatures, + pennylane_operation_set, +) class CustomDevice(qml.QubitDevice): @@ -67,24 +74,56 @@ def observables(self): return pennylane_operation_set(self.qjit_capabilities.native_obs) -@pytest.mark.parametrize("param,expected", [(0.0, True), (jnp.pi, False)]) -def test_decomposition(param, expected): - dev = CustomDevice(wires=2) - - @qjit - @qml.qnode(dev) - def mid_circuit(x: float): - qml.Hadamard(wires=0) - qml.Rot(0, 0, x, wires=0) - qml.Hadamard(wires=0) - m = measure(wires=0) - b = m ^ 0x1 - qml.Hadamard(wires=1) - qml.Rot(0, 0, b * jnp.pi, wires=1) - qml.Hadamard(wires=1) - return measure(wires=1) - - assert mid_circuit(param) == expected +class TestGateAliases: + """Test the decomposition of gates wich are in fact supported via aliased or equivalent + op definitions.""" + + special_control_ops = ( + qml.CNOT([0, 1]), + qml.Toffoli([0, 1, 2]), + qml.MultiControlledX([1, 2], 0, [True, False]), + qml.CZ([0, 1]), + qml.CCZ([0, 1, 2]), + qml.CY([0, 1]), + qml.CSWAP([0, 1, 2]), + qml.CH([0, 1]), + qml.CRX(0.1, [0, 1]), + qml.CRY(0.1, [0, 1]), + qml.CRZ(0.1, [0, 1]), + qml.CRot(0.1, 0.2, 0.3, [0, 1]), + qml.ControlledPhaseShift(0.1, [0, 1]), + qml.ControlledQubitUnitary([[1, 0], [0, 1j]], 1, 0), + ) + control_base_ops = ( + qml.PauliX, + qml.PauliX, + qml.PauliX, + qml.PauliZ, + qml.PauliZ, + qml.PauliY, + qml.SWAP, + qml.Hadamard, + qml.RX, + qml.RY, + qml.RZ, + qml.Rot, + qml.PhaseShift, + qml.QubitUnitary, + ) + assert len(special_control_ops) == len(control_base_ops) + + @pytest.mark.parametrize("gate, base", zip(special_control_ops, control_base_ops)) + def test_control_aliases(self, gate, base): + """Test the decomposition of specialized control operations.""" + + capabilities = DeviceCapabilities( + native_ops={base.__name__: OperationProperties(controllable=True)} + ) + decomp = catalyst_decomposer(gate, capabilities) + + assert len(decomp) == 1 + assert type(decomp[0]) is qml.ops.ControlledOp + assert type(decomp[0].base) is base class TestControlledDecomposition: diff --git a/frontend/test/pytest/test_custom_devices.py b/frontend/test/pytest/test_custom_devices.py index bce54edb7d..9fca66345c 100644 --- a/frontend/test/pytest/test_custom_devices.py +++ b/frontend/test/pytest/test_custom_devices.py @@ -83,6 +83,7 @@ "MultiControlledX", "SISWAP", "ControlledPhaseShift", + "C(QubitUnitary)", "C(PauliY)", "C(RY)", "C(PauliX)", diff --git a/frontend/test/pytest/test_operations.py b/frontend/test/pytest/test_operations.py index 998820765c..0090333e37 100644 --- a/frontend/test/pytest/test_operations.py +++ b/frontend/test/pytest/test_operations.py @@ -222,5 +222,25 @@ def interpreted(x): assert np.allclose(compiled(inp), interpreted(inp)) +def test_multicontrolledx_via_paulix(): + """Test that lightning executes multicontrolled x via paulix rather than qubit unitary.""" + + dev = qml.device("lightning.qubit", wires=4) + + @qjit + @qml.qnode(dev) + def circuit(): + qml.Hadamard(0) + qml.Hadamard(1) + qml.Hadamard(2) + qml.MultiControlledX(control_wires=[0, 1, 2], wires=[3], control_values=[True, False, True]) + return qml.state() + + assert "QubitUnitary" not in str(circuit.jaxpr) + assert "PauliX" in str(circuit.jaxpr) + + assert np.allclose(circuit(), circuit.original_function()) + + if __name__ == "__main__": pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/test_quantum_control.py b/frontend/test/pytest/test_quantum_control.py index 48d8d66503..636c9bcce8 100644 --- a/frontend/test/pytest/test_quantum_control.py +++ b/frontend/test/pytest/test_quantum_control.py @@ -462,9 +462,6 @@ def native_controlled(): ) return qml.state() - # The code will be lowered to `QubitUnitary` of an updated - # matrix that represents the `ControlledQubitUnitary`. - compiled = qjit()(native_controlled) result = compiled() expected = native_controlled() diff --git a/runtime/lib/backend/dummy/dummy_device.toml b/runtime/lib/backend/dummy/dummy_device.toml index d8ecde79f4..33eea473ef 100644 --- a/runtime/lib/backend/dummy/dummy_device.toml +++ b/runtime/lib/backend/dummy/dummy_device.toml @@ -2,8 +2,7 @@ schema = 2 [operators.gates.native] -QubitUnitary = { properties = [ "invertible", "differentiable" ] } -ControlledQubitUnitary = { properties = [ "invertible", "differentiable" ] } +QubitUnitary = { properties = [ "invertible", "controllable", "differentiable" ] } PauliX = { properties = [ "controllable", "invertible", "differentiable" ] } PauliY = { properties = [ "controllable", "invertible", "differentiable" ] } PauliZ = { properties = [ "controllable", "invertible", "differentiable" ] } @@ -56,7 +55,8 @@ SQISW = {} BasisState = {} QubitStateVector = {} StatePrep = {} -DiagonalQubitUnitary = {} +ControlledQubitUnitary = {} +MultiControlledX = {} QubitCarry = {} QubitSum = {} OrbitalRotation = {} @@ -66,7 +66,7 @@ ECR = {} # Gates which should be translated to QubitUnitary [operators.gates.matrix] -MultiControlledX = {} +DiagonalQubitUnitary = {} # Observables supported by the device @@ -107,7 +107,7 @@ mid_circuit_measurement = true # determining if the device supports dynamic qubit allocation/deallocation. dynamic_qubit_management = false -# whether the device can support non-commuting measurements together +# whether the device can support non-commuting measurements together # in a single execution non_commuting_observables = true @@ -115,4 +115,3 @@ non_commuting_observables = true option1 = "_option1" option2 = "_option2" -