[BUG] Jax compiled default.qubit.jax device raises ConversionError for qml.QubitStateVector #1670
Closed
1 task done
Labels
bug 🐛
Something isn't working
Expected behavior
The
_apply_state_vector
method seems to not have been adapted to be compatible with jax compiled code when setting a state vector withqml.QubitStateVector
.Actual behavior
Specifically in
where the norm of the state vector is calculated with
np.linalg.norm
raises ajax._src.errors.TracerArrayConversionError
.A solution could be to use the jax.numpy version instead:
jnp.linalg.norm
Additional information
No response
Source code
Tracebacks
No response
System information
The text was updated successfully, but these errors were encountered: