Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow returning the state #818

Merged
merged 55 commits into from
Sep 24, 2020
Merged

Allow returning the state #818

merged 55 commits into from
Sep 24, 2020

Conversation

trbromley
Copy link
Contributor

@trbromley trbromley commented Sep 21, 2020

This is a new version of #737 and #789.

This PR allows users to return the quantum state using a new state() function:

import pennylane as qml
from pennylane.beta.tapes import qnode
from pennylane.beta.queuing import state

dev = qml.device("default.qubit", wires=2)

@qnode(dev)
def f(x):
    qml.RX(x, wires=0)
    qml.RY(0.4, wires=1)
    qml.CNOT(wires=[0, 1])
    return state()

f(0.3)

Not currently implemented:

  • Support for differentiating the state using the jacobian() method, although it is possible using a PassthruQNode
  • Support for returning the state in CV devices. It is a simple change to support in principle (e.g., edit the execute() method in Device to return the state, similar to in QubitDevice). However, there is a greater variety in the objects stored in dev.state for CV devices. The default.gaussian device stores a tuple of cov and means in dev._state, while the PL-SF devices return a state object. These are not immediately compatible with Device, which is expecting something it can convert into an array.
  • Support for returning the state in versions of torch before 1.6.0 (since complex numbers are not supported)

Things to note:

  • The circuit drawer could not be tested as it is not yet available in the new core

  • The state is still analytic when the device is in non_analytic mode. I think this makes sense, but just wanted to make a note of it.

@trbromley trbromley self-assigned this Sep 21, 2020
@codecov
Copy link

codecov bot commented Sep 21, 2020

Codecov Report

Merging #818 into master will increase coverage by 0.03%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #818      +/-   ##
==========================================
+ Coverage   90.92%   90.95%   +0.03%     
==========================================
  Files         129      129              
  Lines        8548     8584      +36     
==========================================
+ Hits         7772     7808      +36     
  Misses        776      776              
Impacted Files Coverage Δ
pennylane/_device.py 95.85% <100.00%> (+0.04%) ⬆️
pennylane/_qubit_device.py 98.75% <100.00%> (+0.10%) ⬆️
pennylane/beta/interfaces/torch.py 100.00% <100.00%> (ø)
pennylane/beta/queuing/__init__.py 100.00% <100.00%> (ø)
pennylane/beta/queuing/measure.py 96.92% <100.00%> (+0.04%) ⬆️
pennylane/beta/tapes/qnode.py 98.59% <100.00%> (+0.06%) ⬆️
pennylane/beta/tapes/tape.py 98.76% <100.00%> (+0.01%) ⬆️
...ennylane/circuit_drawer/representation_resolver.py 99.29% <100.00%> (+0.01%) ⬆️
pennylane/operation.py 54.36% <100.00%> (+0.22%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 58fda45...184f633. Read the comment docs.

@trbromley trbromley requested review from co9olguy, antalszava, thisac and mariaschuld and removed request for thisac September 21, 2020 16:16
qml.CNOT(wires=[0, 1])
qml.RY(y, wires=1)
qml.CNOT(wires=[0, 2])
return state(range(3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought here, we could even make this implement full state tomography in future to work on hardware devices...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be cool! I guess that could be done in a state() method or attribute of the device


Support for returning the quantum state of the QNode is also provided. Similar to the
:func:`~.pennylane.probs` measurement function, **observables should not be input** into the
:func:`~.state` function. Moreover, the state must be returned over **all** wires in the device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is obviously a bit strange from a user perspective: An argument that one has to set to a very particular value without any options. You mention that otherwise there were side effects, is it not better to deal with them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I know this is usually easier said than done :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the state is returned over all wires, why does the user need to even provide any arguments? 🤔

Copy link
Contributor Author

@trbromley trbromley Sep 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, now the user no longer needs to provide wires: doing return state() is enough.

If you can think of any edge cases of doing this, please let me know!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we didn't require state() to have a wires argument, but this led to some side effects

Would any of the previous side effects be relevant here too?

if self.interface is not None:
# pylint: disable=protected-access
if self.qtape._returns_state and self.interface in ["torch", "tf"]:
self.INTERFACE_MAP[self.interface](self, dtype=np.complex128)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to understand, why is that necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The state is complex, and we need to pass this on to the to_tf() and to_torch() functions, which then use the interface apply() methods with the corresponding dtype. If we didn't do this, the state would be cast to reals.

I also added a comment just above this line to help.

Comment on lines 277 to 282
state = self.state
except AttributeError:
state = None

if state is None:
raise AttributeError("The state is not available in the current device")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid the awkwardness of doing it this way, you could do someting like getattr(self, "state", None)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it! Have done this.

trbromley and others added 2 commits September 23, 2020 11:14
@antalszava
Copy link
Contributor

Hi @trbromley, had a high-level review here. Looking great from my side. Leaving a comment here, as didn't delve into more details and saw that there might be an open question regarding wires. Let me know if it would be good to have another look though!

@trbromley
Copy link
Contributor Author

trbromley commented Sep 23, 2020

Thanks @co9olguy good point regarding wires.

The one to worry about is the Cirq plugin. When a user passes custom qubits, the indices of the wires do not necessary match the order they appear in the list of qubits. I can guarantee users will get confused (if they ever try to do it).

Yes, I managed to find a case where the ordering alters the state in Cirq, see below. Rather than being due to custom wire labels, I think this is linked to the optional qubits argument. As you mention, the device expresses the state following an internal order. One fix could be to extend the new access_state() method in QubitDevice (introduced in this PR to check if the state is available and then return it). At the end of this method, we could have:

dev_wires = [w.tolist()[0] for w in self.wire_map.values()]
state = state.reshape([2] * self.num_wires).transpose(dev_wires).ravel()

This transposes the state back into the order a user might expect and then agrees with default.qubit in the example below.

Is there a reason to only support this behaviour in "core" devices? Because if we want to support other plugins, we will have to think more carefully about this and potentially modify how it is done (either now or in the future)

No I don't think we should restrict to core devices in principle. However, I think we should get plugin devices to overwrite the access_state() method with machinery similar to above. This makes sense: devices will have to declare that they support returning the state in the capabilities dictionary, and also make sure that the state is in lexicographic order according to the user-input wires by editing the access_state() method.

Another thing to note is that users wouldn't currently be able to return the state in Cirq until we update the capabilities. Hence, I'd recommend we do this as a follow up in the Cirq plugin. Additionally, we may want to add details to the building a plugin page.

Here is the example:

import pennylane as qml
from pennylane.beta.queuing import state, probs
from pennylane.beta.tapes import QNode
import cirq
import itertools
from pennylane import numpy as np

qubits = [
  cirq.GridQubit(0, 0),
  cirq.GridQubit(0, 1),
  cirq.GridQubit(1, 0),
]

for q in itertools.permutations(qubits):
    dev_circ = qml.device("cirq.simulator", wires=3, qubits=q)
    dev_def = qml.device("default.qubit", wires=3)

    def my_quantum_function(x, y):
        qml.RZ(x, wires=0)
        qml.CNOT(wires=[0, 1])
        qml.RY(y, wires=1)
        qml.CNOT(wires=[0, 2])
        return state()

    qnode_circ = QNode(my_quantum_function, dev_circ)
    qnode_def = QNode(my_quantum_function, dev_def)

    res_circ = qnode_circ(0.56, 0.1)
    res_def = qnode_def(0.56, 0.1)
    print(np.allclose(res_circ, res_def))

@trbromley
Copy link
Contributor Author

Thanks @antalszava for the review!

return state()

state_ev = func()
assert np.allclose(state_ev, dev.state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought: some of these tests would be nice to be added to the shared device tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I also want to add some tests to the shared suite that compare a given device to the results from default.qubit (we did this in lightning.qubit). Seems a good way to make sure everything is on a par. Will add to my list of ideas.

Copy link
Contributor

@mariaschuld mariaschuld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better without wires argument :)

Great!

if self.wires.labels != tuple(range(self.num_wires)):
raise QuantumFunctionError(
"Returning the state is not supported when using " "custom wire labels"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trbromley and I were discussing and decided there are too many possible ambiguities here. This solution gives us the possibility to reach the vast majority of users while deferring unlikely but difficult edge cases to be further supported later on

QuantumFunctionError: if the device is not capable of returning the state

Returns:
array: the state of the device
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be array or tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, thanks!

Copy link
Member

@co9olguy co9olguy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @trbromley!

raise ValueError("Cannot set the wires if an observable is provided.")

self._wires = wires
if wires is not None and obs is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking whether wires is None is no longer needed (from above code, it can never be)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I've moved this check up a few lines so that it is still relevant.

@@ -356,10 +357,14 @@ def construct(self, args, kwargs):
"or a nonempty sequence of measurements."
)

state_returns = any([m.return_type is State for m in measurement_processes])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this feels like a better way to do this 👍

"Version 1.6.0 or above of PyTorch must be installed"
"for complex support, such as returning the state"
)
self.dtype = torch.complex128
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a biggie, but would we want to support other complex number types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not needed for this state PR - we are choosing np.complex128 for the state dtype earlier on and then just converting here.

The other use case might be e.g. if a user passes np.complex64, but I don't think the to_torch() function is very visible.

As a side note, I hope Torch introduce an analogue of tf.as_dtype(), which would make this easier.

@@ -1284,6 +1279,9 @@ def jacobian(self, device, params=None, **options):
>>> tape.jacobian(dev)
array([], shape=(4, 0), dtype=float64)
"""
if any([m.return_type is State for m in self.measurements]):
raise ValueError("The jacobian method does not support circuits that return the state")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One day it might be cool to consider this 😆

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! We already have done some of the groundwork for the parameter shift rule.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something to keep in mind as we evolve the parameter-shift rule --- allowing gate grad recipes to be altered based on the circuit output.

state = func()
assert state.shape == (1, 2 ** wires)
assert state.dtype == np.complex128
state_ev = func()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come ev? It's not an expectation value 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state_val perhaps better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed! I was going for "ev"aluated, but yeah state_val looks better

@trbromley trbromley merged commit 6b77015 into master Sep 24, 2020
@trbromley trbromley deleted the add_state_return_to_tape branch September 24, 2020 21:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants