Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX integration tests #1685

Merged
merged 18 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
extended to the JAX interface for scalar functions, via the beta
`pennylane.interfaces.batch` module.
[(#1634)](https://github.com/PennyLaneAI/pennylane/pull/1634)
[(#1685)](https://github.com/PennyLaneAI/pennylane/pull/1685)

For example using the `execute` function from the `pennylane.interfaces.batch` module:

Expand Down
16 changes: 15 additions & 1 deletion pennylane/math/multi_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,25 @@ def get_trainable_indices(values):
Trainable: {0}
tensor(0.0899685, requires_grad=True)
"""
trainable = requires_grad
interface = _multi_dispatch(values)
trainable_params = set()

if interface == "jax":
import jax

if not any(isinstance(v, jax.core.Tracer) for v in values):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a great change to have 🥇

How come it's placed here, instead of into the JAX branch of requires_grad?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it needs to be here, since the not any check can only be done here, it cannot be done inside the requires_grad check (which only checks a single tensor at a time) 🤔

I could be wrong though, let me know if you see a way around this!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I think you're right. 🤔 At least nothing else comes to mind that we could use here.

# No JAX tracing is occuring; treat all `DeviceArray` objects as trainable.
trainable = lambda p, **kwargs: isinstance(p, jax.numpy.DeviceArray)
else:
# JAX tracing is occuring; use the default behaviour (only traced arrays
# are treated as trainable). This is required to ensure that `jax.grad(func, argnums=...)
# works correctly, as the argnums argnument determines which parameters are
# traced arrays.
trainable = requires_grad

for idx, p in enumerate(values):
if requires_grad(p, interface=interface):
if trainable(p, interface=interface):
trainable_params.add(idx)

return trainable_params
Expand Down
2 changes: 1 addition & 1 deletion pennylane/math/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,6 @@ def requires_grad(tensor, interface=None):
if interface == "jax":
import jax

return isinstance(tensor, jax.interpreters.ad.JVPTracer)
return isinstance(tensor, jax.core.Tracer)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.core.Tracer is the original parent class, so this is a lot safer :) There are cases I discovered where JAX will use tracers that aren't JVPTracer.


raise ValueError(f"Argument {tensor} is an unknown object")
3 changes: 3 additions & 0 deletions pennylane/transforms/batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ def default_qnode_wrapper(self, qnode, targs, tkwargs):
the QNode, and return the output of the applying the tape transform
to the QNode's constructed tape.
"""
transform_max_diff = tkwargs.pop("max_diff", None)
josh146 marked this conversation as resolved.
Show resolved Hide resolved

if "shots" in inspect.signature(qnode.func).parameters:
raise ValueError(
"Detected 'shots' as an argument of the quantum function to transform. "
Expand All @@ -287,6 +289,7 @@ def _wrapper(*args, **kwargs):
interface = qnode.interface
execute_kwargs = getattr(qnode, "execute_kwargs", {})
max_diff = execute_kwargs.pop("max_diff", 2)
max_diff = transform_max_diff or max_diff

gradient_fn = getattr(qnode, "gradient_fn", qnode.diff_method)
gradient_kwargs = getattr(qnode, "gradient_kwargs", {})
Expand Down
2 changes: 1 addition & 1 deletion tests/collections/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def test_dot_product_qnodes_tensor(self, qnodes, interface, tf_support, torch_su
coeffs = coeffs.numpy()

expected = np.dot(qcval, coeffs)
assert np.all(res == expected)
assert np.allclose(res, expected)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, this test was failing for me on CI (but not locally) 🤔


def test_unknown_interface(self, monkeypatch):
"""Test exception raised if the interface is unknown"""
Expand Down
8 changes: 4 additions & 4 deletions tests/gradients/test_finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def test_torch(self, approx_order, strategy, tol):
torch = pytest.importorskip("torch")
from pennylane.interfaces.torch import TorchInterface

dev = qml.device("default.qubit", wires=2)
dev = qml.device("default.qubit.torch", wires=2)
josh146 marked this conversation as resolved.
Show resolved Hide resolved
params = torch.tensor([0.543, -0.654], dtype=torch.float64, requires_grad=True)

with TorchInterface.apply(qml.tape.QubitParamShiftTape()) as tape:
Expand All @@ -578,7 +578,7 @@ def test_torch(self, approx_order, strategy, tol):
qml.expval(qml.PauliZ(0) @ qml.PauliX(1))

tapes, fn = finite_diff(tape, n=1, approx_order=approx_order, strategy=strategy)
jac = fn([t.execute(dev) for t in tapes])
jac = fn(dev.batch_execute(tapes))
cost = torch.sum(jac)
cost.backward()
hess = params.grad
Expand Down Expand Up @@ -606,7 +606,7 @@ def test_jax(self, approx_order, strategy, tol):

config.update("jax_enable_x64", True)

dev = qml.device("default.qubit", wires=2)
dev = qml.device("default.qubit.jax", wires=2)
params = jnp.array([0.543, -0.654])

def cost_fn(x):
Expand All @@ -618,7 +618,7 @@ def cost_fn(x):

tape.trainable_params = {0, 1}
tapes, fn = finite_diff(tape, n=1, approx_order=approx_order, strategy=strategy)
jac = fn([t.execute(dev) for t in tapes])
jac = fn(dev.batch_execute(tapes))
josh146 marked this conversation as resolved.
Show resolved Hide resolved
return jac

res = jax.jacobian(cost_fn)(params)
Expand Down
24 changes: 24 additions & 0 deletions tests/gradients/test_gradient_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,27 @@ def circuit(x):
res.backward()
expected = -2 * np.cos(2 * x_)
assert np.allclose(x.grad.detach(), expected, atol=tol, rtol=0)

def test_jax(self, tol):
"""Test that a gradient transform remains differentiable
with JAX"""
jax = pytest.importorskip("jax")
jnp = jax.numpy
dev = qml.device("default.qubit", wires=2)

@qml.gradients.param_shift
@qml.qnode(dev, interface="jax")
def circuit(x):
qml.RY(x ** 2, wires=[1])
qml.CNOT(wires=[0, 1])
return qml.var(qml.PauliX(1))

x = jnp.array(-0.654)

res = circuit(x)
expected = -4 * x * np.cos(x ** 2) * np.sin(x ** 2)
assert np.allclose(res, expected, atol=tol, rtol=0)

res = jax.grad(circuit)(x)
expected = -2 * (4 * x ** 2 * np.cos(2 * x ** 2) + np.sin(2 * x ** 2))
assert np.allclose(res, expected, atol=tol, rtol=0)
8 changes: 4 additions & 4 deletions tests/gradients/test_parameter_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ def test_torch(self, tol):
torch = pytest.importorskip("torch")
from pennylane.interfaces.torch import TorchInterface

dev = qml.device("default.qubit", wires=2)
dev = qml.device("default.qubit.torch", wires=2)
params = torch.tensor([0.543, -0.654], dtype=torch.float64, requires_grad=True)

with TorchInterface.apply(qml.tape.QubitParamShiftTape()) as tape:
Expand All @@ -915,7 +915,7 @@ def test_torch(self, tol):
qml.var(qml.PauliZ(0) @ qml.PauliX(1))

tapes, fn = qml.gradients.param_shift(tape)
jac = fn([t.execute(dev) for t in tapes])
jac = fn(dev.batch_execute(tapes))
cost = jac[0, 0]
cost.backward()
hess = params.grad
Expand All @@ -938,7 +938,7 @@ def test_jax(self, tol):

config.update("jax_enable_x64", True)

dev = qml.device("default.qubit", wires=2)
dev = qml.device("default.qubit.jax", wires=2)
params = jnp.array([0.543, -0.654])

def cost_fn(x):
Expand All @@ -950,7 +950,7 @@ def cost_fn(x):

tape.trainable_params = {0, 1}
tapes, fn = qml.gradients.param_shift(tape)
jac = fn([t.execute(dev) for t in tapes])
jac = fn(dev.batch_execute(tapes))
return jac

res = jax.jacobian(cost_fn)(params)
Expand Down
Loading