-
Notifications
You must be signed in to change notification settings - Fork 575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow returning the state #818
Conversation
Codecov Report
@@ Coverage Diff @@
## master #818 +/- ##
==========================================
+ Coverage 90.92% 90.95% +0.03%
==========================================
Files 129 129
Lines 8548 8584 +36
==========================================
+ Hits 7772 7808 +36
Misses 776 776
Continue to review full report at Codecov.
|
.github/CHANGELOG.md
Outdated
qml.CNOT(wires=[0, 1]) | ||
qml.RY(y, wires=1) | ||
qml.CNOT(wires=[0, 2]) | ||
return state(range(3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a thought here, we could even make this implement full state tomography in future to work on hardware devices...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be cool! I guess that could be done in a state()
method or attribute of the device
doc/introduction/measurements.rst
Outdated
|
||
Support for returning the quantum state of the QNode is also provided. Similar to the | ||
:func:`~.pennylane.probs` measurement function, **observables should not be input** into the | ||
:func:`~.state` function. Moreover, the state must be returned over **all** wires in the device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is obviously a bit strange from a user perspective: An argument that one has to set to a very particular value without any options. You mention that otherwise there were side effects, is it not better to deal with them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I know this is usually easier said than done :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the state is returned over all wires, why does the user need to even provide any arguments? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, now the user no longer needs to provide wires: doing return state()
is enough.
If you can think of any edge cases of doing this, please let me know!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we didn't require state() to have a wires argument, but this led to some side effects
Would any of the previous side effects be relevant here too?
if self.interface is not None: | ||
# pylint: disable=protected-access | ||
if self.qtape._returns_state and self.interface in ["torch", "tf"]: | ||
self.INTERFACE_MAP[self.interface](self, dtype=np.complex128) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to understand, why is that necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The state is complex, and we need to pass this on to the to_tf()
and to_torch()
functions, which then use the interface apply()
methods with the corresponding dtype. If we didn't do this, the state would be cast to reals.
I also added a comment just above this line to help.
pennylane/_qubit_device.py
Outdated
state = self.state | ||
except AttributeError: | ||
state = None | ||
|
||
if state is None: | ||
raise AttributeError("The state is not available in the current device") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid the awkwardness of doing it this way, you could do someting like getattr(self, "state", None
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like it! Have done this.
Co-authored-by: antalszava <antalszava@gmail.com>
Hi @trbromley, had a high-level review here. Looking great from my side. Leaving a comment here, as didn't delve into more details and saw that there might be an open question regarding wires. Let me know if it would be good to have another look though! |
Thanks @co9olguy good point regarding wires.
Yes, I managed to find a case where the ordering alters the state in Cirq, see below. Rather than being due to custom wire labels, I think this is linked to the optional dev_wires = [w.tolist()[0] for w in self.wire_map.values()]
state = state.reshape([2] * self.num_wires).transpose(dev_wires).ravel() This transposes the state back into the order a user might expect and then agrees with
No I don't think we should restrict to core devices in principle. However, I think we should get plugin devices to overwrite the Another thing to note is that users wouldn't currently be able to return the state in Cirq until we update the capabilities. Hence, I'd recommend we do this as a follow up in the Cirq plugin. Additionally, we may want to add details to the building a plugin page. Here is the example: import pennylane as qml
from pennylane.beta.queuing import state, probs
from pennylane.beta.tapes import QNode
import cirq
import itertools
from pennylane import numpy as np
qubits = [
cirq.GridQubit(0, 0),
cirq.GridQubit(0, 1),
cirq.GridQubit(1, 0),
]
for q in itertools.permutations(qubits):
dev_circ = qml.device("cirq.simulator", wires=3, qubits=q)
dev_def = qml.device("default.qubit", wires=3)
def my_quantum_function(x, y):
qml.RZ(x, wires=0)
qml.CNOT(wires=[0, 1])
qml.RY(y, wires=1)
qml.CNOT(wires=[0, 2])
return state()
qnode_circ = QNode(my_quantum_function, dev_circ)
qnode_def = QNode(my_quantum_function, dev_def)
res_circ = qnode_circ(0.56, 0.1)
res_def = qnode_def(0.56, 0.1)
print(np.allclose(res_circ, res_def)) |
Thanks @antalszava for the review! |
return state() | ||
|
||
state_ev = func() | ||
assert np.allclose(state_ev, dev.state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thought: some of these tests would be nice to be added to the shared device tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I also want to add some tests to the shared suite that compare a given device to the results from default.qubit
(we did this in lightning.qubit
). Seems a good way to make sure everything is on a par. Will add to my list of ideas.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much better without wires argument :)
Great!
if self.wires.labels != tuple(range(self.num_wires)): | ||
raise QuantumFunctionError( | ||
"Returning the state is not supported when using " "custom wire labels" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@trbromley and I were discussing and decided there are too many possible ambiguities here. This solution gives us the possibility to reach the vast majority of users while deferring unlikely but difficult edge cases to be further supported later on
pennylane/_qubit_device.py
Outdated
QuantumFunctionError: if the device is not capable of returning the state | ||
|
||
Returns: | ||
array: the state of the device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be array or tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @trbromley!
pennylane/beta/queuing/measure.py
Outdated
raise ValueError("Cannot set the wires if an observable is provided.") | ||
|
||
self._wires = wires | ||
if wires is not None and obs is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checking whether wires
is None
is no longer needed (from above code, it can never be)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I've moved this check up a few lines so that it is still relevant.
@@ -356,10 +357,14 @@ def construct(self, args, kwargs): | |||
"or a nonempty sequence of measurements." | |||
) | |||
|
|||
state_returns = any([m.return_type is State for m in measurement_processes]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this feels like a better way to do this 👍
"Version 1.6.0 or above of PyTorch must be installed" | ||
"for complex support, such as returning the state" | ||
) | ||
self.dtype = torch.complex128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a biggie, but would we want to support other complex number types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's not needed for this state PR - we are choosing np.complex128
for the state dtype earlier on and then just converting here.
The other use case might be e.g. if a user passes np.complex64
, but I don't think the to_torch()
function is very visible.
As a side note, I hope Torch introduce an analogue of tf.as_dtype(), which would make this easier.
@@ -1284,6 +1279,9 @@ def jacobian(self, device, params=None, **options): | |||
>>> tape.jacobian(dev) | |||
array([], shape=(4, 0), dtype=float64) | |||
""" | |||
if any([m.return_type is State for m in self.measurements]): | |||
raise ValueError("The jacobian method does not support circuits that return the state") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One day it might be cool to consider this 😆
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! We already have done some of the groundwork for the parameter shift rule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something to keep in mind as we evolve the parameter-shift rule --- allowing gate grad recipes to be altered based on the circuit output.
state = func() | ||
assert state.shape == (1, 2 ** wires) | ||
assert state.dtype == np.complex128 | ||
state_ev = func() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How come ev
? It's not an expectation value 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
state_val
perhaps better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed! I was going for "ev"aluated, but yeah state_val
looks better
This is a new version of #737 and #789.
This PR allows users to return the quantum state using a new
state()
function:Not currently implemented:
jacobian()
method, although it is possible using a PassthruQNodeexecute()
method inDevice
to return the state, similar to inQubitDevice
). However, there is a greater variety in the objects stored indev.state
for CV devices. Thedefault.gaussian
device stores a tuple of cov and means indev._state
, while the PL-SF devices return a state object. These are not immediately compatible withDevice
, which is expecting something it can convert into an array.Things to note:
The circuit drawer could not be tested as it is not yet available in the new core
The state is still analytic when the device is in
non_analytic
mode. I think this makes sense, but just wanted to make a note of it.