Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Jax compiled default.qubit.jax device raises ConversionError for qml.QubitStateVector #1670

Closed
1 task done
bonfab opened this issue Sep 20, 2021 · 5 comments
Closed
1 task done
Assignees
Labels
bug 🐛 Something isn't working

Comments

@bonfab
Copy link

bonfab commented Sep 20, 2021

Expected behavior

The _apply_state_vector method seems to not have been adapted to be compatible with jax compiled code when setting a state vector with qml.QubitStateVector.

Actual behavior

Specifically in

if not np.allclose(np.linalg.norm(state, ord=2), 1.0, atol=tolerance):
            raise ValueError("Sum of amplitudes-squared does not equal one.")

where the norm of the state vector is calculated with np.linalg.norm raises a jax._src.errors.TracerArrayConversionError.

A solution could be to use the jax.numpy version instead: jnp.linalg.norm

Additional information

No response

Source code

import pennylane as qml
import jax
import numpy as np

def circuit(x):
	wires = list(range(2))
	qml.QubitStateVector(x, wires=wires)
	return [qml.expval(qml.PauliX(wires=i)) for i in wires]
	

dev = qml.device("default.qubit.jax", wires=list(range(2)))

qnode = jax.jit(qml.QNode(circuit, dev, interface="jax"))


state_vector = np.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0])

f_norm = jax.jit(jax.numpy.linalg.norm) # works

#f_norm = jax.jit(np.linalg.norm) # does not work, raises same error

print(f_norm(state_vector))

qnode(state_vector)

Tracebacks

No response

System information

Name: PennyLane
Version: 0.17.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /home/fabian/.local/lib/python3.8/site-packages
Requires: appdirs, networkx, semantic-version, numpy, scipy, toml, autoray, autograd
Required-by: pennylane-qulacs, PennyLane-qiskit
Platform info:           Linux-5.11.0-34-generic-x86_64-with-glibc2.29
Python version:          3.8.10
Numpy version:           1.19.2
Scipy version:           1.7.1
Installed devices:
- default.gaussian (PennyLane-0.17.0)
- default.mixed (PennyLane-0.17.0)
- default.qubit (PennyLane-0.17.0)
- default.qubit.autograd (PennyLane-0.17.0)
- default.qubit.jax (PennyLane-0.17.0)
- default.qubit.tf (PennyLane-0.17.0)
- default.tensor (PennyLane-0.17.0)
- default.tensor.tf (PennyLane-0.17.0)
- qulacs.simulator (pennylane-qulacs-0.15.0)
- qiskit.aer (PennyLane-qiskit-0.17.0)
- qiskit.basicaer (PennyLane-qiskit-0.17.0)
- qiskit.ibmq (PennyLane-qiskit-0.17.0)

  • I have searched exisisting GitHub issues to make sure the issue does not already exist.
@bonfab bonfab added the bug 🐛 Something isn't working label Sep 20, 2021
@CatalinaAlbornoz
Copy link
Contributor

Hi @bonfab! Thank you for reporting this bug. We'll get on it and try to fix it!

@josh146
Copy link
Member

josh146 commented Sep 21, 2021

One approach to fix this could be to make sure that the qml.math.linalg.norm and qml.math.allclose functions both work with the JAX jit --- once this is the case, we can modify this default.qubit method to use these functions instead

@bonfab
Copy link
Author

bonfab commented Sep 21, 2021

I realized it might be not as trivial to fix as first thought. Even after adapting the source code to jax.numpy one still receives a ConcretizationTypeError. Only workaround working for me at the moment is to comment out the check completely.

@CatalinaAlbornoz
Copy link
Contributor

Yes @bonfab, I was also getting an error. I think Josh's approach is a good way to go. If you have any other ideas on how to fix this bug let us know here!

@antalszava
Copy link
Contributor

Hi @bonfab, with #1683 merged, this should be resolved in the master branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants