Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] jax.jit(jax.grad()) of a circuit with shots crashes #3218

Closed
1 task done
PhilipVinc opened this issue Oct 26, 2022 · 23 comments · Fixed by #3244
Closed
1 task done

[BUG] jax.jit(jax.grad()) of a circuit with shots crashes #3218

PhilipVinc opened this issue Oct 26, 2022 · 23 comments · Fixed by #3244
Assignees
Labels
bug 🐛 Something isn't working

Comments

@PhilipVinc
Copy link
Contributor

PhilipVinc commented Oct 26, 2022

Expected behavior

Jitting the gradient of a QNode with a device using shots, when setting the PRNGKey leads to a crash. I would expect this to work.
Below there is a snippet that easily reproduces this issue on master. Do note that if you remove the jax.jit the gradient works, but this is by accident.

I think I know what is causing the bug, but the explanation is a bit involved, I will first give you a TLDR, then I will show you exactly where the crash happens, then I will reason on what is happening there.

TLDR

The problem arises because you are storing a tracer in DEfaultQubiJax._prng_key, but you are not correctly passing this prng key as an argument of the host callback in jax_jit.py:_execute. Conceptually, you should pass as an arg of the callback the prng key like you do for the parameters.

Instead, the device and therefore the _prng_key is captured in a nested series of lambdas/functions called from the callback. Therefore when the callback is executed, he encounters a tracer object for the prng key which is not substituted with concrete values and crashes.

Observing where the crash happens

As I am not very familiar with the interiors of Pennylane, and as this crash happens inside of a callback, preventing proper stack traces from being printed, I had to resort to a very primitive way of debugging.
I have added several print statements in the various functions of penny lane. You can install my copy of 'instrumented penny lane by running'

pip install git+https://github.com/PhilipVinc/pennylane@pv/debug

Using this copy, and running the snippet below, you will see the following messages printed:

INSIDE THE non_diff_wrapper CALLBACK, EXECUTING PYTHON CODE. Called with args=([array(0.34564769), array(0.45750395)], [array([1.])])
 ...
 after batch vjp, ...
  INSIDE cache_execute called with tapes=[<QuantumTape: wires=[0, 1], params=2>, <QuantumTape: wires=[0, 1], params=2>, <QuantumTape: wires=[0, 1], params=2>, <QuantumTape: wires=[0, 1], params=2>] and kwargs={}
  doing some stuff in the wrapper
   this qubit device has prng_key=Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>
   executing the circuit
     inside circuit execute for self=<DefaultQubitJax device (wires=2, shots=1000) at 0x1342ecdf0>.circuit=<QuantumTape: wires=[0, 1], params=2>
     generating samples
      - Sampling basis states for a jax qubit device with self._prng_key=Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>
      - IF THE PRNG KEY ABOVE IS A TRACER, THIS WILL CRASH!
ERROR:absl:Outside call <jax.experimental.host_callback._CallbackWrapper object at 0x134312230> threw exception Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.

The call chain at the point of the crash is the following:

In this function we use the ._prng_key to execute some jax random functions. But as I said before, this is all being executed inside of a callback, so there should be no tracers there! Instead, as the device was captured in some lambdas, the device has a tracer as a prngKey and leads to a crash.

Possible solution

The solution is to pass the prng key as an argument to the callback. In a sense, you'd need to do something similar to cp_tape for the prng key of the device.

However, this seems complicated to do because you are not passing the device itself as an argument to those functions, but it captured inside of lambdas (I think). But maybe someone who is more familiar with pennylane @antalszava might know how to do this?

Source code

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import numpy as np

import pennylane as qml

phys_qubits = 2
pars_q      = np.random.rand(3)

def minimal_circ(params, prng_key=None):
    if prng_key is not None:
        dev = qml.device("default.qubit.jax", wires=tuple(range(phys_qubits)), shots=1000, prng_key=prng_key)
    else:
        dev = qml.device("default.qubit.jax", wires=tuple(range(phys_qubits)), shots=1000)
    @qml.qnode(dev, interface="jax",diff_method="parameter-shift")
    def _measure_operator():
        qml.RY(params[0],wires=0)
        qml.RY(params[1],wires=1)
        op = qml.Hamiltonian([1.0],[qml.PauliZ(0) @ qml.PauliZ(1)])
        return qml.expval(op)
    res = _measure_operator()
    return res

grad_fun = jax.grad(minimal_circ)

jax.jit(grad_fun)(pars_q, jax.random.PRNGKey(0))

System information

>>> import pennylane as qml; qml.about()
Name: PennyLane
Version: 0.27.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /Users/filippovicentini/Documents/pythonenvs/dev-pennylane/python-3.10.7/lib/python3.10/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, retworkx, scipy, semantic-version, toml
Required-by: PennyLane-Lightning

Platform info:           macOS-13.0-x86_64-i386-64bit
Python version:          3.10.7
Numpy version:           1.23.4
Scipy version:           1.9.3
Installed devices:
- default.gaussian (PennyLane-0.27.0.dev0)
- default.mixed (PennyLane-0.27.0.dev0)
- default.qubit (PennyLane-0.27.0.dev0)
- default.qubit.autograd (PennyLane-0.27.0.dev0)
- default.qubit.jax (PennyLane-0.27.0.dev0)
- default.qubit.tf (PennyLane-0.27.0.dev0)
- default.qubit.torch (PennyLane-0.27.0.dev0)
- default.qutrit (PennyLane-0.27.0.dev0)
- null.qubit (PennyLane-0.27.0.dev0)
- lightning.qubit (PennyLane-Lightning-0.26.1)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@CatalinaAlbornoz
Copy link
Contributor

Hi @PhilipVinc, thank you for opening this issue. Can you please confirm whether you get this issue with PennyLane v0.26.0 too? Or only with the dev version?

@PhilipVinc
Copy link
Contributor Author

Hi @CatalinaAlbornoz . I just tried and the issue also exists with 0.26.0. The diagnosis is the same: the host_callback should be fed the prng_key as an argument.

@antalszava
Copy link
Contributor

Hi @PhilipVinc, thank you for the report and comments! 🙂 We'll be looking into this and come back with our findings shortly.

@PhilipVinc
Copy link
Contributor Author

Thank you, actually!

If you would like any opinion or discuss more interactively some of those Jax-related mysteries on a call, feel free to drop me an email.
This issue is currently blocking the last section of a paper we're writing, so I have a strong interest in giving any assistance you might need to address it (cc @co9olguy )

@antalszava
Copy link
Contributor

antalszava commented Oct 28, 2022

Just went through the description and tried the example myself. Agree on the points - as the key is not passed in, a tracer is leaking when jitting which will lead to a leaked tracer error.

This is definitely a byproduct of the design done in PennyLane and likely requires a major design change because the pipeline described originally is a Device API that is used by (almost) all devices in our ecosystem. Having said that, I'll continue the investigation and try to come up with a solution that could benefit the use case.

On the side, one question I have is, why would we like PRNGKey to be passed in the function minimal_circ? Would it be to follow the pure functional perspective of JAX?

Could the following (executable) solution work still, where we pass in dev to minimal_circ marked as a static argument?

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import numpy as np

import pennylane as qml

phys_qubits = 2
pars_q      = np.random.rand(3)

def minimal_circ(params, dev):
    
    @qml.qnode(dev, interface="jax",diff_method="parameter-shift")
    def _measure_operator():
        qml.RY(params[0],wires=0)
        qml.RY(params[1],wires=1)
        op = qml.Hamiltonian([1.0],[qml.PauliZ(0) @ qml.PauliZ(1)])
        return qml.expval(op)
    res = _measure_operator()
    return res

grad_fun = jax.grad(minimal_circ)

prng_key = jax.random.PRNGKey(0)
dev = qml.device("default.qubit.jax", wires=tuple(range(phys_qubits)), shots=1000, prng_key=prng_key)
jax.jit(grad_fun, static_argnums=[1])(pars_q, dev)
DeviceArray([-0.045, -0.144,  0.   ], dtype=float64)

@antalszava
Copy link
Contributor

If you would like any opinion or discuss more interactively some of those Jax-related mysteries on a call, feel free to drop me an email.

Appreciate this a lot, definitely keen to hear more about what we could improve! 🙏

Notice that the callback is essentially executing python code, so this code is not jitted itself (question: do you really need a callback here?), therefore if you have tracers anywhere at this point, they will lead to crashes.

To answer the question in the brackets: though the callback calls Python code, it encapsulates the quantum device execution which may happen using a remote simulator/a remote QPU. That's the main motivation behind us using call, such that everything before and after the quantum device execution can be jitted. Here too, keen to hear if there would be room for improvement.

@PhilipVinc
Copy link
Contributor Author

Yes, what you propose would work, in principle.

However, Jax retraces/recompiles every time you change some static information, detected by the hash of the static data. If you feed different devices, with different prng seeds, you will recompile a lot the code.

In my use case, where I have an hybrid structure coupling a Neural Network and a quantum circuit, re-compiling leads to very, very large increases in computational time (at least, when not using shots.).

As a side note, to make this work, you'd need to correctly compute the hash of your devices starting from the static data contained inside (like the prng key), so that if an user changes the prng key in the device, the hash changes, and jax recompiles.

@PhilipVinc
Copy link
Contributor Author

To answer the question in the brackets: though the callback calls Python code, it encapsulates the quantum device execution which may happen using a remote simulator/a remote QPU.

Yeah, this definitely makes sense. Though for the particular case of a local jax device you could drop it, but I agree with your analysis.

because the pipeline described originally is a Device API that is used by (almost) all devices in our ecosystem

I'm not sure I understand what is limiting here. Probably also because I fail to see exactly where the device is captured in your execute_fn.

My very uninformed understanding is that you are taking the Object-Oriented/Pythonic approach of passing around the method of an object, which implicitly captures (in a somewhat opaque manner) the underlying instance.

The standard way to do this in functional programming would be to split the functions from the data structure, so that you are obliged, in a sense, to pass the data structure as an argument. Jax likes that because it can do its tracer magic on the arguments.

@antalszava
Copy link
Contributor

Though for the particular case of a local jax device you could drop it

This is a great point! We can leverage this specifically for default.qubit.jax.

As for the other comments, yes, the OOP and the more implicit pipeline are more disadvantageous in this specific case. While PennyLane does follow functional approaches, parts of it are definitely not purely functional.

After some more local exploration, a fix should be doable and we'll be focusing on working towards having it in the code base as soon as we can. We have a release coming up soon (v0.27.0), having the fix in doesn't seem out of reach, and that should help with having the feature in with a stable release.

I'll be commenting on the progress as this work moves along. 👍

@antalszava
Copy link
Contributor

@PhilipVinc was there a specific reason for using the default.qubit.jax device? Although written in JAX natively, our original thinking for that device has been to be used with diff_method="backprop".

Switching to the C++-based lightning.qubit device and turning off caching (cache=None passed to the QNode), we can use the parameter-shift rule to get jax.jit to work with finite-shots:

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import numpy as np

import pennylane as qml

phys_qubits = 2
pars_q      = np.random.rand(3)

def minimal_circ(params):
    dev = qml.device("lightning.qubit", wires=tuple(range(phys_qubits)), shots=100)
    
    @qml.qnode(dev, interface="jax-jit",diff_method="parameter-shift", cache=None)
    def _measure_operator():
        qml.RY(params[0],wires=0)
        qml.RY(params[1],wires=1)
        op = qml.Hamiltonian([1.0],[qml.PauliZ(0) @ qml.PauliZ(1)])
        return qml.expval(op)
    res = _measure_operator()
    return res

grad_fun = jax.grad(minimal_circ)

fun = jax.jit(grad_fun)
for _ in range(5):
    print(fun(pars_q))
[-0.81 -0.01  0.  ]
[-0.79  0.08  0.  ]
[-0.77  0.23  0.  ]
[-0.79 -0.03  0.  ]
[-0.81 -0.09  0.  ]

On top of this, it could be interesting to see how the originally suggested default.qubit.jax device and PRNGKey performs compared to the above example, but it doesn't seem to unlock a completely different use case, but rather do computation completely in JAX from start to end.

@PhilipVinc
Copy link
Contributor Author

@PhilipVinc was there a specific reason for using the default.qubit.jax device? Although written in JAX natively, our original thinking for that device has been to be used with diff_method="backprop".

not really, no. I think I used it because at first I (erroneously) thought I could not mix and match different devices and qml.qnode interfaces, and then it sticked. no other reason.

On top of this, it could be interesting to see how the originally suggested default.qubit.jax device and PRNGKey performs compared to the above example, but it doesn't seem to unlock a completely different use case, but rather do computation completely in JAX from start to end.

Probably worse. Jax is especially bad at using more than 1 or 2 cores on CPU (I think its BLAS implementation is particularly conservative before switching to multi-threading) and I wouldn't be surprised if any purpose-written C kernel could beat XLA (Jax compiler) when applying gates...

@PhilipVinc
Copy link
Contributor Author

Thanks for the snippet!

I'll surely try this out. Just to understand... how will the RNG seed work in that case? is it using some internal state that gets updated every time he calls back into python/lightnight?

@antalszava
Copy link
Contributor

Just to understand... how will the RNG seed work in that case? is it using some internal state that gets updated every time he calls back into python/lightnight?

The sampling (including the RNG seed generation) is completely encompassed in the function that is invoked by host_callback.call.

The function passed to host_callback.call is a wrapper around the execute_fn callable argument which mainly uses the QubitDevice.batch_execute method. QubitDevice.batch_execute is device-agnostic: under the hood, the device steps of device execution including how the samples are being generated (QubitDevice.generate_samples method, that may be overridden). Specifically default.qubit.jax requires the RNG seed to generate differing samples with jax.jit and redefines the sample_basis_states used by generate_samples. Other devices, however, have their own sampling mechanisms, including PennyLane-Lightning that redefines the entire of generate_samples.

@PhilipVinc
Copy link
Contributor Author

@antalszava thanks a lot for the snippet, indeed it pushes us forward!
However, I had provided you with an MWE. What we actually want to do is to compute the jacobian of this expectation value, which means taking the vmap of grad.

To get to the bottom of what we'd need, here is a more complicated MWE that breaks down once we start playing with vmap.

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import numpy as np
import pennylane as qml

phys_qubits = 2
n_configs = 5
pars_q      = np.random.rand(n_configs,2)

def minimal_circ(params):
    dev = qml.device("lightning.qubit", wires=tuple(range(phys_qubits)), shots=100)
    
    @qml.qnode(dev, interface="jax-jit",diff_method="parameter-shift", cache=None)
    def _measure_operator():
        qml.RY(params[0],wires=0)
        qml.RY(params[1],wires=1)
        op = qml.Hamiltonian([1.0],[qml.PauliZ(0) @ qml.PauliZ(1)])
        return qml.expval(op)
    res = _measure_operator()
    return res

# this works
jax.jit(minimal_circ)(pars_q[0])
jax.jit(jax.grad(minimal_circ))(pars_q[0])

# Vmapping the circuit does not work
minimal_circ_batch = jax.vmap(minimal_circ)
minimal_circ_batch(pars_q)

# Getting the jacobian (aka, vmap of grad) does not work
jax.vmap(jax.grad(minimal_circ))(pars_q)

In short, we want to run the same circuit for several different parameters, and compute the gradient of those. At least for the forward pass, I'd expect that your lightning device should support batching (vmap, in jax parlance) and speed up the calculation. I'm unsure if that would also work for the backward pass...

Of course, host callback does not support vmapping. In that case, if your underlying C code does support it you could transition to jax.pure_callback which would support it.

A note: your lightning execute function is not functionally pure because you have an hidden RNG state for the shots...

@PhilipVinc
Copy link
Contributor Author

PhilipVinc commented Oct 31, 2022

Ok, I just noticed that you implicitly support parameter broadcasting by adding leading dimensions, so that

minimal_circ(pars_q)

works and does what I wanted to do with

jax.vmap(minimal_circ)(pars_q)

So I'm even more sure that jax.pure_callback(..., vectorized =True) in lieu of host_callback.call should work for you.

About the gradient... It's going to be a bit more tricky.

Note: If someone wonders why would I want to use vmap when your qnodes already vectorise.. The code above is a MWE. In my actual use case I mix this with jax code that I need to vmap.

@antalszava
Copy link
Contributor

Hi @PhilipVinc, the behaviour with jax.jacobian is a known limitation that we have with the JAX-Jit interface and there with discussion points in #2163 too. Our attempt to allow jax.jacobian to work with jax.jit is by updating our custom gradient recipe (this will also mean a shift from custom_vjp to custom_jvp as you may have come across it in #3235).

The use of jax.pure_callback could be investigated for a shorter-term fix.

@PhilipVinc
Copy link
Contributor Author

@antalszava changing all host callback to pure_callback(, vectorized=False) gets vmap and jacobian to work with no further effort!

Though performance will be sub-optimal, because he'll be inserting a loop. But I think it's possible to make it work without switching to jvp...

@antalszava
Copy link
Contributor

You're right! 🎉 Tried it in #3244. We'll check how JAX's vectorized version could be wired in with PennyLane's parameter broadcasting.

As for the JVPs: it's a change we've been contemplating anyways because it allows both jacrev and jacfwd to work.

@PhilipVinc
Copy link
Contributor Author

Can I open an issue to track properly vectorising the grad call when using the jax interface?

@antalszava
Copy link
Contributor

For sure! Just wanted to leave a comment here, mentioning that although this issue is being closed, we'd like to track the improvements we discussed.

@antalszava
Copy link
Contributor

@PhilipVinc could you also open an issue describing the improvements we could make to the parameter broadcasting UI?

@PhilipVinc
Copy link
Contributor Author

Yes it's on my to do list. I have a deadline on monday so my eta is 10 days to be able to phrase something decently.

And thanks for implementing this thing. It's very helpful for us being able to work with Jax without worrying too much about how to work around issues..

@antalszava
Copy link
Contributor

For sure! Sounds good. 👍

I've opened an issue to Skip using a callback for default.qubit.jax and diff_method="parameter-shift":
#3259

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants