Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove norm check for jax.jit functions in QubitStateVector #1683

Merged
merged 18 commits into from Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions pennylane/devices/default_qubit.py
Expand Up @@ -26,6 +26,7 @@
from scipy.sparse import coo_matrix

import pennylane as qml
import jax
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
from pennylane import QubitDevice, DeviceError, QubitStateVector, BasisState
from pennylane.operation import DiagonalOperation
from pennylane.wires import WireError
Expand Down Expand Up @@ -626,12 +627,11 @@ def _apply_state_vector(self, state, device_wires):
raise ValueError("State vector must be of length 2**wires.")

norm_error_message = "Sum of amplitudes-squared does not equal one."
if qml.math.get_interface(state) == "torch":
if not isinstance(
qml.math.linalg.norm(state, ord=2), jax.interpreters.partial_eval.DynamicJaxprTracer
):
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
if not qml.math.allclose(qml.math.linalg.norm(state, ord=2), 1.0, atol=tolerance):
raise ValueError(norm_error_message)
else:
if not np.allclose(np.linalg.norm(state, ord=2), 1.0, atol=tolerance):
raise ValueError(norm_error_message)

if len(device_wires) == self.num_wires and sorted(device_wires) == device_wires:
# Initialize the entire wires with the state
Expand Down
72 changes: 72 additions & 0 deletions tests/devices/test_default_qubit_jax.py
Expand Up @@ -172,6 +172,78 @@ def inner_circuit():
np.testing.assert_array_equal(a, b)
assert not np.all(a == c)

@pytest.mark.parametrize(
"state_vector",
[np.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0]), jnp.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0])],
)
def test_qubit_state_vector_arg_jax_jit(self, state_vector, tol):
"""Test that Qubit state vector works with a jax.jit"""
dev = qml.device("default.qubit.jax", wires=list(range(2)))

@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(x):
wires = list(range(2))
qml.QubitStateVector(x, wires=wires)
return [qml.expval(qml.PauliX(wires=i)) for i in wires]

res = circuit(state_vector)
assert jnp.allclose(res, [0, 1], atol=tol, rtol=0)

@pytest.mark.parametrize(
"state_vector",
[np.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0]), jnp.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0])],
)
def test_qubit_state_vector_jax_jit(self, state_vector, tol):
"""Test that Qubit state vector works with jax"""
dev = qml.device("default.qubit.jax", wires=list(range(2)))

@qml.qnode(dev, interface="jax")
def circuit(x):
wires = list(range(2))
qml.QubitStateVector(x, wires=wires)
return [qml.expval(qml.PauliX(wires=i)) for i in wires]

res = circuit(state_vector)
assert jnp.allclose(res, [0, 1], atol=tol, rtol=0)

@pytest.mark.parametrize(
"state_vector",
[np.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0]), jnp.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0])],
)
def test_qubit_state_vector_jax_jit(self, state_vector, tol):
"""Test that Qubit state vector works with a jax.jit"""
dev = qml.device("default.qubit.jax", wires=list(range(2)))

@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(x):
qml.QubitStateVector(state_vector, wires=dev.wires)
for w in dev.wires:
qml.RZ(x, wires=w, id="x")
return qml.expval(qml.PauliZ(wires=0))

res = circuit(0.1)
assert jnp.allclose(res, 1, atol=tol, rtol=0)

@pytest.mark.parametrize(
"state_vector",
[np.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0]), jnp.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0])],
)
def test_qubit_state_vector_jit(self, state_vector, tol):
"""Test that Qubit state vector works with a jax.jit"""
dev = qml.device("default.qubit.jax", wires=list(range(2)))

@qml.qnode(dev, interface="jax")
def circuit(x):
qml.QubitStateVector(state_vector, wires=dev.wires)
for w in dev.wires:
qml.RZ(x, wires=w, id="x")
return qml.expval(qml.PauliZ(wires=0))

res = circuit(0.1)
assert jnp.allclose(res, 1, atol=tol, rtol=0)
rmoyard marked this conversation as resolved.
Show resolved Hide resolved

def test_sampling_op_by_op(self):
"""Test that op-by-op sampling works as a new user would expect"""
dev = qml.device("default.qubit.jax", wires=1, shots=1000)
Expand Down