Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Make tests toml-schema independent #712

Merged
merged 27 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0ede63d
Make a test tomle-schema-2-compatible
May 3, 2024
1babc99
Address formatting issues
May 3, 2024
30de70c
Make test toml-schema-independant
May 6, 2024
9f985b0
Merge branch 'main' into toml-schema-2-update-test
May 6, 2024
f97e7e9
Address codecov errors
May 6, 2024
40499b4
Address codecov errors
May 6, 2024
0ed92b6
Address codecov errors
May 6, 2024
22cfe45
Merge remote-tracking branch 'origin/main' into toml-schema-2-update-…
May 9, 2024
d00c7e4
Fix wrong field name
May 9, 2024
d3ffae5
Address review suggestions; Move code around
May 16, 2024
aa37716
Merge remote-tracking branch 'origin/main' into toml-schema-2-update-…
May 16, 2024
6537b2a
Address formatting issues
May 16, 2024
468c3ca
Address formatting issues
May 16, 2024
72c9f24
Address pylint issues
May 16, 2024
611e5fc
Add missing paths module
May 16, 2024
e94c55d
Merge branch 'main' into toml-schema-2-update-test
May 21, 2024
5bd08cb
Fix re-added runtime.py
May 22, 2024
f59e6bc
Fix a test
May 22, 2024
8aadcf4
Merge remote-tracking branch 'origin/main' into toml-schema-2-update-…
May 22, 2024
1ba9149
Update frontend/catalyst/jax_primitives.py
May 24, 2024
4735b85
Update frontend/catalyst/third_party/cuda/primitives/__init__.py
May 24, 2024
74ec4ff
Rename paths -> runtime_environment
May 24, 2024
8793214
Address review suggestions: move validate_device_requirements -> qjit…
May 24, 2024
04a5b41
Address review suggestions: remove self.qjit_device attribute from qn…
May 24, 2024
2bb3ede
Fix a test import
May 24, 2024
722886c
Merge branch 'main' into toml-schema-2-update-test
May 24, 2024
eadeb91
Update changelog
May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from catalyst.utils.exceptions import CompileError
from catalyst.utils.filesystem import Directory
from catalyst.utils.runtime import get_lib_path
from catalyst.utils.toml import get_lib_path

package_root = os.path.dirname(__file__)

Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ def get_jaxpr(self, *args):
an MLIR module
"""

def cudaq_backend_info(device, _config) -> BackendInfo:
def cudaq_backend_info(device, _capabilities) -> BackendInfo:
"""The extract_backend_info should not be run by the cuda compiler as it is
catalyst-specific. We need to make this API a bit nicer for third-party compilers.
"""
Expand Down
33 changes: 10 additions & 23 deletions frontend/catalyst/device/qjit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,10 @@
)
from catalyst.utils.exceptions import CompileError
from catalyst.utils.patching import Patcher
from catalyst.utils.runtime import BackendInfo, device_get_toml_config
from catalyst.utils.runtime import BackendInfo
from catalyst.utils.toml import (
DeviceCapabilities,
OperationProperties,
ProgramFeatures,
TOMLDocument,
get_device_capabilities,
intersect_operations,
pennylane_operation_set,
)
Expand Down Expand Up @@ -164,7 +161,7 @@ def _get_operations_to_convert_to_matrix(_capabilities: DeviceCapabilities) -> S

def __init__(
self,
target_config: TOMLDocument,
original_device_capabilities: DeviceCapabilities,
shots=None,
wires=None,
backend: Optional[BackendInfo] = None,
Expand All @@ -174,23 +171,18 @@ def __init__(
self.backend_name = backend.c_interface_name if backend else "default"
self.backend_lib = backend.lpath if backend else ""
self.backend_kwargs = backend.kwargs if backend else {}
device_name = backend.device_name if backend else "default"

program_features = ProgramFeatures(shots is not None)
target_device_capabilities = get_device_capabilities(
target_config, program_features, device_name
)
self.capabilities = get_qjit_device_capabilities(target_device_capabilities)
self.qjit_capabilities = get_qjit_device_capabilities(original_device_capabilities)

@property
def operations(self) -> Set[str]:
"""Get the device operations using PennyLane's syntax"""
return pennylane_operation_set(self.capabilities.native_ops)
return pennylane_operation_set(self.qjit_capabilities.native_ops)

@property
def observables(self) -> Set[str]:
"""Get the device observables"""
return pennylane_operation_set(self.capabilities.native_obs)
return pennylane_operation_set(self.qjit_capabilities.native_obs)

def apply(self, operations, **kwargs):
"""
Expand Down Expand Up @@ -269,6 +261,7 @@ class QJITDeviceNewAPI(qml.devices.Device):
def __init__(
self,
original_device,
original_device_capabilities: DeviceCapabilities,
backend: Optional[BackendInfo] = None,
):
self.original_device = original_device
Expand All @@ -284,29 +277,23 @@ def __init__(
self.backend_name = backend.c_interface_name if backend else "default"
self.backend_lib = backend.lpath if backend else ""
self.backend_kwargs = backend.kwargs if backend else {}
device_name = backend.device_name if backend else "default"

target_config = device_get_toml_config(original_device)
program_features = ProgramFeatures(original_device.shots is not None)
target_device_capabilities = get_device_capabilities(
target_config, program_features, device_name
)
self.capabilities = get_qjit_device_capabilities(target_device_capabilities)
self.qjit_capabilities = get_qjit_device_capabilities(original_device_capabilities)

@property
def operations(self) -> Set[str]:
"""Get the device operations"""
return pennylane_operation_set(self.capabilities.native_ops)
return pennylane_operation_set(self.qjit_capabilities.native_ops)

@property
def observables(self) -> Set[str]:
"""Get the device observables"""
return pennylane_operation_set(self.capabilities.native_obs)
return pennylane_operation_set(self.qjit_capabilities.native_obs)

@property
def measurement_processes(self) -> Set[str]:
"""Get the device measurement processes"""
return self.capabilities.measurement_processes
return self.qjit_capabilities.measurement_processes

def preprocess(
self,
Expand Down
35 changes: 23 additions & 12 deletions frontend/catalyst/qfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@
from catalyst.jax_tracer import trace_quantum_function
from catalyst.utils.runtime import (
BackendInfo,
device_get_toml_config,
extract_backend_info,
validate_config_with_device,
validate_device_capabilities,
)
from catalyst.utils.toml import (
DeviceCapabilities,
ProgramFeatures,
get_device_capabilities,
)
from catalyst.utils.toml import TOMLDocument


class QFunc:
Expand All @@ -54,26 +57,34 @@ def __new__(cls):
raise NotImplementedError() # pragma: no-cover

@staticmethod
def extract_backend_info(device: qml.QubitDevice, config: TOMLDocument) -> BackendInfo:
def extract_backend_info(
device: qml.QubitDevice, capabilities: DeviceCapabilities
) -> BackendInfo:
"""Wrapper around extract_backend_info in the runtime module."""
return extract_backend_info(device, config)
return extract_backend_info(device, capabilities)

# pylint: disable=no-member
# pylint: disable=no-member, attribute-defined-outside-init
def __call__(self, *args, **kwargs):
assert isinstance(self, qml.QNode)

config = device_get_toml_config(self.device)
validate_config_with_device(self.device, config)
backend_info = QFunc.extract_backend_info(self.device, config)
device = self.device
program_features = ProgramFeatures(device.shots is not None)
device_capabilities = get_device_capabilities(device, program_features)
backend_info = QFunc.extract_backend_info(device, device_capabilities)

# Validate decive operations against the declared capabilities
validate_device_capabilities(device, device_capabilities)
dime10 marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(self.device, qml.devices.Device):
device = QJITDeviceNewAPI(self.device, backend_info)
self.qjit_device = QJITDeviceNewAPI(device, device_capabilities, backend_info)
else:
device = QJITDevice(config, self.device.shots, self.device.wires, backend_info)
self.qjit_device = QJITDevice(
device_capabilities, device.shots, device.wires, backend_info
)

def _eval_quantum(*args):
closed_jaxpr, out_type, out_tree = trace_quantum_function(
self.func, device, args, kwargs, qnode=self
self.func, self.qjit_device, args, kwargs, qnode=self
)
args_expanded = get_implicit_and_explicit_flat_args(None, *args)
res_expanded = eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args_expanded)
Expand Down
67 changes: 8 additions & 59 deletions frontend/catalyst/utils/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,13 @@

import pennylane as qml

from catalyst._configuration import INSTALLED
from catalyst.utils.exceptions import CompileError
from catalyst.utils.toml import (
ProgramFeatures,
TOMLDocument,
get_device_capabilities,
DeviceCapabilities,
get_lib_path,
pennylane_operation_set,
read_toml_file,
)

package_root = os.path.dirname(__file__)


# Default paths to dep libraries
DEFAULT_LIB_PATHS = {
"llvm": os.path.join(package_root, "../../../mlir/llvm-project/build/lib"),
"runtime": os.path.join(package_root, "../../../runtime/build/lib"),
"enzyme": os.path.join(package_root, "../../../mlir/Enzyme/build/Enzyme"),
"oqc_runtime": os.path.join(package_root, "../../catalyst/oqc/src/build"),
}


# TODO: This should be removed after implementing `get_c_interface`
# for the following backend devices:
SUPPORTED_RT_DEVICES = {
Expand All @@ -58,13 +43,6 @@
}


def get_lib_path(project, env_var):
"""Get the library path."""
if INSTALLED:
return os.path.join(package_root, "..", "lib") # pragma: no cover
return os.getenv(env_var, DEFAULT_LIB_PATHS.get(project, ""))


def check_no_overlap(*args, device_name):
"""Check items in *args are mutually exclusive.

Expand Down Expand Up @@ -109,7 +87,9 @@ def is_not_adj(op):
return set(operations_no_adj)


def validate_config_with_device(device: qml.QubitDevice, config: TOMLDocument) -> None:
def validate_device_capabilities(
device: qml.QubitDevice, device_capabilities: DeviceCapabilities
) -> None:
"""Validate configuration document against the device attributes.
Raise CompileError in case of mismatch:
* If device is not qjit-compatible.
Expand All @@ -125,15 +105,13 @@ def validate_config_with_device(device: qml.QubitDevice, config: TOMLDocument) -
Raises: CompileError
"""

if not config["compilation"]["qjit_compatible"]:
if not device_capabilities.qjit_compatible_flag:
raise CompileError(
f"Attempting to compile program for incompatible device '{device.name}': "
f"Config is not marked as qjit-compatible"
)

device_name = device.short_name if isinstance(device, qml.Device) else device.name
program_features = ProgramFeatures(device.shots is not None)
device_capabilities = get_device_capabilities(config, program_features, device_name)

native = pennylane_operation_set(device_capabilities.native_ops)
decomposable = pennylane_operation_set(device_capabilities.to_decomp_ops)
Expand Down Expand Up @@ -163,34 +141,6 @@ def validate_config_with_device(device: qml.QubitDevice, config: TOMLDocument) -
)


def device_get_toml_config(device) -> TOMLDocument:
"""Get the contents of the device config file."""
if hasattr(device, "config"):
# The expected case: device specifies its own config.
toml_file = device.config
else:
# TODO: Remove this section when `qml.Device`s are guaranteed to have their own config file
# field.
device_lpath = pathlib.Path(get_lib_path("runtime", "RUNTIME_LIB_DIR"))

name = device.short_name if isinstance(device, qml.Device) else device.name
# The toml files name convention we follow is to replace
# the dots with underscores in the device short name.
toml_file_name = name.replace(".", "_") + ".toml"
# And they are currently saved in the following directory.
toml_file = device_lpath.parent / "lib" / "backend" / toml_file_name

try:
config = read_toml_file(toml_file)
except FileNotFoundError as e:
raise CompileError(
"Attempting to compile program for incompatible device: "
f"Config file ({toml_file}) does not exist"
) from e

return config


@dataclass
class BackendInfo:
"""Backend information"""
Expand All @@ -201,7 +151,7 @@ class BackendInfo:
kwargs: Dict[str, Any]


def extract_backend_info(device: qml.QubitDevice, config: TOMLDocument) -> BackendInfo:
def extract_backend_info(device: qml.QubitDevice, capabilities: DeviceCapabilities) -> BackendInfo:
"""Extract the backend info from a quantum device. The device is expected to carry a reference
to a valid TOML config file."""

Expand Down Expand Up @@ -255,8 +205,7 @@ def extract_backend_info(device: qml.QubitDevice, config: TOMLDocument) -> Backe
device._s3_folder # pylint: disable=protected-access
)

options = config.get("options", {})
for k, v in options.items():
for k, v in capabilities.options.items():
if hasattr(device, v):
device_kwargs[k] = getattr(device, v)

Expand Down
Loading
Loading