Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Pass capabilities to the decomposer #749

Merged
merged 36 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
0ede63d
Make a test tomle-schema-2-compatible
May 3, 2024
1babc99
Address formatting issues
May 3, 2024
30de70c
Make test toml-schema-independant
May 6, 2024
9f985b0
Merge branch 'main' into toml-schema-2-update-test
May 6, 2024
f97e7e9
Address codecov errors
May 6, 2024
40499b4
Address codecov errors
May 6, 2024
0ed92b6
Address codecov errors
May 6, 2024
22cfe45
Merge remote-tracking branch 'origin/main' into toml-schema-2-update-…
May 9, 2024
d00c7e4
Fix wrong field name
May 9, 2024
d3ffae5
Address review suggestions; Move code around
May 16, 2024
aa37716
Merge remote-tracking branch 'origin/main' into toml-schema-2-update-…
May 16, 2024
6537b2a
Address formatting issues
May 16, 2024
468c3ca
Address formatting issues
May 16, 2024
72c9f24
Address pylint issues
May 16, 2024
611e5fc
Add missing paths module
May 16, 2024
dd15715
Pass capabilities to the decomposer
May 16, 2024
9a5e30f
Address pylint issues
May 16, 2024
485230a
Merge remote-tracking branch 'origin/main' into decompose-to-matrix-u…
May 21, 2024
13ace0f
Address formatting issues
May 21, 2024
67efb9c
Merge branch 'toml-schema-2-update-test' into decompose-to-matrix-usi…
May 21, 2024
f26a0e1
Fix tests
May 21, 2024
508457d
Address codefactor issues
May 21, 2024
d578c62
Fix formatting issues
May 21, 2024
3448527
Address review suggestions: add a requested test
May 22, 2024
5cf74d1
Merge branch 'main' into decompose-to-matrix-using-capabilities
lillian542 May 22, 2024
3e8d720
Merge remote-tracking branch 'origin/main' into decompose-to-matrix-u…
May 27, 2024
55e2d3e
Address formatting issues
May 27, 2024
c24ad8f
Revert abstract-method
May 27, 2024
e871ebd
Revert paths.py
May 27, 2024
e2bbe64
Rever runtime.py
May 27, 2024
653a0e0
Rever validate_device_capabilities
May 27, 2024
99aa125
Rever validate_device_capabilities
May 27, 2024
2aaef4b
Rever validate_device_capabilities
May 27, 2024
486316a
Revert old device api shots
May 27, 2024
18a3413
Revert abstract-method
May 27, 2024
2a218dc
Add a lit-test
May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions frontend/catalyst/device/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
compilation & execution on devices.
"""

from functools import partial

import jax
import pennylane as qml
from pennylane import transform
Expand All @@ -37,15 +39,16 @@
from catalyst.jax_tracer import HybridOpRegion, has_nested_tapes
from catalyst.tracing.contexts import EvaluationContext
from catalyst.utils.exceptions import CompileError
from catalyst.utils.toml import DeviceCapabilities


def catalyst_decomposer(op):
def catalyst_decomposer(op, capabilities: DeviceCapabilities):
"""A decomposer for catalyst, to be passed to the decompose transform. Takes an operator and
returns the default decomposition, unless the operator should decompose to a QubitUnitary.
Raises a CompileError for MidMeasureMP"""
if isinstance(op, MidMeasureMP):
raise CompileError("Must use 'measure' from Catalyst instead of PennyLane.")
if op.name in {"MultiControlledX", "BlockEncode"} or isinstance(op, qml.ops.Controlled):
if capabilities.to_matrix_ops.get(op.name) or isinstance(op, qml.ops.Controlled):
return _decompose_to_matrix(op)
return op.decomposition()

Expand All @@ -55,7 +58,7 @@ def catalyst_decompose(
tape: qml.tape.QuantumTape,
ctx,
stopping_condition,
decomposer=catalyst_decomposer,
capabilities,
max_expansion=None,
):
"""Decompose operations until the stopping condition is met.
Expand All @@ -76,7 +79,7 @@ def catalyst_decompose(
tape,
stopping_condition,
skip_initial_state_prep=False,
decomposer=decomposer,
decomposer=partial(catalyst_decomposer, capabilities=capabilities),
max_expansion=max_expansion,
name="catalyst on this device",
error=CompileError,
Expand All @@ -85,7 +88,7 @@ def catalyst_decompose(
new_ops = []
for op in toplevel_tape.operations:
if has_nested_tapes(op):
op = _decompose_nested_tapes(op, ctx, stopping_condition, decomposer, max_expansion)
op = _decompose_nested_tapes(op, ctx, stopping_condition, capabilities, max_expansion)
new_ops.append(op)
tape = qml.tape.QuantumScript(new_ops, tape.measurements, shots=tape.shots)

Expand All @@ -103,7 +106,7 @@ def _decompose_to_matrix(op):
return [op]


def _decompose_nested_tapes(op, ctx, stopping_condition, decomposer, max_expansion):
def _decompose_nested_tapes(op, ctx, stopping_condition, capabilities, max_expansion):
new_regions = []
for region in op.regions:
if region.quantum_tape is None:
Expand All @@ -114,7 +117,7 @@ def _decompose_nested_tapes(op, ctx, stopping_condition, decomposer, max_expansi
region.quantum_tape,
ctx=ctx,
stopping_condition=stopping_condition,
decomposer=decomposer,
capabilities=capabilities,
max_expansion=max_expansion,
)
new_tape = tapes[0]
Expand Down
7 changes: 6 additions & 1 deletion frontend/catalyst/device/qjit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,66 +104,66 @@
lpath: str
kwargs: Dict[str, Any]


def extract_backend_info(device: qml.QubitDevice, capabilities: DeviceCapabilities) -> BackendInfo:
"""Extract the backend info from a quantum device. The device is expected to carry a reference
to a valid TOML config file."""
# pylint: disable=too-many-branches

dname = device.name
if isinstance(device, qml.Device):
dname = device.short_name

device_name = ""
device_lpath = ""
device_kwargs = {}

if dname in SUPPORTED_RT_DEVICES:
# Support backend devices without `get_c_interface`
device_name = SUPPORTED_RT_DEVICES[dname][0]
device_lpath = get_lib_path("runtime", "RUNTIME_LIB_DIR")
sys_platform = platform.system()

if sys_platform == "Linux":
device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".so")
elif sys_platform == "Darwin": # pragma: no cover
device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".dylib")
else: # pragma: no cover
raise NotImplementedError(f"Platform not supported: {sys_platform}")
elif hasattr(device, "get_c_interface"):
# Support third party devices with `get_c_interface`
device_name, device_lpath = device.get_c_interface()
else:
raise CompileError(f"The {dname} device does not provide C interface for compilation.")

if not pathlib.Path(device_lpath).is_file():
raise CompileError(f"Device at {device_lpath} cannot be found!")

if hasattr(device, "shots"):
if isinstance(device, qml.Device):
device_kwargs["shots"] = device.shots if device.shots else 0
else:
# TODO: support shot vectors
device_kwargs["shots"] = device.shots.total_shots if device.shots else 0

if dname == "braket.local.qubit": # pragma: no cover
device_kwargs["device_type"] = dname
device_kwargs["backend"] = (
# pylint: disable=protected-access
device._device._delegate.DEVICE_ID
)
elif dname == "braket.aws.qubit": # pragma: no cover
device_kwargs["device_type"] = dname
device_kwargs["device_arn"] = device._device._arn # pylint: disable=protected-access
if device._s3_folder: # pylint: disable=protected-access
device_kwargs["s3_destination_folder"] = str(
device._s3_folder # pylint: disable=protected-access
)

for k, v in capabilities.options.items():
if hasattr(device, v):
device_kwargs[k] = getattr(device, v)

Check notice on line 166 in frontend/catalyst/device/qjit_device.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/device/qjit_device.py#L107-L166

Complex Method
return BackendInfo(dname, device_name, device_lpath, device_kwargs)


Expand Down Expand Up @@ -393,7 +393,12 @@
program = TransformProgram()

ops_acceptance = partial(catalyst_acceptance, operations=self.operations)
program.add_transform(catalyst_decompose, ctx=ctx, stopping_condition=ops_acceptance)
program.add_transform(
catalyst_decompose,
ctx=ctx,
stopping_condition=ops_acceptance,
capabilities=self.qjit_capabilities,
)

if self.measurement_processes == {"Counts"}:
program.add_transform(measurements_from_counts)
Expand Down
74 changes: 57 additions & 17 deletions frontend/test/lit/test_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,29 @@
# RUN: %PYTHON %s | FileCheck %s
# pylint: disable=line-too-long

import platform
from copy import deepcopy

import jax
import pennylane as qml

from catalyst import cond, for_loop, measure, qjit, while_loop
from catalyst.compiler import get_lib_path
from catalyst.utils.toml import (
OperationProperties,
ProgramFeatures,
get_device_capabilities,
pennylane_operation_set,
)


def get_custom_device_without(num_wires, discards):
def get_custom_device_without(num_wires, discards=frozenset(), force_matrix=frozenset()):
"""Generate a custom device without gates in discards."""

class CustomDevice(qml.QubitDevice):
class CustomDevice(qml.devices.Device):
"""Custom Gate Set Device"""

name = "Custom Device"
short_name = "lightning.qubit"
pennylane_requires = "0.35.0"
version = "0.0.2"
author = "Tester"
Expand All @@ -55,12 +57,13 @@ def __init__(self, shots=None, wires=None):
)
custom_capabilities = deepcopy(lightning_capabilities)
for gate in discards:
if gate in custom_capabilities.native_ops:
custom_capabilities.native_ops.pop(gate)
if gate in custom_capabilities.to_decomp_ops:
custom_capabilities.to_decomp_ops.pop(gate)
if gate in custom_capabilities.to_matrix_ops:
custom_capabilities.to_matrix_ops.pop(gate)
custom_capabilities.native_ops.pop(gate, None)
custom_capabilities.to_decomp_ops.pop(gate, None)
custom_capabilities.to_matrix_ops.pop(gate, None)
for gate in force_matrix:
custom_capabilities.native_ops.pop(gate, None)
custom_capabilities.to_decomp_ops.pop(gate, None)
custom_capabilities.to_matrix_ops[gate] = OperationProperties(False, False, False)
self.qjit_capabilities = custom_capabilities

def apply(self, operations, **kwargs):
Expand All @@ -81,12 +84,27 @@ def observables(self):
"""Return PennyLane observables"""
return pennylane_operation_set(self.qjit_capabilities.native_obs)

@staticmethod
def get_c_interface():
"""Returns a tuple consisting of the device name, and
the location to the shared object with the C/C++ device implementation.
"""
system_extension = ".dylib" if platform.system() == "Darwin" else ".so"
lib_path = (
get_lib_path("runtime", "RUNTIME_LIB_DIR") + "/librtd_dummy" + system_extension
)
return "dummy.remote", lib_path

def execute(self, circuits, execution_config):
"""Execution."""
return circuits, execution_config

return CustomDevice(wires=num_wires)


def test_decompose_multicontrolledx():
"""Test decomposition of MultiControlledX."""
dev = get_custom_device_without(5, {"MultiControlledX"})
dev = get_custom_device_without(5, discards={"MultiControlledX"})

@qjit(target="mlir")
@qml.qnode(dev)
Expand All @@ -107,7 +125,7 @@ def decompose_multicontrolled_x1(theta: float):

def test_decompose_multicontrolledx_in_conditional():
"""Test decomposition of MultiControlledX in conditional."""
dev = get_custom_device_without(5, {"MultiControlledX"})
dev = get_custom_device_without(5, discards={"MultiControlledX"})

@qjit(target="mlir")
@qml.qnode(dev)
Expand All @@ -133,7 +151,7 @@ def cond_fn():

def test_decompose_multicontrolledx_in_while_loop():
"""Test decomposition of MultiControlledX in while loop."""
dev = get_custom_device_without(5, {"MultiControlledX"})
dev = get_custom_device_without(5, discards={"MultiControlledX"})

@qjit(target="mlir")
@qml.qnode(dev)
Expand All @@ -160,7 +178,7 @@ def loop(v):

def test_decompose_multicontrolledx_in_for_loop():
"""Test decomposition of MultiControlledX in for loop."""
dev = get_custom_device_without(5, {"MultiControlledX"})
dev = get_custom_device_without(5, discards={"MultiControlledX"})

@qjit(target="mlir")
@qml.qnode(dev)
Expand All @@ -186,7 +204,7 @@ def loop(_):

def test_decompose_rot():
"""Test decomposition of Rot gate."""
dev = get_custom_device_without(1, {"Rot", "C(Rot)"})
dev = get_custom_device_without(1, discards={"Rot", "C(Rot)"})

@qjit(target="mlir")
@qml.qnode(dev)
Expand Down Expand Up @@ -216,7 +234,7 @@ def decompose_rot(phi: float, theta: float, omega: float):

def test_decompose_s():
"""Test decomposition of S gate."""
dev = get_custom_device_without(1, {"S", "C(S)"})
dev = get_custom_device_without(1, discards={"S", "C(S)"})

@qjit(target="mlir")
@qml.qnode(dev)
Expand All @@ -238,7 +256,7 @@ def decompose_s():

def test_decompose_qubitunitary():
"""Test decomposition of QubitUnitary"""
dev = get_custom_device_without(1, {"QubitUnitary"})
dev = get_custom_device_without(1, discards={"QubitUnitary"})

@qjit(target="mlir")
@qml.qnode(dev)
Expand All @@ -260,7 +278,7 @@ def decompose_qubit_unitary(U: jax.core.ShapedArray([2, 2], float)):

def test_decompose_singleexcitationplus():
"""Test decomposition of single excitation plus."""
dev = get_custom_device_without(2, {"SingleExcitationPlus", "C(SingleExcitationPlus)"})
dev = get_custom_device_without(2, discards={"SingleExcitationPlus", "C(SingleExcitationPlus)"})

@qjit(target="mlir")
@qml.qnode(dev)
Expand Down Expand Up @@ -304,3 +322,25 @@ def decompose_singleexcitationplus(theta: float):


test_decompose_singleexcitationplus()


def test_decompose_to_matrix():
"""Test decomposition of QubitUnitary"""
dev = get_custom_device_without(1, force_matrix={"PauliY"})

@qjit(target="mlir")
@qml.qnode(dev)
# CHECK-LABEL: public @jit_decompose_to_matrix
def decompose_to_matrix():
# CHECK: quantum.custom "PauliX"
qml.PauliX(wires=0)
# CHECK: quantum.unitary
qml.PauliY(wires=0)
# CHECK: quantum.custom "PauliZ"
qml.PauliZ(wires=0)
return measure(wires=0)

print(decompose_to_matrix.mlir)


test_decompose_to_matrix()
2 changes: 1 addition & 1 deletion frontend/test/pytest/test_config_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import pytest

from catalyst.device import QJITDevice, validate_device_capabilities
from catalyst.device.qjit_device import check_no_overlap
from catalyst.device.qjit_device import check_no_overlap, validate_device_capabilities
from catalyst.utils.exceptions import CompileError
from catalyst.utils.toml import (
DeviceCapabilities,
Expand Down
Loading
Loading