Skip to content

Commit

Permalink
passes -> pipelines and documentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
erick-xanadu committed Apr 13, 2023
1 parent c863745 commit 96f53eb
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 29 deletions.
24 changes: 18 additions & 6 deletions frontend/catalyst/compilation_pipelines.py
Expand Up @@ -432,17 +432,20 @@ class QJIT:
after the Python process ends. If ``False``, these representations
will instead be stored in a temporary folder, which will be deleted
as soon as the QJIT instance is deleted.
pipelines (Optional(List[AnyType]): A list of pipelines to be executed. The elements of
the list are asked to implement a run method which takes the output of the previous run
as an input to the next element, and so on.
compile_options (Optional[CompileOptions]): Common compilation options
"""

# pylint: disable=too-many-arguments
def __init__(self, fn, target, keep_intermediate, passes, compile_options=None):
def __init__(self, fn, target, keep_intermediate, pipelines, compile_options=None):
self.qfunc = fn
self.c_sig = None
functools.update_wrapper(self, fn)
self.compile_options = compile_options
self._compiler = Compiler()
self.passes = passes
self.pipelines = pipelines
self._jaxpr = None
self._mlir = None
self._llvmir = None
Expand Down Expand Up @@ -518,7 +521,7 @@ def compile(self):
shared_object = self._compiler.run(
self.mlir_module,
keep_intermediate=self.keep_intermediate,
passes=self.passes,
pipelines=self.pipelines,
options=self.compile_options,
)
self._llvmir = self._compiler.get_output_of("LLVMDialectToLLVMIR")
Expand Down Expand Up @@ -559,7 +562,13 @@ def __call__(self, *args, **kwargs):


def qjit(
fn=None, *, target="binary", keep_intermediate=False, verbose=False, logfile=None, passes=None
fn=None,
*,
target="binary",
keep_intermediate=False,
verbose=False,
logfile=None,
pipelines=None,
):
"""A just-in-time decorator for PennyLane and JAX programs using Catalyst.
Expand All @@ -582,6 +591,9 @@ def qjit(
verbosity (int): Verbosity level (0 - disabled, >0 - enabled)
logfile (Optional[TextIOWrapper]): File object to write verose messages to (default -
sys.stderr).
pipelines (Optional(List[AnyType]): A list of pipelines to be executed. The elements of
the list are asked to implement a run method which takes the output of the previous run
as an input to the next element, and so on.
Returns:
QJIT object.
Expand Down Expand Up @@ -648,9 +660,9 @@ def circuit(x: complex, z: ShapedArray(shape=(3,), dtype=jnp.float64)):
"""

if fn is not None:
return QJIT(fn, target, keep_intermediate, passes, CompileOptions(verbose, logfile))
return QJIT(fn, target, keep_intermediate, pipelines, CompileOptions(verbose, logfile))

def wrap_fn(fn):
return QJIT(fn, target, keep_intermediate, passes, CompileOptions(verbose, logfile))
return QJIT(fn, target, keep_intermediate, pipelines, CompileOptions(verbose, logfile))

return wrap_fn
38 changes: 19 additions & 19 deletions frontend/catalyst/compiler.py
Expand Up @@ -407,7 +407,7 @@ def __init__(self):
# pylint: disable=consider-using-with
self.workspace = tempfile.TemporaryDirectory()

def run(self, mlir_module, keep_intermediate=False, passes=None, options=None):
def run(self, mlir_module, keep_intermediate=False, pipelines=None, options=None):
"""Compile an MLIR module to a shared object.
.. note::
Expand All @@ -417,7 +417,7 @@ def run(self, mlir_module, keep_intermediate=False, passes=None, options=None):
Args:
mlir_module (Module): the MLIR module
passes (List[Any]): the list of compilation passes
pipelines (List[Any]): the list of compilation pipelines
Returns:
Shared object
Expand All @@ -436,8 +436,8 @@ def run(self, mlir_module, keep_intermediate=False, passes=None, options=None):
else:
workspace_name = self.workspace.name

if passes is None:
passes = [
if pipelines is None:
pipelines = [
MHLOPass,
QuantumCompilationPass,
BufferizationPass,
Expand All @@ -453,41 +453,41 @@ def run(self, mlir_module, keep_intermediate=False, passes=None, options=None):
with open(filename, "w", encoding="utf-8") as f:
mlir_module.operation.print(f, print_generic_op_form=False, assume_verified=True)

for _pass in passes:
output = _pass.run(filename, options=options)
self.pass_pipeline_output[_pass] = output
for pipeline in pipelines:
output = pipeline.run(filename, options=options)
self.pass_pipeline_output[pipeline] = output
filename = os.path.abspath(output)

return filename

@staticmethod
def _get_class_from_string(_pass):
return getattr(sys.modules[__name__], _pass)
def _get_class_from_string(pipeline):
return getattr(sys.modules[__name__], pipeline)

def _get_output_file_of(self, _pass):
cls = Compiler._get_class_from_string(_pass)
def _get_output_file_of(self, pipeline):
cls = Compiler._get_class_from_string(pipeline)
return self.pass_pipeline_output.get(cls)

def get_output_of(self, _pass):
"""Get the output IR of a pass.
def get_output_of(self, pipeline):
"""Get the output IR of a pipeline.
Args:
_pass (str): name of pass class
pipeline (str): name of pass class
Returns
(str): output IR
"""
try:
fname = self._get_output_file_of(_pass)
fname = self._get_output_file_of(pipeline)
except AttributeError as e:
raise ValueError(f"Output for pass {_pass} not found.") from e
raise ValueError(f"Output for pass {pipeline} not found.") from e
with open(fname, "r", encoding="utf-8") as f:
txt = f.read()
return txt

def print(self, _pass):
def print(self, pipeline):
"""Print the output IR of pass.
Args:
_pass (str): name of pass class
pipeline (str): name of pass class
"""
txt = self.get_output_of(_pass)
txt = self.get_output_of(pipeline)
print(txt)
9 changes: 5 additions & 4 deletions frontend/test/pytest/test_compiler.py
Expand Up @@ -81,6 +81,7 @@ def test_no_executable(self):
"""Test that executable was set from a custom PassPipeline."""

class CustomClassWithNoExecutable(PassPipeline):
# pylint: disable=missing-class-docstring
_default_flags = ["some-command-but-it-is-actually-a-flag"]

with pytest.raises(ValueError, match="Executable not specified."):
Expand Down Expand Up @@ -165,9 +166,9 @@ def workflow():

mlir_module, _, _ = get_mlir(workflow)
# This means that we are not running any pass.
passes = []
pipelines = []
identity_compiler = Compiler()
identity_compiler.run(mlir_module, keep_intermediate=True, passes=passes)
identity_compiler.run(mlir_module, keep_intermediate=True, pipelines=pipelines)
directory = os.path.join(os.getcwd(), workflow.__name__)
assert os.path.exists(directory)
files = os.listdir(directory)
Expand All @@ -185,9 +186,9 @@ def workflow():

mlir_module, _, _ = get_mlir(workflow)
# This means that we are not running any pass.
passes = []
pipelines = []
identity_compiler = Compiler()
identity_compiler.run(mlir_module, passes=passes)
identity_compiler.run(mlir_module, pipelines=pipelines)
files = os.listdir(identity_compiler.workspace.name)
# The directory is non-empty. Should at least contain the original .mlir file
assert files

0 comments on commit 96f53eb

Please sign in to comment.