Skip to content

Commit

Permalink
Merge branch 'master' into two-local-swap-network
Browse files Browse the repository at this point in the history
  • Loading branch information
obliviateandsurrender committed Dec 20, 2022
2 parents ac55b87 + e2e23e5 commit 29f160b
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 11 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Expand Up @@ -104,6 +104,9 @@
`Observable`, allowing it to return valid generators from `SymbolicOp` and `CompositeOp` classes.
[(#3485)](https://github.com/PennyLaneAI/pennylane/pull/3485)

* Added support for two-qubit unitary decomposition with JAX-JIT.
[(#3569)](https://github.com/PennyLaneAI/pennylane/pull/3569)

* Limit the `numpy` version to `<1.24`.
[(#3563)](https://github.com/PennyLaneAI/pennylane/pull/3563)

Expand Down
5 changes: 3 additions & 2 deletions pennylane/math/multi_dispatch.py
Expand Up @@ -30,7 +30,8 @@
def array(*args, like=None, **kwargs):
"""Creates an array or tensor object of the target framework.
This method preserves the Torch device used.
If the PyTorch interface is specified, this method preserves the Torch device used.
If the JAX interface is specified, this method uses JAX numpy arrays, which do not cause issues with jit tracers.
Returns:
tensor_like: the tensor_like object of the framework
Expand Down Expand Up @@ -154,7 +155,7 @@ def wrapper(*args, **kwargs):
return decorator


@multi_dispatch(argnum=[0])
@multi_dispatch(argnum=[0, 1])
def kron(*args, like=None, **kwargs):
"""The kronecker/tensor product of args."""
if like == "scipy":
Expand Down
39 changes: 30 additions & 9 deletions pennylane/transforms/decompositions/two_qubit_unitary.py
Expand Up @@ -170,18 +170,30 @@ def _su2su2_to_tensor_products(U):
# C1 C2^dag = a1 a2* I
C12 = math.dot(C1, math.conj(math.T(C2)))

if not math.allclose(a1 * math.conj(a2), C12[0, 0]):
a2 *= -1
if not math.is_abstract(C12):
if not math.allclose(a1 * math.conj(a2), C12[0, 0]):
a2 *= -1
else:
sign_is_correct = math.allclose(a1 * math.conj(a2), C12[0, 0])
sign = (-1) ** (sign_is_correct + 1) # True + 1 = 2, False + 1 = 1
a2 *= sign

# Construct A
A = math.stack([math.stack([a1, a2]), math.stack([-math.conj(a2), math.conj(a1)])])

# Next, extract B. Can do from any of the C, just need to be careful in
# case one of the elements of A is 0.
if not math.allclose(A[0, 0], 0.0, atol=1e-6):
B = C1 / math.cast_like(A[0, 0], 1j)
else:
B = C2 / math.cast_like(A[0, 1], 1j)
# We use B1 unless division by 0 would cause all elements to be inf.
use_B2 = math.allclose(A[0, 0], 0.0, atol=1e-6)
if not math.is_abstract(A):
B = C2 / math.cast_like(A[0, 1], 1j) if use_B2 else C1 / math.cast_like(A[0, 0], 1j)
elif qml.math.get_interface(A) == "jax":
B = qml.math.cond(
use_B2,
lambda x: C2 / math.cast_like(A[0, 1], 1j),
lambda x: C1 / math.cast_like(A[0, 0], 1j),
[0], # arbitrary value for x
)

return math.convert_like(A, U), math.convert_like(B, U)

Expand Down Expand Up @@ -428,8 +440,14 @@ def _decomposition_3_cnots(U, wires):
gammaU = math.dot(u, math.T(u))
evs, _ = math.linalg.eig(gammaU)

angles = [math.angle(ev) for ev in evs]

# We will sort the angles so that results are consistent across interfaces.
angles = math.sort([math.angle(ev) for ev in evs])
# This step is skipped when using JAX-JIT, because it cannot sort without making the
# magnitude of the angles concrete. This does not impact the validity of the resulting
# decomposition, but may result in a different decompositions for jitting vs not.
if not qml.math.is_abstract(U):
angles = math.sort(angles)

x, y, z = angles[0], angles[1], angles[2]

Expand Down Expand Up @@ -589,10 +607,13 @@ def two_qubit_decomposition(U, wires):

# The next thing we will do is compute the number of CNOTs needed, as this affects
# the form of the decomposition.
num_cnots = _compute_num_cnots(U)
if not qml.math.is_abstract(U):
num_cnots = _compute_num_cnots(U)

with qml.QueuingManager.stop_recording():
if num_cnots == 0:
if qml.math.is_abstract(U):
decomp = _decomposition_3_cnots(U, wires)
elif num_cnots == 0:
decomp = _decomposition_0_cnots(U, wires)
elif num_cnots == 1:
decomp = _decomposition_1_cnot(U, wires)
Expand Down
62 changes: 62 additions & 0 deletions tests/transforms/test_decompositions.py
Expand Up @@ -855,3 +855,65 @@ def test_two_qubit_decomposition_tensor_products_jax(self, U_pair, wires):
obtained_matrix = qml.matrix(tape, wire_order=wires)

assert check_matrix_equivalence(U, obtained_matrix, atol=1e-7)

@pytest.mark.jax
@pytest.mark.parametrize("wires", [[0, 1], ["a", "b"], [3, 2], ["c", 0]])
@pytest.mark.parametrize("U", samples_3_cnots + samples_2_cnots + samples_1_cnot)
def test_two_qubit_decomposition_jax_jit(self, U, wires):
"""Test that a two-qubit operation is correctly decomposed with JAX-JIT ."""
import jax
from jax.config import config

config.update("jax_enable_x64", True)

U = jax.numpy.array(U, dtype=jax.numpy.complex128)

def wrapped_decomposition(U):
# the jitted function cannot return objects like operators,
# so we cannot jit two_qubit_decomposition directly
obtained_decomposition = two_qubit_decomposition(U, wires=wires)

with qml.queuing.AnnotatedQueue() as q:
for op in obtained_decomposition:
qml.apply(op)

tape = qml.tape.QuantumScript.from_queue(q)
obtained_matrix = qml.matrix(tape, wire_order=wires)

return obtained_matrix

jitted_matrix = jax.jit(wrapped_decomposition)(U)

assert check_matrix_equivalence(U, jitted_matrix, atol=1e-7)

@pytest.mark.jax
@pytest.mark.parametrize("wires", [[0, 1], ["a", "b"], [3, 2], ["c", 0]])
@pytest.mark.parametrize("U_pair", samples_su2_su2)
def test_two_qubit_decomposition_tensor_products_jax_jit(self, U_pair, wires):
"""Test that a two-qubit tensor product is correctly decomposed with JAX-JIT."""
import jax
from jax.config import config

config.update("jax_enable_x64", True)

U1 = jax.numpy.array(U_pair[0], dtype=jax.numpy.complex128)
U2 = jax.numpy.array(U_pair[1], dtype=jax.numpy.complex128)
U = qml.math.kron(U1, U2)

def wrapped_decomposition(U):
# the jitted function cannot return objects like operators,
# so we cannot jit two_qubit_decomposition directly
obtained_decomposition = two_qubit_decomposition(U, wires=wires)

with qml.queuing.AnnotatedQueue() as q:
for op in obtained_decomposition:
qml.apply(op)

tape = qml.tape.QuantumScript.from_queue(q)
obtained_matrix = qml.matrix(tape, wire_order=wires)

return obtained_matrix

jitted_matrix = jax.jit(wrapped_decomposition)(U)

assert check_matrix_equivalence(U, jitted_matrix, atol=1e-7)

0 comments on commit 29f160b

Please sign in to comment.