Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove norm check for jax.jit functions in QubitStateVector #1683

Merged
merged 18 commits into from Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion doc/releases/changelog-dev.md
Expand Up @@ -277,6 +277,9 @@

<h3>Bug fixes</h3>

* Fix a bug where it was not possible to use `jax.jit` on a `QNode` when using `QubitStateVector`.
[(#1683)](https://github.com/PennyLaneAI/pennylane/pull/1683)

* The device suite tests can now execute successfully if no shots configuration variable is given.
[(#1641)](https://github.com/PennyLaneAI/pennylane/pull/1641)

Expand All @@ -296,5 +299,5 @@

This release contains contributions from (in alphabetical order):

Utkarsh Azad, Olivia Di Matteo, Andrew Gardhouse, Josh Izaac, Christina Lee,
Utkarsh Azad, Olivia Di Matteo, Andrew Gardhouse, Josh Izaac, Christina Lee, Romain Moyard
Ingrid Strandberg, Antal Száva, David Wierichs.
8 changes: 5 additions & 3 deletions pennylane/devices/default_qubit.py
Expand Up @@ -626,12 +626,14 @@ def _apply_state_vector(self, state, device_wires):
raise ValueError("State vector must be of length 2**wires.")

norm_error_message = "Sum of amplitudes-squared does not equal one."
if qml.math.get_interface(state) == "torch":
if qml.math.get_interface(state) != "jax":
if not qml.math.allclose(qml.math.linalg.norm(state, ord=2), 1.0, atol=tolerance):
raise ValueError(norm_error_message)
else:
if not np.allclose(np.linalg.norm(state, ord=2), 1.0, atol=tolerance):
raise ValueError(norm_error_message)
# Case for jax without jit, full_lower is an attribute for abstract tracers
if not hasattr(qml.math.linalg.norm(state, ord=2), "full_lower"):
if not qml.math.allclose(qml.math.linalg.norm(state, ord=2), 1.0, atol=tolerance):
raise ValueError(norm_error_message)
Comment on lines +633 to +636
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much clearer now 👍


if len(device_wires) == self.num_wires and sorted(device_wires) == device_wires:
# Initialize the entire wires with the state
Expand Down
94 changes: 92 additions & 2 deletions tests/devices/test_default_qubit_jax.py
Expand Up @@ -172,6 +172,96 @@ def inner_circuit():
np.testing.assert_array_equal(a, b)
assert not np.all(a == c)

@pytest.mark.parametrize(
"state_vector",
[np.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0]), jnp.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0])],
)
def test_qubit_state_vector_arg_jax_jit(self, state_vector, tol):
"""Test that Qubit state vector as argument works with a jax.jit"""
dev = qml.device("default.qubit.jax", wires=list(range(2)))

@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(x):
wires = list(range(2))
qml.QubitStateVector(x, wires=wires)
return [qml.expval(qml.PauliX(wires=i)) for i in wires]

res = circuit(state_vector)
assert jnp.allclose(jnp.array(res), jnp.array([0, 1]), atol=tol, rtol=0)

@pytest.mark.parametrize(
"state_vector",
[np.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0]), jnp.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0])],
)
def test_qubit_state_vector_arg_jax(self, state_vector, tol):
"""Test that Qubit state vector as argument works with jax"""
dev = qml.device("default.qubit.jax", wires=list(range(2)))

@qml.qnode(dev, interface="jax")
def circuit(x):
wires = list(range(2))
qml.QubitStateVector(x, wires=wires)
return [qml.expval(qml.PauliX(wires=i)) for i in wires]

res = circuit(state_vector)
assert jnp.allclose(jnp.array(res), jnp.array([0, 1]), atol=tol, rtol=0)

@pytest.mark.parametrize(
"state_vector",
[np.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0]), jnp.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0])],
)
def test_qubit_state_vector_jax_jit(self, state_vector, tol):
"""Test that Qubit state vector works with a jax.jit"""
dev = qml.device("default.qubit.jax", wires=list(range(2)))

@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(x):
qml.QubitStateVector(state_vector, wires=dev.wires)
for w in dev.wires:
qml.RZ(x, wires=w, id="x")
return qml.expval(qml.PauliZ(wires=0))

res = circuit(0.1)
assert jnp.allclose(jnp.array(res), 1, atol=tol, rtol=0)

@pytest.mark.parametrize(
"state_vector",
[np.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0]), jnp.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0])],
)
def test_qubit_state_vector_jax(self, state_vector, tol):
"""Test that Qubit state vector works with a jax"""
dev = qml.device("default.qubit.jax", wires=list(range(2)))

@qml.qnode(dev, interface="jax")
def circuit(x):
qml.QubitStateVector(state_vector, wires=dev.wires)
for w in dev.wires:
qml.RZ(x, wires=w, id="x")
return qml.expval(qml.PauliZ(wires=0))

res = circuit(0.1)
assert jnp.allclose(jnp.array(res), 1, atol=tol, rtol=0)

@pytest.mark.parametrize(
"state_vector",
[np.array([0.1 + 0.1j, 0.2 + 0.2j, 0, 0]), jnp.array([0.1 + 0.1j, 0.2 + 0.2j, 0, 0])],
)
def test_qubit_state_vector_jax_not_normed(self, state_vector, tol):
"""Test that an error is raised when Qubit state vector is not normed works with a jax"""
dev = qml.device("default.qubit.jax", wires=list(range(2)))

@qml.qnode(dev, interface="jax")
def circuit(x):
qml.QubitStateVector(state_vector, wires=dev.wires)
for w in dev.wires:
qml.RZ(x, wires=w, id="x")
return qml.expval(qml.PauliZ(wires=0))

with pytest.raises(ValueError, match="Sum of amplitudes-squared does not equal one."):
circuit(0.1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really nice tests @rmoyard!


def test_sampling_op_by_op(self):
"""Test that op-by-op sampling works as a new user would expect"""
dev = qml.device("default.qubit.jax", wires=1, shots=1000)
Expand Down Expand Up @@ -372,7 +462,7 @@ def cost(a, b):

grad = jax.jit(jax.grad(cost, argnums=(0, 1)))(a, b)
expected = [jnp.sin(a) * jnp.cos(b), jnp.cos(a) * jnp.sin(b)]
assert jnp.allclose(grad, expected, atol=tol, rtol=0)
assert jnp.allclose(jnp.array(grad), jnp.array(expected), atol=tol, rtol=0)

def test_backprop_gradient(self, tol):
"""Tests that the gradient of the qnode is correct"""
Expand All @@ -394,7 +484,7 @@ def circuit(a, b):
expected_grad = jnp.array(
[-0.5 * jnp.sin(a) * (jnp.cos(b) + 1), 0.5 * jnp.sin(b) * (1 - jnp.cos(a))]
)
assert jnp.allclose(res, expected_grad, atol=tol, rtol=0)
assert jnp.allclose(jnp.array(res), expected_grad, atol=tol, rtol=0)

@pytest.mark.parametrize("operation", [qml.U3, qml.U3.decomposition])
@pytest.mark.parametrize("diff_method", ["backprop"])
Expand Down