Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Implement quantum device capabilities as a data structure #609

Merged
merged 51 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
af40cf6
Clean-up unconditional addition of the ControlledQubitUnitary to QJIT…
Mar 14, 2024
a4bbd6b
Clean-up unconditional addition of the ControlledQubitUnitary to QJIT…
Mar 14, 2024
f4ee1a8
Explicitly list supported controlled gates
Mar 14, 2024
5d288e8
Fix missing Projector observable in the kokkos device config
Mar 14, 2024
5174342
Read toml files in binary mode
Mar 14, 2024
6ef4188
[FIX] Apply ASAN runner_mu fix
Mar 15, 2024
d367d12
Merge branch 'main' into native-quantum-control-toml-fixup
Mar 18, 2024
797fd2d
Program verification initial commit
Mar 18, 2024
54d7808
Ongoing work
Mar 20, 2024
fc23d2d
Implement device capability datastructure
Mar 20, 2024
728171e
Implement device capability datastructure
Mar 21, 2024
e49f521
Implement device capability datastructure
Mar 21, 2024
fc1fdfc
Implement device capability datastructure
Mar 21, 2024
649437f
Implement device capability datastructure
Mar 21, 2024
9d0c1c9
Implement device capability datastructure
Mar 21, 2024
97bfa29
Implement device capability datastructure
Mar 21, 2024
dd6d469
Implement device capability datastructure
Mar 21, 2024
f5c8d21
Implement device capability datastructure
Mar 21, 2024
5d15a81
Clean up the code
Mar 21, 2024
83b2f8a
Temporary revert verification draft
Mar 21, 2024
96d1f81
Address CodeCov issues
Mar 21, 2024
1b51b59
Merge remote-tracking branch 'origin/main' into program-verification
Mar 21, 2024
e8cf0db
Address CodeCov issues
Mar 21, 2024
1cc4bf2
Address CodeCov issues
Mar 21, 2024
b765320
Address CodeCov issues
Mar 21, 2024
0668639
Address CodeCov issues
Mar 21, 2024
6855293
Address CodeCov issues
Mar 21, 2024
dbb62df
Address CodeCov issues
Mar 21, 2024
986e682
Fix kokkos tests
Mar 22, 2024
e47402e
Merge remote-tracking branch 'origin/main' into program-verification
Mar 22, 2024
29d9258
Cleanup the code
Mar 22, 2024
171f0ba
Address codecov issues
Mar 22, 2024
46cd001
Address codecov issues
Mar 22, 2024
258a9fc
Address codecov issues
Mar 22, 2024
26af78c
Address codecov issues
Mar 22, 2024
3bc8911
Merge remote-tracking branch 'origin/main' into program-verification
Apr 16, 2024
394099a
Address pylint issues
Apr 16, 2024
78d65cc
Update frontend/catalyst/qjit_device.py
Apr 17, 2024
c803ecf
Address review suggestion: remove unsafe_hash flag
Apr 17, 2024
44a2726
Update frontend/catalyst/utils/toml.py
Apr 17, 2024
8d381fb
Address review suggestion: split a schema1 patching logic into a sepa…
Apr 17, 2024
3029a66
Address review suggestion: rename get_gates and improve its description
Apr 17, 2024
56438da
Address pylint issues
Apr 17, 2024
2e98ace
Address review suggestion: rename device config fields
Apr 23, 2024
5c76483
Merge remote-tracking branch 'origin/main' into program-verification
Apr 23, 2024
b7f35a4
Address review suggestion: rename device config fields
Apr 23, 2024
d2955f3
Merge branch 'main' into program-verification
rmoyard Apr 24, 2024
49dff91
Merge remote-tracking branch 'origin/main' into program-verification
Apr 29, 2024
28cc69b
Address review suggestion: rename caps -> capabilities
Apr 29, 2024
8239b36
Address review suggestion: rename caps -> capabilities
Apr 29, 2024
04a5813
Address formatting issues
Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion frontend/catalyst/qfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __call__(self, *args, **kwargs):
backend_info = QFunc.extract_backend_info(self.device, config)

if isinstance(self.device, qml.devices.Device):
device = QJITDeviceNewAPI(self.device, config, backend_info)
device = QJITDeviceNewAPI(self.device, backend_info)
else:
device = QJITDevice(config, self.device.shots, self.device.wires, backend_info)

Expand Down
179 changes: 94 additions & 85 deletions frontend/catalyst/qjit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""This module contains the qjit device classes.
"""
from copy import deepcopy
from functools import partial
from typing import Optional, Set

Expand All @@ -23,95 +24,101 @@
from catalyst.preprocess import catalyst_acceptance, decompose
from catalyst.utils.exceptions import CompileError
from catalyst.utils.patching import Patcher
from catalyst.utils.runtime import (
BackendInfo,
get_pennylane_observables,
get_pennylane_operations,
)
from catalyst.utils.runtime import BackendInfo, device_get_toml_config
from catalyst.utils.toml import (
DeviceCapabilities,
OperationProperties,
ProgramFeatures,
TOMLDocument,
check_adjoint_flag,
check_mid_circuit_measurement_flag,
get_device_capabilities,
intersect_operations,
pennylane_operation_set,
)

# fmt:off
RUNTIME_OPERATIONS = {
"CNOT",
"ControlledPhaseShift",
"CRot",
"CRX",
"CRY",
"CRZ",
"CSWAP",
"CY",
"CZ",
"Hadamard",
"Identity",
"IsingXX",
"IsingXY",
"IsingYY",
"ISWAP",
"MultiRZ",
"PauliX",
"PauliY",
"PauliZ",
"PhaseShift",
"PSWAP",
"QubitUnitary",
"Rot",
"RX",
"RY",
"RZ",
"S",
"SWAP",
"T",
"Toffoli",
"GlobalPhase",
"C(GlobalPhase)",
"C(Hadamard)",
"C(IsingXX)",
"C(IsingXY)",
"C(IsingYY)",
"C(ISWAP)",
"C(MultiRZ)",
"ControlledQubitUnitary",
"C(PauliX)",
"C(PauliY)",
"C(PauliZ)",
"C(PhaseShift)",
"C(PSWAP)",
"C(Rot)",
"C(RX)",
"C(RY)",
"C(RZ)",
"C(S)",
"C(SWAP)",
"C(T)",
'CNOT': OperationProperties(invertible=True, controllable=True, differentiable=True),
'ControlledPhaseShift':
OperationProperties(invertible=True, controllable=True, differentiable=True),
'CRot': OperationProperties(invertible=True, controllable=True, differentiable=True),
'CRX': OperationProperties(invertible=True, controllable=True, differentiable=True),
'CRY': OperationProperties(invertible=True, controllable=True, differentiable=True),
'CRZ': OperationProperties(invertible=True, controllable=True, differentiable=True),
'CSWAP': OperationProperties(invertible=True, controllable=True, differentiable=True),
'CY': OperationProperties(invertible=True, controllable=True, differentiable=True),
'CZ': OperationProperties(invertible=True, controllable=True, differentiable=True),
'Hadamard': OperationProperties(invertible=True, controllable=True, differentiable=True),
'Identity': OperationProperties(invertible=True, controllable=True, differentiable=True),
'IsingXX': OperationProperties(invertible=True, controllable=True, differentiable=True),
'IsingXY': OperationProperties(invertible=True, controllable=True, differentiable=True),
'IsingYY': OperationProperties(invertible=True, controllable=True, differentiable=True),
'ISWAP': OperationProperties(invertible=True, controllable=True, differentiable=True),
'MultiRZ': OperationProperties(invertible=True, controllable=True, differentiable=True),
'PauliX': OperationProperties(invertible=True, controllable=True, differentiable=True),
'PauliY': OperationProperties(invertible=True, controllable=True, differentiable=True),
'PauliZ': OperationProperties(invertible=True, controllable=True, differentiable=True),
'PhaseShift': OperationProperties(invertible=True, controllable=True, differentiable=True),
'PSWAP': OperationProperties(invertible=True, controllable=True, differentiable=True),
'QubitUnitary': OperationProperties(invertible=True, controllable=True, differentiable=True),
'ControlledQubitUnitary':
OperationProperties(invertible=True, controllable=True, differentiable=True),
'Rot': OperationProperties(invertible=True, controllable=True, differentiable=True),
'RX': OperationProperties(invertible=True, controllable=True, differentiable=True),
'RY': OperationProperties(invertible=True, controllable=True, differentiable=True),
'RZ': OperationProperties(invertible=True, controllable=True, differentiable=True),
'S': OperationProperties(invertible=True, controllable=True, differentiable=True),
'SWAP': OperationProperties(invertible=True, controllable=True, differentiable=True),
'T': OperationProperties(invertible=True, controllable=True, differentiable=True),
'Toffoli': OperationProperties(invertible=True, controllable=True, differentiable=True),
'GlobalPhase': OperationProperties(invertible=True, controllable=True, differentiable=True),
}
# fmt:on


def get_qjit_pennylane_operations(
config: TOMLDocument, shots_present: bool, device_name: str
) -> Set[str]:
def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Set[str]:
"""Calculate the set of supported quantum gates for the QJIT device from the gates
allowed on the target quantum device."""
# Supported gates of the target PennyLane's device
native_gates = get_pennylane_operations(config, shots_present, device_name)
qjit_config = deepcopy(target_capabilities)

# Gates that Catalyst runtime supports
qir_gates = RUNTIME_OPERATIONS
supported_gates = set.intersection(native_gates, qir_gates)

# Intersection of the above
qjit_config.native_ops = intersect_operations(target_capabilities.native_ops, qir_gates)

# Control-flow gates to be lowered down to the LLVM control-flow instructions
supported_gates.update({"Cond", "WhileLoop", "ForLoop"})
qjit_config.native_ops.update(
{
"Cond": OperationProperties(invertible=True, controllable=True, differentiable=True),
"WhileLoop": OperationProperties(
invertible=True, controllable=True, differentiable=True
),
"ForLoop": OperationProperties(invertible=True, controllable=True, differentiable=True),
}
)

# Optionally enable runtime-powered mid-circuit measurments
if check_mid_circuit_measurement_flag(config): # pragma: no branch
supported_gates.update({"MidCircuitMeasure"})
if target_capabilities.mid_circuit_measurement_flag: # pragma: no branch
qjit_config.native_ops.update(
{
"MidCircuitMeasure": OperationProperties(
invertible=True, controllable=True, differentiable=True
)
}
)

# Optionally enable runtime-powered quantum gate adjointing (inversions)
if check_adjoint_flag(config, shots_present):
supported_gates.update({"Adjoint"})
if all(ng.invertible for ng in target_capabilities.native_ops.values()):
qjit_config.native_ops.update(
{
"Adjoint": OperationProperties(
invertible=True, controllable=True, differentiable=True
)
}
)

return supported_gates
return qjit_config


class QJITDevice(qml.QubitDevice):
Expand All @@ -137,7 +144,7 @@ class QJITDevice(qml.QubitDevice):
author = ""

@staticmethod
def _get_operations_to_convert_to_matrix(_config: TOMLDocument) -> Set[str]:
def _get_operations_to_convert_to_matrix(_capabilities: DeviceCapabilities) -> Set[str]:
# We currently override and only set a few gates to preserve existing behaviour.
# We could choose to read from config and use the "matrix" gates.
# However, that affects differentiability.
Expand All @@ -154,25 +161,26 @@ def __init__(
):
super().__init__(wires=wires, shots=shots)

self.target_config = target_config
self.backend_name = backend.c_interface_name if backend else "default"
self.backend_lib = backend.lpath if backend else ""
self.backend_kwargs = backend.kwargs if backend else {}
device_name = backend.device_name if backend else "default"

shots_present = shots is not None
self._operations = get_qjit_pennylane_operations(target_config, shots_present, device_name)
self._observables = get_pennylane_observables(target_config, shots_present, device_name)
program_features = ProgramFeatures(shots is not None)
target_device_capabilities = get_device_capabilities(
target_config, program_features, device_name
)
self.capabilities = get_qjit_device_capabilities(target_device_capabilities)

@property
def operations(self) -> Set[str]:
"""Get the device operations"""
return self._operations
"""Get the device operations using PennyLane's syntax"""
return pennylane_operation_set(self.capabilities.native_ops)

@property
def observables(self) -> Set[str]:
"""Get the device observables"""
return self._observables
return pennylane_operation_set(self.capabilities.native_obs)

def apply(self, operations, **kwargs):
"""
Expand Down Expand Up @@ -202,7 +210,7 @@ def default_expand_fn(self, circuit, max_expansion=10):
raise CompileError("Must use 'measure' from Catalyst instead of PennyLane.")

decompose_to_qubit_unitary = QJITDevice._get_operations_to_convert_to_matrix(
self.target_config
self.capabilities
)

def _decomp_to_unitary(self, *_args, **_kwargs):
Expand Down Expand Up @@ -251,7 +259,6 @@ class QJITDeviceNewAPI(qml.devices.Device):
def __init__(
self,
original_device,
target_config: TOMLDocument,
backend: Optional[BackendInfo] = None,
):
self.original_device = original_device
Expand All @@ -264,25 +271,27 @@ def __init__(

super().__init__(wires=original_device.wires, shots=original_device.shots)

self.target_config = target_config
self.backend_name = backend.c_interface_name if backend else "default"
self.backend_lib = backend.lpath if backend else ""
self.backend_kwargs = backend.kwargs if backend else {}
device_name = backend.device_name if backend else "default"

shots_present = original_device.shots is not None
self._operations = get_qjit_pennylane_operations(target_config, shots_present, device_name)
self._observables = get_pennylane_observables(target_config, shots_present, device_name)
target_config = device_get_toml_config(original_device)
program_features = ProgramFeatures(original_device.shots is not None)
target_device_capabilities = get_device_capabilities(
target_config, program_features, device_name
)
self.capabilities = get_qjit_device_capabilities(target_device_capabilities)

@property
def operations(self) -> Set[str]:
"""Get the device operations"""
return self._operations
return pennylane_operation_set(self.capabilities.native_ops)

@property
def observables(self) -> Set[str]:
"""Get the device observables"""
return self._observables
return pennylane_operation_set(self.capabilities.native_obs)

def preprocess(
self,
Expand Down
Loading
Loading