Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Support for transforms #280

Merged
merged 76 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
bfd00aa
success for capturing co-transform.
erick-xanadu Sep 7, 2023
5111aac
Mid circuit measurements work.
erick-xanadu Sep 7, 2023
253f59e
Solves pytree issue.
erick-xanadu Sep 8, 2023
8832762
Use transforms.
erick-xanadu Sep 12, 2023
7c55d8c
Style.
erick-xanadu Sep 12, 2023
7887c66
Keep identity transform.
erick-xanadu Sep 12, 2023
7605ed6
Conditional qnode or None.
erick-xanadu Sep 12, 2023
76d8351
Set trainable parameters and expand before transforms.
erick-xanadu Sep 13, 2023
ebebdcb
Add support for optional wrapping in list depending on output.
erick-xanadu Sep 14, 2023
3910e6a
Rules for return types for transforms.
erick-xanadu Sep 15, 2023
897c9b1
Support for all most used transforms.
erick-xanadu Sep 18, 2023
00a7d27
fix import
erick-xanadu Oct 24, 2023
e953af8
rebase fix.
erick-xanadu Oct 24, 2023
04e0699
Remove unflattening.
erick-xanadu Oct 25, 2023
a53019a
some tests fail
erick-xanadu Oct 25, 2023
b0984f4
Fix custom decomposition.
erick-xanadu Oct 25, 2023
8d207aa
style
erick-xanadu Oct 25, 2023
8954d7f
isort
erick-xanadu Oct 25, 2023
01cee77
isort
erick-xanadu Oct 25, 2023
f33857c
Fixes
erick-xanadu Oct 25, 2023
5f70498
black
erick-xanadu Oct 25, 2023
4a0a27b
Temporary skip.
erick-xanadu Oct 25, 2023
be0466b
Add issue as a reason.
erick-xanadu Oct 25, 2023
1b7349d
Improving tests by using backend more consistently.
erick-xanadu Oct 25, 2023
034d460
Mark skip.
erick-xanadu Oct 25, 2023
0c32e88
Improve mc_cut test.
erick-xanadu Oct 25, 2023
e70546e
Improve hamiltonian_expand test.
erick-xanadu Oct 25, 2023
b168eac
Better skips
erick-xanadu Oct 25, 2023
fdb632b
Appease linter.
erick-xanadu Oct 25, 2023
8d366bc
Use name.
erick-xanadu Oct 25, 2023
c3d8a5c
Add test for merge rotations.
erick-xanadu Oct 26, 2023
1c06b0f
Fix name
erick-xanadu Oct 26, 2023
f10ddfa
import
erick-xanadu Oct 26, 2023
c8d133e
Uncomment and document reason for skipping.
erick-xanadu Oct 26, 2023
0d65158
Remove exp_fn
erick-xanadu Oct 26, 2023
03dcb54
group parameters
erick-xanadu Oct 26, 2023
d9e2d3f
Attempt one at refactoring.
erick-xanadu Oct 26, 2023
401b301
rename variables.
erick-xanadu Oct 26, 2023
c4664a1
Better message.
erick-xanadu Oct 26, 2023
3e9a3da
Sample documentation.
erick-xanadu Oct 26, 2023
1498367
Simplify.
erick-xanadu Oct 26, 2023
39009f9
Cleaning up a bit.
erick-xanadu Oct 26, 2023
df1e52a
Remove delete quantum_tape.
erick-xanadu Oct 26, 2023
3812373
Add test for invalid transforms
erick-xanadu Oct 27, 2023
7a15d3f
Add test for valid transform with measure.
erick-xanadu Oct 27, 2023
2ba31c3
Remove unused variable.
erick-xanadu Oct 27, 2023
05a356d
Rename variable.
erick-xanadu Oct 27, 2023
e9155ea
Variable renaming.
erick-xanadu Oct 27, 2023
634109b
Remove unused variable
erick-xanadu Oct 27, 2023
8bb6ab7
simplify.
erick-xanadu Oct 27, 2023
3245fe6
Simplify
erick-xanadu Oct 27, 2023
10555d2
Simplify.
erick-xanadu Oct 27, 2023
6076f19
Simplify.
erick-xanadu Oct 27, 2023
656eb33
Add documentation and separate post processing tracing.
erick-xanadu Oct 27, 2023
7e5cda7
Add comment.
erick-xanadu Oct 27, 2023
5decf4a
pylint.
erick-xanadu Oct 27, 2023
b22b046
pylint.
erick-xanadu Oct 27, 2023
677da1e
4731 passes.
erick-xanadu Oct 27, 2023
e865c5a
Link to original issue.
erick-xanadu Oct 27, 2023
f4deb36
Reproducibility.
erick-xanadu Oct 27, 2023
5fe647a
Raise error for informative transforms.
erick-xanadu Oct 27, 2023
778fd86
Fix
erick-xanadu Oct 27, 2023
74d4baa
Add comments.
erick-xanadu Oct 27, 2023
35fb759
Sink error message.
erick-xanadu Oct 27, 2023
a87485a
Merge branch 'main' into eochoa/2023-09-18/transform-iv
erick-xanadu Nov 9, 2023
b662903
Apply review suggestion.
erick-xanadu Nov 9, 2023
18cf111
Add review suggestion.
erick-xanadu Nov 9, 2023
0ef5d9a
Raise error if transformed and returns something other than measureme…
erick-xanadu Nov 9, 2023
a438eba
xfail instead of skip.
erick-xanadu Nov 9, 2023
2c2aca2
Change test names.
erick-xanadu Nov 9, 2023
c257374
Codefactor.
erick-xanadu Nov 9, 2023
07b9f16
style.
erick-xanadu Nov 9, 2023
d8b271b
Fix
erick-xanadu Nov 9, 2023
06de41a
Add unroll_ccrz
erick-xanadu Nov 9, 2023
7761ddc
Coverage.
erick-xanadu Nov 9, 2023
597f2a4
Changelog
erick-xanadu Nov 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

<h3>New features</h3>

* Initial support for transforms. QFunc transforms are supported. QNode transforms have limited
support. QNode transforms cannot be composed, and transforms are limited to what is currently
available in PennyLane. This means that operations defined in Catalyst like `cond`, `for_loop`,
and `while_loop` are not supported by transforms. Additionally, transforms can only return
`MeasurementProcess`es.
[(#280)](https://github.com/PennyLaneAI/catalyst/pull/280)

<h3>Improvements</h3>

<h3>Breaking changes</h3>
Expand All @@ -12,6 +19,8 @@

This release contains contributions from (in alphabetical order):

Erick Ochoa Lopez.

# Release 0.3.2

<h3>New features</h3>
Expand Down
237 changes: 198 additions & 39 deletions frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from dataclasses import dataclass
from functools import partial, reduce
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
Expand All @@ -26,6 +26,7 @@
from pennylane.operation import AnyWires, Operation, Wires
from pennylane.tape import QuantumTape

import catalyst
from catalyst.jax_primitives import (
AbstractQreg,
compbasis_p,
Expand All @@ -50,6 +51,7 @@
var_p,
)
from catalyst.utils.contexts import EvaluationContext, EvaluationMode, JaxTracingContext
from catalyst.utils.exceptions import CompileError
from catalyst.utils.jax_extras import (
ClosedJaxpr,
DynamicJaxprTrace,
Expand Down Expand Up @@ -500,6 +502,11 @@ def pauli_word_to_tensor_obs(obs, qrp: QRegPromise) -> List[DynamicJaxprTracer]:
return tensorobs_p.bind(*nested_obs)


def identity_qnode_transform(tape: QuantumTape) -> (Sequence[QuantumTape], Callable):
"""Identity transform"""
return [tape], lambda res: res[0]


def trace_quantum_measurements(
device: QubitDevice,
qrp: QRegPromise,
Expand Down Expand Up @@ -571,8 +578,135 @@ def trace_quantum_measurements(
return out_classical_tracers, out_tree


def is_transform_valid_for_batch_transforms(tape, flat_results):
"""Not all transforms are valid for batch transforms.
Batch transforms will increase the number of tapes from 1 to N.
However, if the wave function collapses or there is any other non-deterministic behaviour
such as noise present, then each tape would have different results.

Also, MidCircuitMeasure is a HybridOp, which PL does not handle at the moment.
Let's wait until mid-circuit measurements are better integrated into both PL
and Catalyst and discussed more as well."""
class_tracers, meas_tracers = split_tracers_and_measurements(flat_results)

# Can transforms be applied?
# Since transforms are a PL feature and PL does not support the same things as
# Catalyst, transforms may have invariants that rely on PL invariants.
# For example:
# * mid-circuit measurements (for batch-transforms)
# * that the output will be only a sequence of `MeasurementProcess`es.
def is_measurement(op):
"""Only to avoid 100 character per line limit."""
return isinstance(op, MeasurementProcess)

is_out_measurements = map(is_measurement, meas_tracers)
is_all_out_measurements = all(is_out_measurements) and not class_tracers
is_out_measurement_sequence = is_all_out_measurements and isinstance(meas_tracers, Sequence)
is_out_single_measurement = is_all_out_measurements and is_measurement(meas_tracers)

def is_midcircuit_measurement(op):
"""Only to avoid 100 character per line limit."""
return isinstance(op, catalyst.pennylane_extensions.MidCircuitMeasure)

is_valid_output = is_out_measurement_sequence or is_out_single_measurement
if not is_valid_output:
msg = (
"A transformed quantum function must return either a single measurement, "
"or a nonempty sequence of measurements."
)
raise CompileError(msg)

is_wave_function_collapsed = any(map(is_midcircuit_measurement, tape.operations))
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
are_batch_transforms_valid = is_valid_output and not is_wave_function_collapsed
return are_batch_transforms_valid


def apply_transform(qnode, tape, flat_results):
"""Apply transform."""

# Some transforms use trainability as a basis for transforming.
# See batch_params
params = tape.get_parameters(trainable_only=False)
tape.trainable_params = qml.math.get_trainable_indices(params)

is_program_transformed = qnode and qnode.transform_program

if is_program_transformed and qnode.transform_program.is_informative:
msg = "Catalyst does not support informative transforms."
raise CompileError(msg)

if is_program_transformed:
is_valid_for_batch = is_transform_valid_for_batch_transforms(tape, flat_results)
tapes, post_processing = qnode.transform_program([tape])
if not is_valid_for_batch and len(tapes) > 1:
msg = "Multiple tapes are generated, but each run might produce different results."
raise CompileError(msg)
else:
# Apply the identity transform in order to keep generalization
tapes, post_processing = identity_qnode_transform(tape)
return tapes, post_processing


def split_tracers_and_measurements(flat_values):
"""Return classical tracers and measurements"""
classical = []
measurements = []
for flat_value in flat_values:
if isinstance(flat_value, DynamicJaxprTracer):
# classical should remain empty for all valid cases at the moment.
# This is because split_tracers_and_measurements is only called
# when checking the validity of transforms. And transforms cannot
# return any tracers.
classical.append(flat_value) # pragma: no cover
else:
measurements.append(flat_value)

return classical, measurements


def trace_post_processing(ctx, trace, post_processing, args_types, args):
"""Trace post processing function.

Args:
ctx (EvaluationContext): context
trace (DynamicJaxprTrace): trace
post_processing: post_processing function
args_types: unflattened args
args: input tracers

Returns:
closed_jaxpr: JAXPR expression for the whole frame
post_processing_results: Output
"""

with EvaluationContext.frame_tracing_context(ctx, trace):
# What is the input to the post_processing function?
# The input to the post_processing function is going to be a list of values
# One for each tape.

# The tracers are all flat in args.
# The shape is in a list of args_types.
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved

# We need to deduce the type/shape/tree of the post_processing.
wffa, _, out_tree_promise = deduce_avals(post_processing, (args_types,), {})

# wffa will take as an input a flatten tracers.
post_processing_retval_flat = wffa.call_wrapped(*args)

# After wffa is called, then the shape becomes available in out_tree_promise.
post_processing_tracers = [trace.full_raise(t) for t in post_processing_retval_flat]
jaxpr, _, consts = ctx.frames[trace].to_jaxpr2(post_processing_tracers)
closed_jaxpr = ClosedJaxpr(jaxpr, consts)
post_processing_results = tree_unflatten(
out_tree_promise(),
[ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in post_processing_tracers],
)

return closed_jaxpr, post_processing_results


def trace_quantum_function(
f: Callable, device: QubitDevice, args, kwargs
f: Callable, device: QubitDevice, args, kwargs, qnode=None
) -> Tuple[ClosedJaxpr, Any]:
"""Trace quantum function in a way that allows building a nested quantum tape describing the
quantum algorithm.
Expand Down Expand Up @@ -601,7 +735,7 @@ def trace_quantum_function(
in_classical_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
with QueuingManager.stop_recording(), quantum_tape:
# Quantum tape transformations happen at the end of tracing
ans = wffa.call_wrapped(*in_classical_tracers)
return_values_flat = wffa.call_wrapped(*in_classical_tracers)

# Ans contains the leaves of the pytree (empty for measurement without
# data https://github.com/PennyLaneAI/pennylane/pull/4607)
Expand All @@ -610,47 +744,72 @@ def trace_quantum_function(

# 1. Recompute the original return
with QueuingManager.stop_recording():
ans = tree_unflatten(out_tree_promise(), ans)
return_values = tree_unflatten(out_tree_promise(), return_values_flat)

def is_leaf(obj):
return isinstance(obj, qml.measurements.MeasurementProcess)

# 2. Create a new tree that has measurements as leaves
ans, out_tree = jax.tree_util.tree_flatten(ans, is_leaf=is_leaf)

out_classical_tracers_or_measurements = [
(trace.full_raise(t) if isinstance(t, DynamicJaxprTracer) else t) for t in ans
]

# (2) - Quantum tracing
with EvaluationContext.frame_tracing_context(ctx, trace):
qdevice_p.bind(spec="kwargs", val=str(device.backend_kwargs))
qdevice_p.bind(spec="backend", val=device.backend_name)
qreg_in = qalloc_p.bind(len(device.wires))
qrp_out = trace_quantum_tape(quantum_tape, device, qreg_in, ctx, trace)
out_classical_tracers, out_classical_tree = trace_quantum_measurements(
device,
qrp_out,
out_classical_tracers_or_measurements,
out_tree,
return_values_flat, return_values_tree = jax.tree_util.tree_flatten(
return_values, is_leaf=is_leaf
)
out_quantum_tracers = [qrp_out.actualize()]
qdealloc_p.bind(qreg_in)

out_classical_tracers = [trace.full_raise(t) for t in out_classical_tracers]
# TODO: In order to compose transforms, we would need to recursively
# call apply_transform while popping the latest transform applied,
# until there are no more transforms to be applied.
# But first we should clean this up this method a bit more.
tapes, post_processing = apply_transform(qnode, quantum_tape, return_values_flat)
dime10 marked this conversation as resolved.
Show resolved Hide resolved

jaxpr, out_type, consts = ctx.frames[trace].to_jaxpr2(
out_classical_tracers + out_quantum_tracers
)
jaxpr._outvars = jaxpr._outvars[:-1] # pylint: disable=protected-access
out_type = out_type[:-1]
# TODO: `check_jaxpr` complains about the `AbstractQreg` type. Consider fixing.
# check_jaxpr(jaxpr)

closed_jaxpr = ClosedJaxpr(jaxpr, consts)
out_avals, _ = unzip2(out_type)

abstract_results = tree_unflatten(
out_classical_tree, [ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
)
return closed_jaxpr, abstract_results
# (2) - Quantum tracing
results_tracers, results_abstract = [], []
is_program_transformed = qnode and qnode.transform_program
for tape in tapes:
# If the program is batched, that means that it was transformed.
# If it was transformed, that means that the program might have
# changed the output. See `split_non_commuting`
if is_program_transformed:
# TODO: In the future support arbitrary output from the user function.
output = tape.measurements
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
_, trees = jax.tree_util.tree_flatten(output, is_leaf=is_leaf)
else:
output = return_values_flat
trees = return_values_tree

with EvaluationContext.frame_tracing_context(ctx, trace):
qdevice_p.bind(spec="kwargs", val=str(device.backend_kwargs))
qdevice_p.bind(spec="backend", val=device.backend_name)
qreg_in = qalloc_p.bind(len(device.wires))
qrp_out = trace_quantum_tape(tape, device, qreg_in, ctx, trace)
meas, meas_trees = trace_quantum_measurements(device, qrp_out, output, trees)
out_quantum_tracers = [qrp_out.actualize()]
qdealloc_p.bind(qreg_in)

tracers = [trace.full_raise(m) for m in meas]
results_tracers += tracers

jaxpr, out_type, _ = ctx.frames[trace].to_jaxpr2(tracers + out_quantum_tracers)
jaxpr._outvars = jaxpr._outvars[:-1] # pylint: disable=protected-access
out_type = out_type[:-1]

out_avals, _ = unzip2(out_type)
abstract_results = tree_unflatten(
meas_trees,
[ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals],
)
# This mimics the return type from qnodes.
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
# I would prefer if qnodes didn't have special rules about whether they return a
# tuple, list, or value.
# TODO: Allow the user to return whatever types they specify.
if is_program_transformed and len(abstract_results) == 1:
results_abstract.append(abstract_results[0])
elif is_program_transformed:
results_abstract.append(tuple(abstract_results))
else:
results_abstract.append(abstract_results)
dime10 marked this conversation as resolved.
Show resolved Hide resolved
# TODO: `check_jaxpr` complains about the `AbstractQreg` type. Consider fixing.
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
# check_jaxpr(jaxpr)

closed_jaxpr, unflattened_callback_results = trace_post_processing(
ctx, trace, post_processing, results_abstract, results_tracers
)
return closed_jaxpr, unflattened_callback_results
4 changes: 3 additions & 1 deletion frontend/catalyst/pennylane_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def __init__(self, fn, device):
update_wrapper(self, fn)

def __call__(self, *args, **kwargs):
qnode = None
if isinstance(self, qml.QNode):
qnode = self
if isinstance(self.device, qml.Device):
name = self.device.short_name
else:
Expand Down Expand Up @@ -165,7 +167,7 @@ def __call__(self, *args, **kwargs):
device = self.device

with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION):
jaxpr, shape = trace_quantum_function(self.func, device, args, kwargs)
jaxpr, shape = trace_quantum_function(self.func, device, args, kwargs, qnode)

retval_tree = tree_structure(shape)

Expand Down
Loading
Loading