Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callbacks with parameters. #596

Merged
merged 117 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
32e115a
Add initial registry.
erick-xanadu Nov 29, 2023
ca19af6
Add example python module.
erick-xanadu Nov 29, 2023
b2a37b6
Add build commands.
erick-xanadu Nov 29, 2023
7f11485
More general.
erick-xanadu Nov 29, 2023
91b7294
Simple example of registering a function.
erick-xanadu Nov 29, 2023
846e217
Get primitives and fail to do the lowering.
erick-xanadu Nov 29, 2023
1fe10e4
wip
erick-xanadu Nov 30, 2023
68b47e7
Comment.
erick-xanadu Nov 30, 2023
cc41be0
Lower callback to CustomCallOp.
erick-xanadu Nov 30, 2023
345620c
Proof of concept.
erick-xanadu Dec 1, 2023
0850ef9
Fix merge.
erick-xanadu Feb 23, 2024
0e9fe62
Add pybind11 to build time dependencies for runtime.
erick-xanadu Feb 23, 2024
191aba1
Add build time dependencies for runtime tests.
erick-xanadu Feb 23, 2024
ae16cfe
Add build time dependencies for runtime coverage
erick-xanadu Feb 23, 2024
1e8bf35
Move pyregistry definition into lib/capi/RuntimeCapi.cpp
erick-xanadu Feb 23, 2024
02387d3
WIP
erick-xanadu Feb 23, 2024
b2e3c79
WIP
erick-xanadu Mar 5, 2024
88a264e
Depend on pyregistry.
erick-xanadu Mar 5, 2024
ad096cd
WIP
erick-xanadu Mar 5, 2024
845f380
Simple leaky implementation
erick-xanadu Mar 5, 2024
fc676d8
Add test
erick-xanadu Mar 5, 2024
845d18c
Rpath is correctly set.
erick-xanadu Mar 5, 2024
f22f9ec
style
erick-xanadu Mar 6, 2024
c054c27
dlopen for any build configuration
erick-xanadu Mar 6, 2024
83e9be0
Bake in the library name.
erick-xanadu Mar 6, 2024
78f53fb
Remove makefile rule
erick-xanadu Mar 6, 2024
7ab87cb
Adding registry into wheel
erick-xanadu Mar 7, 2024
cb85ceb
Black
erick-xanadu Mar 7, 2024
4f4d25e
Black
erick-xanadu Mar 7, 2024
effe085
Style
erick-xanadu Mar 7, 2024
0e75cdf
Style
erick-xanadu Mar 7, 2024
e9c17a0
Isort
erick-xanadu Mar 7, 2024
ef49931
Code factor
erick-xanadu Mar 7, 2024
4a1d33e
CodeFactor
erick-xanadu Mar 7, 2024
b85ff04
Black
erick-xanadu Mar 7, 2024
ca3620e
Style
erick-xanadu Mar 7, 2024
f56ed7b
Isort
erick-xanadu Mar 7, 2024
91c271c
Pylint
erick-xanadu Mar 7, 2024
e5e406c
Coverage
erick-xanadu Mar 7, 2024
753a100
Add PyCustomCall
erick-xanadu Mar 7, 2024
6e73659
Initial skeleton for MLIR pass.
erick-xanadu Mar 7, 2024
66240da
Skeleton
erick-xanadu Mar 7, 2024
1c7fd0b
Revert "Skeleton"
erick-xanadu Mar 7, 2024
62deee0
Revert "Initial skeleton for MLIR pass."
erick-xanadu Mar 7, 2024
b63cc22
Create PythonCallOp and lowering.
erick-xanadu Mar 7, 2024
31232d0
Creates a new op for python callbacks.
erick-xanadu Mar 7, 2024
ccae140
Style
erick-xanadu Mar 7, 2024
f6a1ee7
Codefactor
erick-xanadu Mar 7, 2024
18b2ccf
Add error handling for pyregistry
erick-xanadu Mar 7, 2024
74fa4b1
fixup runtime
erick-xanadu Mar 7, 2024
25c1995
Fix condition.
erick-xanadu Mar 7, 2024
554d2a8
Use a wrapper to improve UI/UX
erick-xanadu Mar 7, 2024
b7cac32
Code factor
erick-xanadu Mar 8, 2024
bbebe8e
Typo
erick-xanadu Mar 8, 2024
0a22ad1
Move include.
erick-xanadu Mar 11, 2024
50c3d26
Comment
erick-xanadu Mar 11, 2024
953377f
No output during compilation.
erick-xanadu Mar 11, 2024
bf79111
CMake
erick-xanadu Mar 11, 2024
0e1ed63
Makefile
erick-xanadu Mar 11, 2024
e3c81e8
Test
erick-xanadu Mar 11, 2024
e799350
Remove include add comment
erick-xanadu Mar 13, 2024
a650111
Test
erick-xanadu Mar 13, 2024
f6027fb
line length
erick-xanadu Mar 13, 2024
3f0cc15
Test
erick-xanadu Mar 13, 2024
7b17177
Change registry to catalyst_callback_registry.
erick-xanadu Mar 13, 2024
560a83d
Rename
erick-xanadu Mar 13, 2024
9cdfa7b
Add changelog
erick-xanadu Mar 18, 2024
a757186
WIP
erick-xanadu Mar 18, 2024
764f1a2
Evaluate the function if not in tracing context.
erick-xanadu Mar 18, 2024
885f8b8
Merge branch 'main' into eochoa/2024-02-23/callbacks
erick-xanadu Mar 18, 2024
83c745b
black
erick-xanadu Mar 18, 2024
2ac35fc
Isort
erick-xanadu Mar 18, 2024
ef6177f
Add mutex for python
erick-xanadu Mar 18, 2024
9b6f441
Comments
erick-xanadu Mar 18, 2024
3a90bba
Comments
erick-xanadu Mar 18, 2024
1f3bd6a
Merge branch 'main' into eochoa/2024-02-23/callbacks
erick-xanadu Mar 19, 2024
683b608
Fix python lock
erick-xanadu Mar 19, 2024
6114629
Add parameters to callbacks
erick-xanadu Mar 8, 2024
898978e
Add parameters in MLIR
erick-xanadu Mar 8, 2024
57b1266
Add argc attribute
erick-xanadu Mar 8, 2024
ccf645d
Custom builder
erick-xanadu Mar 8, 2024
22432bb
Custom builder
erick-xanadu Mar 8, 2024
9422c99
Stub for runtime...
erick-xanadu Mar 8, 2024
2df0b1f
Typo
erick-xanadu Mar 8, 2024
f2bb274
Add arguments to python call
erick-xanadu Mar 11, 2024
f946b41
WIP
erick-xanadu Mar 12, 2024
7367abc
Add support for translating ShapedArray's to ctypes unintialized values
erick-xanadu Mar 12, 2024
cbdadb2
Proof of concept
erick-xanadu Mar 12, 2024
29e028c
Typo
erick-xanadu Mar 12, 2024
81e2344
Documentation
erick-xanadu Mar 13, 2024
4663b06
Test kwargs
erick-xanadu Mar 13, 2024
4e95be9
isort
erick-xanadu Mar 13, 2024
4aecd50
Isort
erick-xanadu Mar 13, 2024
ee4384b
Do not keep intermediate
erick-xanadu Mar 13, 2024
48357e4
Unnecessary import.
erick-xanadu Mar 13, 2024
fe80bf3
Delete dead code
erick-xanadu Mar 13, 2024
f0a8959
Additional test
erick-xanadu Mar 13, 2024
da762b2
f
erick-xanadu Mar 13, 2024
2a793cb
Comments
erick-xanadu Mar 13, 2024
f548acd
Remove number_original_arg
erick-xanadu Mar 13, 2024
79197a1
Dead code
erick-xanadu Mar 13, 2024
871d455
Mock np_to_memref
erick-xanadu Mar 13, 2024
049fde7
Empty
erick-xanadu Mar 15, 2024
8cf8b3a
Fixing rebase
erick-xanadu Mar 19, 2024
3c9f6f2
Fix changelog
erick-xanadu Mar 20, 2024
f4ca039
Merge branch 'main' into eochoa/2024-03-08/callbacks-parameters
erick-xanadu Mar 20, 2024
9b5552d
Code factor
erick-xanadu Mar 20, 2024
3a5965b
Typo and maxsize
erick-xanadu Mar 20, 2024
0d3b7cf
Use cached property
erick-xanadu Mar 20, 2024
9db215e
Fix
erick-xanadu Mar 20, 2024
5ff9a7c
Fix
erick-xanadu Mar 20, 2024
65a1c44
Do not use cache
erick-xanadu Mar 20, 2024
f416fc1
Update doc/changelog.md
erick-xanadu Mar 20, 2024
0948fbd
Rename function
erick-xanadu Mar 22, 2024
0bc1a61
Comments
erick-xanadu Mar 22, 2024
3034a3a
Constrain type
erick-xanadu Mar 22, 2024
bed566f
Comments
erick-xanadu Mar 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,28 @@

* Support for callbacks in Catalyst.
[(#540)](https://github.com/PennyLaneAI/catalyst/pull/540)
[(#596)](https://github.com/PennyLaneAI/catalyst/pull/596)

Catalyst now supports callbacks without parameters nor return values.
Catalyst now supports callbacks with parameters but no return values.
This is the very first step in supporting callbacks.
The following is now possible.
The following is now possible:
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved

```py
@callback
def foo():
print("Hello world")
def foo(val):
print("Hello world", val)

@qjit
def circuit(*args, **kwargs):
...
foo()
foo(123)
...

```

```pycon
>>> circuit()
Hello world 123
```

* The python layer of the OQC-Catalyst device is now available.
Expand Down
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __getattr__(cls, name):
MOCK_MODULES = [
"mlir_quantum",
"mlir_quantum.runtime",
"mlir_quantum.runtime.np_to_memref",
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
"mlir_quantum.dialects",
"mlir_quantum.dialects.arith",
"mlir_quantum.dialects.tensor",
Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/compiled_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from jax.tree_util import PyTreeDef, tree_flatten, tree_unflatten
from mlir_quantum.runtime import (
as_ctype,
get_ranked_memref_descriptor,
make_nd_memref_descriptor,
make_zero_d_memref_descriptor,
)
Expand All @@ -38,6 +37,7 @@
from catalyst.utils import wrapper # pylint: disable=no-name-in-module
from catalyst.utils.c_template import get_template, mlir_type_to_numpy_type
from catalyst.utils.filesystem import Directory
from catalyst.utils.jnp_to_memref import get_ranked_memref_descriptor


class SharedObjectManager:
Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _python_callback_lowering(jax_ctx: mlir.LoweringRuleContext, *args, callback
ctx = jax_ctx.module_context.context
i64_type = ir.IntegerType.get_signless(64, ctx)
identifier = ir.IntegerAttr.get(i64_type, callback_id)
return PythonCallOp(identifier).results
return PythonCallOp(args, identifier).results


#
Expand Down
60 changes: 53 additions & 7 deletions frontend/catalyst/pennylane_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# pylint: disable=too-many-lines

import copy
import ctypes
import inspect
import numbers
from collections.abc import Sequence, Sized
Expand Down Expand Up @@ -95,6 +96,10 @@
JaxTracingContext,
)
from catalyst.utils.exceptions import DifferentiableCompileError
from catalyst.utils.jnp_to_memref import (
get_ranked_memref_descriptor,
ranked_memref_to_numpy,
)
from catalyst.utils.runtime import (
BackendInfo,
device_get_toml_config,
Expand Down Expand Up @@ -2555,6 +2560,42 @@ def _get_batch_size(args_flat, axes_flat, axis_size):
return batch_size


class CallbackClosure:
"""This is just a class containing data that is important for the callback."""

def __init__(self, *absargs, **abskwargs):
self.absargs = absargs
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
self.abskwargs = abskwargs

@property
def tree_flatten(self):
"""Flatten args and kwargs."""
return tree_flatten((self.absargs, self.abskwargs))

@property
def low_level_sig(self):
"""Get the memref descriptor types"""
flat_params, _ = self.tree_flatten
low_level_flat_params = []
for param in flat_params:
empty_memref_descriptor = get_ranked_memref_descriptor(param)
memref_type = type(empty_memref_descriptor)
ptr_ty = ctypes.POINTER(memref_type)
low_level_flat_params.append(ptr_ty)
return low_level_flat_params

def getArgsAsJAXArrays(self, flat_args):
"""Get arguments as JAX arrays. Since our integration is mostly compatible with JAX,
it is best for the user if we continue with that idea and forward JAX arrays."""
jnpargs = []
for void_ptr, ty in zip(flat_args, self.low_level_sig):
memref_ty = ctypes.cast(void_ptr, ty)
nparray = ranked_memref_to_numpy(memref_ty)
jnparray = jnp.asarray(nparray)
jnpargs.append(jnparray)
return jnpargs


def callback(func):
"""Decorator that will correctly pass the signature as arguments to the callback
implementation.
Expand Down Expand Up @@ -2587,15 +2628,20 @@ def callback_implementation(
"""

flat_args, in_tree = tree_flatten((args, kwargs))
metadata = CallbackClosure(args, kwargs)

def _flat_callback(flat_args):
"""Each flat_arg is a pointer.

It is a pointer to a memref object.
To find out which element type it has, we use the signature obtained previously.
"""
jnpargs = metadata.getArgsAsJAXArrays(flat_args)

def _flat_callback(*flat_args):
"""This function packages flat arguments back into the shapes expected by the function."""
_args, _kwargs = tree_unflatten(in_tree, flat_args)
assert not _args, "Args are not yet expected here."
assert not _kwargs, "Kwargs are not yet supported here."
return tree_leaves(cb())
args, kwargs = tree_unflatten(in_tree, jnpargs)
return tree_leaves(cb(*args, **kwargs))

# TODO(@erick-xanadu): Change back once we support arguments.
# TODO(@erick-xanadu): Change back once we support return values.
# I am leaving this as a to-do because otherwise the coverage will complain.
# results_aval = tree_map(lambda x: jax.core.ShapedArray(x.shape, x.dtype), result_shape_dtypes)
results_aval = []
Expand Down
74 changes: 74 additions & 0 deletions frontend/catalyst/utils/jnp_to_memref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# This file is essentially a re-implementation of np_to_memref.py
# But we make sure that all the functions also work with JAX's numpy.
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# This file contains functions to convert between Memrefs and NumPy arrays and vice-versa.

"""
This file is a wrapper around MLIR's np_to_memref to allow for abstract types and JAX arrays
to be converted to and from memrefs.
"""

import numpy as np
from mlir_quantum.runtime import as_ctype
from mlir_quantum.runtime import (
get_ranked_memref_descriptor as mlir_get_ranked_memref_descriptor,
)
from mlir_quantum.runtime import ranked_memref_to_numpy as mlir_ranked_memref_to_numpy
from mlir_quantum.runtime.np_to_memref import (
make_nd_memref_descriptor,
make_zero_d_memref_descriptor,
move_aligned_ptr_by_offset,
to_numpy,
)

from catalyst.jax_extras import DynamicJaxprTracer, ShapedArray


def get_ranked_memref_descriptor_from_shaped_array(array: ShapedArray):
"""Get a ranked memref descriptor from a shaped array.

Unlike MLIR's implementation, all values are left uninitialized.
This is because the values are not yet known. We only have a description
of the type.
"""

ctp = as_ctype(array.dtype)
if array.ndim == 0:
return make_zero_d_memref_descriptor(ctp)()

return make_nd_memref_descriptor(array.ndim, ctp)()


def get_ranked_memref_descriptor(array):
"""Wrapper around MLIR's get_ranked_memref_descriptor."""

if isinstance(array, DynamicJaxprTracer):
array = array.aval

if isinstance(array, (int, float, bool, complex)):
# This is necessary for keyword arguments
array = np.array(array)

if isinstance(array, ShapedArray):
# If input is ShapedArray
return get_ranked_memref_descriptor_from_shaped_array(array)

# Use default implementation from MLIR's library.
return mlir_get_ranked_memref_descriptor(array)


def ranked_memref_to_numpy(ranked_memref):
"""Wrapper around MLIR's ranked_memref_to_numpy.

This wrapper succeeds when the ranked_memref is a scalar tensor.
"""
try:
return mlir_ranked_memref_to_numpy(ranked_memref)
except AttributeError:
# zero dimensional tensor...
content_ptr = move_aligned_ptr_by_offset(ranked_memref[0].aligned, ranked_memref[0].offset)
np_arr = np.ctypeslib.as_array(content_ptr, shape=[])
return to_numpy(np_arr)
42 changes: 42 additions & 0 deletions frontend/test/pytest/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

@pytest.mark.parametrize("arg", [1, 2, 3])
def test_callback_no_tracing(arg):
"""Test that when there's no tracing the behaviour of identity
stays the same."""

@callback
def identity(x):
return x
Expand Down Expand Up @@ -78,3 +81,42 @@ def cir2():
cir2()
captured = capsys.readouterr()
assert captured.out.strip() == "Hello erick"


def test_callback_send_param(capsys):
"""Test callback with parameters no returns"""

@callback
def my_callback(n) -> None:
print(n)

@qml.qjit
def cir(n):
my_callback(n)
return None

cir(0)
captured = capsys.readouterr()
assert captured.out.strip() == "0"


def test_kwargs(capsys):
"""Test kwargs returns"""

@callback
def my_callback(**kwargs) -> None:
for k, v in kwargs.items():
print(k, v)

@qml.qjit
def cir(a, b, c):
my_callback(a=a, b=b, c=c, d=3, e=4)
return None

captured = capsys.readouterr()
assert captured.out.strip() == ""

cir(0, 1, 2)
captured = capsys.readouterr()
for string in ["a 0", "b 1", "c 2", "d 3", "e 4"]:
assert string in captured.out
61 changes: 61 additions & 0 deletions frontend/test/pytest/test_np_to_memref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests memref descriptor functions with JAX's ShapedArrays."""

import ctypes

import jax.numpy as jnp
import pytest
from jax.core import ShapedArray
from mlir_quantum.runtime import (
C64,
C128,
as_ctype,
make_nd_memref_descriptor,
make_zero_d_memref_descriptor,
)

from catalyst.utils.jnp_to_memref import get_ranked_memref_descriptor


@pytest.mark.parametrize(
"inp, exp",
[
(jnp.dtype(jnp.float64), ctypes.c_double),
(jnp.dtype(jnp.float32), ctypes.c_float),
(jnp.dtype(jnp.int64), ctypes.c_long),
(jnp.dtype(jnp.bool_), ctypes.c_bool),
(jnp.dtype(jnp.complex128), C128),
(jnp.dtype(jnp.complex64), C64),
],
)
def test_as_ctype(inp, exp):
"""Tests that JAX's dtypes behave the same as numpy's dtypes"""
obs = as_ctype(inp)
assert exp == obs


@pytest.mark.parametrize(
"inp, exp",
[
(1, make_zero_d_memref_descriptor(ctypes.c_long)),
(ShapedArray([], float), make_zero_d_memref_descriptor(ctypes.c_double)),
(ShapedArray([1], float), make_nd_memref_descriptor(1, ctypes.c_double)),
(ShapedArray([2, 2], float), make_nd_memref_descriptor(2, ctypes.c_double)),
],
)
def test_get_ranked_memref_descriptor(inp, exp):
"""Tests that the structure has the expected fields."""
obs = get_ranked_memref_descriptor(inp)
assert exp._fields_ == obs._fields_ # pylint: disable=protected-access
3 changes: 2 additions & 1 deletion mlir/include/Catalyst/IR/CatalystOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,12 @@ def PythonCallOp: Catalyst_Op<"pycallback",
}];

let arguments = (ins
Variadic<AnyTypeOf<[AnyRankedTensor, MemRefOf<[AnyType]>]>>:$inputs,
I64Attr: $identifier
);

let assemblyFormat = [{
attr-dict
`(` $inputs `)` attr-dict `:` functional-type(operands, results)
}];
}

Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/Catalyst/Transforms/BufferizationPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,26 @@ struct BufferizeCustomCallOp : public OpConversionPattern<CustomCallOp> {
}
};

struct BufferizePythonCallOp : public OpConversionPattern<PythonCallOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(PythonCallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override
{
rewriter.create<PythonCallOp>(op.getLoc(), adaptor.getOperands(), adaptor.getIdentifier());
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
rewriter.eraseOp(op);
return success();
}
};

} // namespace

namespace catalyst {

void populateBufferizationPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)
{
patterns.add<BufferizeCustomCallOp>(typeConverter, patterns.getContext());
patterns.add<BufferizePythonCallOp>(typeConverter, patterns.getContext());
patterns.add<BufferizePrintOp>(typeConverter, patterns.getContext());
}

Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Catalyst/Transforms/catalyst_bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ struct CatalystBufferizationPass : impl::CatalystBufferizationPassBase<CatalystB
[&](PrintOp op) { return typeConverter.isLegal(op); });
target.addDynamicallyLegalOp<CustomCallOp>(
[&](CustomCallOp op) { return typeConverter.isLegal(op); });
target.addDynamicallyLegalOp<PythonCallOp>(
[&](PythonCallOp op) { return typeConverter.isLegal(op); });

if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) {
signalPassFailure();
Expand Down
Loading
Loading