From ce15837d46901aab2ea9fbb32a57a16f7d0f8ce1 Mon Sep 17 00:00:00 2001 From: Romain Date: Tue, 12 Oct 2021 10:17:22 -0400 Subject: [PATCH] Update Jax documentation (#1742) * remove newlines in docstring (#1647) * Change Jax doc. * Typo. Co-authored-by: Theodor Co-authored-by: antalszava --- doc/introduction/interfaces/jax.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/introduction/interfaces/jax.rst b/doc/introduction/interfaces/jax.rst index b27e27405ad..ecd8fe049af 100644 --- a/doc/introduction/interfaces/jax.rst +++ b/doc/introduction/interfaces/jax.rst @@ -122,7 +122,7 @@ explicitly seeded. (See the `JAX random package documentation details). When simulations include randomness (i.e., if the device has a finite ``shots`` value, or the qnode -returns ``qml.samples()``), the JAX device requires a ``jax.random.PRNGKey``. Usually, PennyLane +returns ``qml.sample()``), the JAX device requires a ``jax.random.PRNGKey``. Usually, PennyLane automatically handles this for you. However, if you wish to use jitting with randomness, both the qnode and the device need to be created in the context of the ``jax.jit`` decorator. This can be achieved by wrapping device and qnode creation into a function decorated by ``@jax.jit``: @@ -136,11 +136,11 @@ Example: @jax.jit - def sample_circuit(phi, theta, key) + def sample_circuit(phi, theta, key): # Device construction should happen inside a `jax.jit` decorated # method when using a PRNGKey. - dev = qml.device('default.qubit.jax', wires=2, prng_key=key) + dev = qml.device('default.qubit.jax', wires=2, prng_key=key, shots=100) @qml.qnode(dev, interface='jax') @@ -149,9 +149,9 @@ Example: qml.RZ(phi[1], wires=1) qml.CNOT(wires=[0, 1]) qml.RX(theta, wires=0) - return qml.samples() # Here, we take samples instead. + return qml.sample() # Here, we take samples instead. - return circuit(phi, theta, key) + return circuit(phi, theta) # Get the samples from the jitted method. samples = sample_circuit([0.0, 1.0], 0.0, jax.random.PRNGKey(0))