Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Refactor compiler #38

Merged
merged 55 commits into from Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
4d04b68
Add MHLOPass.
erick-xanadu Mar 3, 2023
2d4ce70
Add BufferizationPass.
erick-xanadu Mar 3, 2023
17ebfee
Add MlirToLLVMDialect.
erick-xanadu Mar 3, 2023
471af80
Add LLVMDialectToLLVMIR.
erick-xanadu Mar 3, 2023
327ba73
Add LLVMIRToObjectFile.
erick-xanadu Mar 3, 2023
dba0244
Adds abstract class Pass.
erick-xanadu Mar 6, 2023
5904fd4
Use run instead of link.
erick-xanadu Mar 6, 2023
834bcb4
Refactor compiler.py
erick-xanadu Mar 6, 2023
b651de5
Move compile time signature to utilities.
erick-xanadu Mar 6, 2023
b262425
Move keep_intermediate logic to Compiler class.
erick-xanadu Mar 6, 2023
400db9d
Move appending helper functions.
erick-xanadu Mar 6, 2023
851119a
Allow passes being sent from qjit decorator.
erick-xanadu Mar 14, 2023
2e032ad
Allow returning None for passes not run.
erick-xanadu Mar 14, 2023
92735d2
Linting.
erick-xanadu Mar 14, 2023
13b3d34
Remove redundant functions.
erick-xanadu Apr 4, 2023
f9f8845
Small changes.
erick-xanadu Apr 4, 2023
ce78c2e
Some refactoring.
erick-xanadu Apr 4, 2023
fd6c3dc
Improving tests.
erick-xanadu Apr 5, 2023
41df848
Remove type annotation.
erick-xanadu Apr 13, 2023
1cf0e80
Change variable name.
erick-xanadu Apr 13, 2023
a73f5ab
Rename class.
erick-xanadu Apr 13, 2023
a4b508a
Test that an executable was specified.
erick-xanadu Apr 13, 2023
e203036
Change docstring.
erick-xanadu Apr 13, 2023
32ec9f9
Remove pylint disable pragma.
erick-xanadu Apr 13, 2023
8380092
Refactor function.
erick-xanadu Apr 13, 2023
9705df5
Change variable name.
erick-xanadu Apr 13, 2023
726f4c1
Remove pylint pragma.
erick-xanadu Apr 13, 2023
9d1f131
Default flags are computed before linking.
erick-xanadu Apr 13, 2023
8e994c9
Add type annotations.
erick-xanadu Apr 13, 2023
0841faa
Remove l from lrt.
erick-xanadu Apr 13, 2023
9df0366
Move utils to compilation_pipeline.py
erick-xanadu Apr 13, 2023
4333779
Add type annotations.
erick-xanadu Apr 13, 2023
9d597de
Remove pylint pragma.
erick-xanadu Apr 13, 2023
c863745
Move append_modules.
erick-xanadu Apr 13, 2023
f46294e
passes -> pipelines and documentation.
erick-xanadu Apr 13, 2023
9519697
Merge branch 'main' into eochoa/2022-03-07/compiler-refactor-2
erick-xanadu Apr 13, 2023
ac74e1c
Changelog.
erick-xanadu Apr 13, 2023
7919ab3
Update frontend/catalyst/compiler.py
erick-xanadu Apr 14, 2023
00eabe7
Docstring.
erick-xanadu Apr 14, 2023
d03637d
Add new options to CompilerOptions.
erick-xanadu Apr 14, 2023
581a843
Fix link error.
erick-xanadu Apr 14, 2023
755ad05
Modify compiler diagram.
erick-xanadu Apr 14, 2023
d92dcba
Nice error for float16.
erick-xanadu Apr 14, 2023
a39568c
Apply suggestions from code review
erick-xanadu Apr 14, 2023
6c94b2c
Fix docstring.
erick-xanadu Apr 14, 2023
72600f3
Rename function.
erick-xanadu Apr 14, 2023
26df791
Remove get_logfile function.
erick-xanadu Apr 14, 2023
40a3e05
No cover.
erick-xanadu Apr 14, 2023
0ca30d5
CompileOptions are required.
erick-xanadu Apr 14, 2023
851c7b6
Pylint.
erick-xanadu Apr 14, 2023
18a3658
Partial coverage.
erick-xanadu Apr 14, 2023
48798ee
Revert "CompileOptions are required."
erick-xanadu Apr 14, 2023
c0a0d7c
Fix tests.
erick-xanadu Apr 14, 2023
4e47da5
CodeFactor.
erick-xanadu Apr 14, 2023
f17b0a1
Merge branch 'main' into eochoa/2023-03-07/compiler-refactor-2
erick-xanadu Apr 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
763 changes: 1 addition & 762 deletions doc/_static/arch/compiler.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions doc/changelog.md
Expand Up @@ -4,6 +4,10 @@

<h3>Improvements</h3>

* Improve interface for adding and re-using flags to quantum-opt commands.
These are called pipelines, as they contain multiple passes.
[#38](https://github.com/PennyLaneAI/catalyst/pull/38)

<h3>Breaking changes</h3>

<h3>Bug fixes</h3>
Expand Down
10 changes: 7 additions & 3 deletions doc/dev/arch/compiler.puml
Expand Up @@ -15,6 +15,8 @@ Boundary(c1, "Compiler", $link="https://github.com/plantuml-stdlib/C4-PlantUML")

Boundary(c2, " ", $link="https://github.com/plantuml-stdlib/C4-PlantUML") {

Component(_opt_mlir_, "Gradient-to-Quantum", "quantum-opt", "Lower Gradient dialect to Quantum dialect..")

Component(_buff_mlir_, "Bufferization", "quantum-opt", "Bufferize scf, tensor, arith, quantum, etc.")

Component(_llvm_mlir_, "LLVM-MLIR", "quantum-opt", "Lower the code to the LLVM Dialect in MLIR")
Expand All @@ -27,7 +29,7 @@ Boundary(c1, "Compiler", $link="https://github.com/plantuml-stdlib/C4-PlantUML")

Component(_o_, "Compile", "llc", "Compile the LLVM code using the LLVM Static Compiler")

Component(_so_, "Link", "c99", "Link the compiled file to the Runtime")
Component(_so_, "Link", "Compiler available in system", "Link the compiled file to the Runtime")
}
}

Expand All @@ -36,9 +38,11 @@ System_Ext(_catalyst_runtime_, "Catalyst-Runtime")

Rel_R(_mlir_sys_, _nohlo_mlir_, "*.mlir", "")

Rel_R(_nohlo_mlir_, _buff_mlir_, "*.nohlo.mlir", "")
Rel_R(_nohlo_mlir_, _opt_mlir_, "*.nohlo.mlir", "")

Rel_D(_opt_mlir_, _buff_mlir_, "*.opt.mlir", "")

Rel(_buff_mlir_, _llvm_mlir_, "*.buff.mlir", "")
Rel_D(_buff_mlir_, _llvm_mlir_, "*.buff.mlir", "")

Rel_L(_llvm_mlir_, _ll_, "*.llvm.mlir", "")

Expand Down
3 changes: 2 additions & 1 deletion doc/dev/architecture.rst
Expand Up @@ -40,7 +40,8 @@ Compiler Workflow
^^^^^^^^^^^^^^^^^

To understand the workflow and tools being used in the Compiler class, we
present the following diagram.
present the following diagram. Please note that individual passes are selected
at runtime and may be configured by the user, but this is the default pipeline.

|br|

Expand Down
162 changes: 68 additions & 94 deletions frontend/catalyst/compilation_pipelines.py
dime10 marked this conversation as resolved.
Show resolved Hide resolved
Expand Up @@ -15,15 +15,11 @@
compiling of hybrid quantum-classical functions using Catalyst.
"""

# pylint: disable=missing-module-docstring

import ctypes
import functools
import warnings
import inspect
import os
import tempfile
import typing
import warnings

import jax
import numpy as np
Expand All @@ -35,11 +31,11 @@
to_numpy,
)

from catalyst.utils.gen_mlir import append_modules
import catalyst.jax_tracer as tracer
from catalyst import compiler
from catalyst.compiler import Compiler
from catalyst.compiler import CompileOptions
from catalyst.pennylane_extensions import QFunc
from catalyst.utils.gen_mlir import append_modules
from catalyst.utils.patching import Patcher
from catalyst.utils.tracing import TracingContext

Expand All @@ -52,7 +48,22 @@
jax.config.update("jax_array", True)


# pylint: disable=too-many-return-statements
def are_params_annotated(f: typing.Callable):
"""Return true if all parameters are typed-annotated."""
signature = inspect.signature(f)
parameters = signature.parameters
return all(p.annotation is not inspect.Parameter.empty for p in parameters.values())


def get_type_annotations(func: typing.Callable):
"""Get all type annotations if all parameters are typed-annotated."""
params_are_annotated = are_params_annotated(func)
if params_are_annotated:
return getattr(func, "__annotations__", {}).values()

return None


def mlir_type_to_numpy_type(t):
"""Convert an MLIR type to a Numpy type.

Expand All @@ -63,30 +74,33 @@
Raises:
TypeError
"""
retval = None
if ir.ComplexType.isinstance(t):
base = ir.ComplexType(t).element_type
if ir.F64Type.isinstance(base):
return np.complex128
retval = np.complex128
if ir.F32Type.isinstance(base):
return np.complex64
retval = np.complex64
elif ir.F64Type.isinstance(t):
retval = np.float64
elif ir.F32Type.isinstance(t):
retval = np.float32

Check warning on line 87 in frontend/catalyst/compilation_pipelines.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/compilation_pipelines.py#L87

Added line #L87 was not covered by tests
elif ir.F16Type.isinstance(t):
retval = np.float16

Check warning on line 89 in frontend/catalyst/compilation_pipelines.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/compilation_pipelines.py#L89

Added line #L89 was not covered by tests
elif ir.IntegerType(t).width == 1:
retval = np.bool_
elif ir.IntegerType(t).width == 8:
retval = np.int8

Check warning on line 93 in frontend/catalyst/compilation_pipelines.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/compilation_pipelines.py#L93

Added line #L93 was not covered by tests
elif ir.IntegerType(t).width == 16:
retval = np.int16

Check warning on line 95 in frontend/catalyst/compilation_pipelines.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/compilation_pipelines.py#L95

Added line #L95 was not covered by tests
elif ir.IntegerType(t).width == 32:
retval = np.int32

Check warning on line 97 in frontend/catalyst/compilation_pipelines.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/compilation_pipelines.py#L97

Added line #L97 was not covered by tests
elif ir.IntegerType(t).width == 64:
retval = np.int64

if retval is None:
raise TypeError("No known type")
if ir.F64Type.isinstance(t):
return np.float64
if ir.F32Type.isinstance(t):
return np.float32
if ir.F16Type.isinstance(t):
return np.float16
if ir.IntegerType(t).width == 1:
return np.bool_
if ir.IntegerType(t).width == 8:
return np.int8
if ir.IntegerType(t).width == 16:
return np.int16
if ir.IntegerType(t).width == 32:
return np.int32
if ir.IntegerType(t).width == 64:
return np.int64
raise TypeError("No known type")
return retval


class CompiledFunction:
Expand Down Expand Up @@ -217,37 +231,6 @@
arg_type = type(arg)
raise TypeError(f"Unsupported argument type: {arg_type}") from exc

@staticmethod
def are_all_signature_params_annotated(f: typing.Callable):
"""Determine if all parameters are typed.

Args:
f: callable, with possible annotation
Returns:
bool: whether all parameters are annotated
"""
signature = inspect.signature(f)
parameters = signature.parameters
return all(p.annotation is not inspect.Parameter.empty for p in parameters.values())

@staticmethod
def get_compile_time_signature(f: typing.Callable) -> typing.List[typing.Any]:
"""Get signature from parameter annotations.

Args:
f: callable, with possible annotations
Returns:
annotations for all parameters if possible

"""
can_validate = CompiledFunction.are_all_signature_params_annotated(f)

if can_validate:
# Needed instead of inspect.get_annotations for Python < 3.10.
return getattr(f, "__annotations__", {}).values()

return None

@staticmethod
def zero_ranked_memref_to_numpy(ranked_memref):
"""Cast a zero ranked memrefs to a numpy array.
Expand Down Expand Up @@ -442,48 +425,29 @@

Args:
fn (Callable): the quantum or classical function
target (str): the compilation target
keep_intermediate (bool): Whether or not to store the intermediate files throughout the
compilation. If ``True``, the current working directory keeps
readable representations of the compiled module which remain available
after the Python process ends. If ``False``, these representations
will instead be stored in a temporary folder, which will be deleted
as soon as the QJIT instance is deleted.
compile_options (Optional[CompileOptions]): Common compilation options
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, fn, target, keep_intermediate, compile_options=None):
# pylint: disable=too-many-arguments
def __init__(self, fn, compile_options):
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
self.qfunc = fn
self.c_sig = None
functools.update_wrapper(self, fn)
if keep_intermediate:
dirname = fn.__name__
parent_dir = os.getcwd()
path = os.path.join(parent_dir, dirname)
os.makedirs(path, exist_ok=True)
self.workspace_name = path
else:
# The temporary directory must be referenced by the wrapper class
# in order to avoid being garbage collected
# pylint: disable=consider-using-with
self.workspace = tempfile.TemporaryDirectory()
self.workspace_name = self.workspace.name
self.compile_options = compile_options
self.passes = {}
self._compiler = Compiler()
self._jaxpr = None
self._mlir = None
self._llvmir = None
self.mlir_module = None
self.compiled_function = None
parameter_types = get_type_annotations(self.qfunc)
self.runtime = fn.device.short_name if isinstance(fn, qml.QNode) else "best"

parameter_types = CompiledFunction.get_compile_time_signature(self.qfunc)
self.user_typed = False
if parameter_types is not None:
self.user_typed = True
if target in ("mlir", "binary"):
if self.compile_options.target in ("mlir", "binary"):
self.mlir_module = self.get_mlir(*parameter_types)
if target == "binary":
if self.compile_options.target == "binary":
self.compiled_function = self.compile()

def print_stage(self, stage):
Expand All @@ -493,9 +457,7 @@
Args:
stage: string corresponding with the name of the stage to be printed
"""
if self.passes.get(stage):
with open(self.passes[stage], "r", encoding="utf-8") as f:
print(f.read())
self._compiler.print(stage) # pragma: nocover

@property
def mlir(self):
Expand Down Expand Up @@ -534,21 +496,22 @@
):
mlir_module, ctx, jaxpr = tracer.get_mlir(self.qfunc, *self.c_sig)

# Inject setup and finalize functions.
append_modules(mlir_module, self.runtime, ctx)
mod = mlir_module.operation
self._jaxpr = jaxpr
self._mlir = mod.get_asm(binary=False, print_generic_op_form=False, assume_verified=True)
dime10 marked this conversation as resolved.
Show resolved Hide resolved

# Inject setup and finalize functions.
append_modules(mlir_module, self.runtime, ctx)

return mlir_module

def compile(self):
"""Compile the current MLIR module."""

shared_object, self._llvmir = compiler.compile(
self.mlir_module, self.workspace_name, self.passes, self.compile_options
shared_object = self._compiler.run(
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
self.mlir_module,
options=self.compile_options,
)
self._llvmir = self._compiler.get_output_of("LLVMDialectToLLVMIR")

# The function name out of MLIR has quotes around it, which we need to remove.
# The MLIR function name is actually a derived type from string which has no
Expand Down Expand Up @@ -585,7 +548,15 @@
return self.compiled_function(*args, **kwargs)


def qjit(fn=None, *, target="binary", keep_intermediate=False, verbose=False, logfile=None):
def qjit(
fn=None,
*,
target="binary",
keep_intermediate=False,
verbose=False,
logfile=None,
pipelines=None,
):
"""A just-in-time decorator for PennyLane and JAX programs using Catalyst.

This decorator enables both just-in-time and ahead-of-time compilation,
Expand All @@ -607,6 +578,9 @@
verbosity (int): Verbosity level (0 - disabled, >0 - enabled)
logfile (Optional[TextIOWrapper]): File object to write verose messages to (default -
sys.stderr).
pipelines (Optional(List[AnyType]): A list of pipelines to be executed. The elements of
the list are asked to implement a run method which takes the output of the previous run
as an input to the next element, and so on.

Returns:
QJIT object.
Expand Down Expand Up @@ -673,9 +647,9 @@
"""

if fn is not None:
return QJIT(fn, target, keep_intermediate, CompileOptions(verbose, logfile))
return QJIT(fn, CompileOptions(verbose, logfile, target, keep_intermediate, pipelines))

def wrap_fn(fn):
return QJIT(fn, target, keep_intermediate, CompileOptions(verbose, logfile))
return QJIT(fn, CompileOptions(verbose, logfile, target, keep_intermediate, pipelines))

return wrap_fn