diff --git a/doc/introduction/interfaces/jax.rst b/doc/introduction/interfaces/jax.rst index b27e27405ad..ecd8fe049af 100644 --- a/doc/introduction/interfaces/jax.rst +++ b/doc/introduction/interfaces/jax.rst @@ -122,7 +122,7 @@ explicitly seeded. (See the `JAX random package documentation details). When simulations include randomness (i.e., if the device has a finite ``shots`` value, or the qnode -returns ``qml.samples()``), the JAX device requires a ``jax.random.PRNGKey``. Usually, PennyLane +returns ``qml.sample()``), the JAX device requires a ``jax.random.PRNGKey``. Usually, PennyLane automatically handles this for you. However, if you wish to use jitting with randomness, both the qnode and the device need to be created in the context of the ``jax.jit`` decorator. This can be achieved by wrapping device and qnode creation into a function decorated by ``@jax.jit``: @@ -136,11 +136,11 @@ Example: @jax.jit - def sample_circuit(phi, theta, key) + def sample_circuit(phi, theta, key): # Device construction should happen inside a `jax.jit` decorated # method when using a PRNGKey. - dev = qml.device('default.qubit.jax', wires=2, prng_key=key) + dev = qml.device('default.qubit.jax', wires=2, prng_key=key, shots=100) @qml.qnode(dev, interface='jax') @@ -149,9 +149,9 @@ Example: qml.RZ(phi[1], wires=1) qml.CNOT(wires=[0, 1]) qml.RX(theta, wires=0) - return qml.samples() # Here, we take samples instead. + return qml.sample() # Here, we take samples instead. - return circuit(phi, theta, key) + return circuit(phi, theta) # Get the samples from the jitted method. samples = sample_circuit([0.0, 1.0], 0.0, jax.random.PRNGKey(0))