You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using decompositions and transforms does not change the derivative of the overall workflow.
Actual behavior
Some decompositions/transforms only reproduce the function, but not its derivative. I found this in the following parts of the codebase:
merge_rotations: Some rotation gates are skipped for zero angles
single_qubit_fusion: Some rotation gates are skipped for zero angles
MottonenStatePreparation: Depending on the input state, gates are skipped, which leads to errors with JITting (no gradient entries to stack) or produces nan values.
fuse_rot_angles: Used in merge_rotations and single_qubit_fusion, creates second bugs within both functions
Additional information
Note that JITting usually prevents the source of error (except for MottonenStatePrep), and in all examples above, the code base has special logic for JITting.
As a consequence, JITted derivatives tend to be unaffected by the type of bug observed in the transforms.
Under the hood, this seems like similar to #5541, which is concerned with AmplitudeEmbedding and is being solved in #5620 by modifying the diff method of GlobalPhase. However, the bug described here is of different origin and was encountered while finalizing the tests for #5620 for MottonenStatePreparation.
Source code
#### BUG caused by merge_rotations itself
@qml.transforms.merge_rotations
def _node(x):
qml.RX(x, 1)
qml.RX(x, 1)
return qml.expval(qml.Y(1))
dev = qml.device("default.qubit")
node_ps = qml.QNode(_node, dev, diff_method="parameter-shift")
node_ps_jit = jax.jit(qml.QNode(_node, dev, diff_method="parameter-shift"))
node_ad = qml.QNode(_node, dev)
node_ad_jit = jax.jit(qml.QNode(_node, dev))
print("Derivatives at 0:")
fornode_in [node_ps, node_ps_jit, node_ad, node_ad_jit]:
print(jax.jacobian(node_)(0.))
>>> Derivatives at 0:
... 0.0
... -2.0
... 0.0
... -1.9999999999999996
print("Derivatives close to 0:")
fornode_in [node_ps, node_ps_jit, node_ad, node_ad_jit]:
print(jax.jacobian(node_)(1e-8))
>>> Derivatives close to 0:
... -1.9999999999999993
... -1.9999999999999993
... -1.9999999999999993
... -1.9999999999999993
#### BUG caused by fuse_rot_angles via merge_rotations
@qml.transforms.merge_rotations
def _node(x):
qml.Rot(x, x, x, 1)
qml.Rot(x, x, x, 1)
return qml.expval(qml.X(1))
dev = qml.device("default.qubit")
node_ps = qml.QNode(_node, dev, diff_method="parameter-shift")
node_ps_jit = jax.jit(qml.QNode(_node, dev, diff_method="parameter-shift"))
node_ad = qml.QNode(_node, dev)
node_ad_jit = jax.jit(qml.QNode(_node, dev))
print("Derivatives at 0:")
fornode_in [node_ps, node_ps_jit, node_ad, node_ad_jit]:
print(jax.jacobian(node_)(0.))
>>> Derivatives at 0:
... 0.0
... 2.0
... 0.0
... 1.9999999999999996
print("Derivatives close to 0:")
fornode_in [node_ps, node_ps_jit, node_ad, node_ad_jit]:
print(jax.jacobian(node_)(1e-6))
>>> Derivatives close to 0:
... 2.0000221220668224
... 1.9999999999840001
... 2.0000221220668215
... 1.9999999999839995
#### BUGS in single_qubit_fusion, one in the function itself, one from fuse_rot_angles
@partial(qml.transforms.single_qubit_fusion, atol=1e-6)
def _node(x):
qml.RX(x, 1)
qml.RX(x, 1)
return qml.expval(qml.Y(1))
dev = qml.device("default.qubit")
node_ps = qml.QNode(_node, dev, diff_method="parameter-shift")
node_ps_jit = jax.jit(qml.QNode(_node, dev, diff_method="parameter-shift"))
node_ad = qml.QNode(_node, dev)
node_ad_jit = jax.jit(qml.QNode(_node, dev))
print("Derivatives at 0:")
fornode_in [node_ps, node_ps_jit, node_ad, node_ad_jit]:
print(jax.jacobian(node_)(0.))
>>> Derivatives at 0:
... 0.0
... 0.0
... 0.0
... 0.0
print("Derivatives close to 0:")
fornode_in [node_ps, node_ps_jit, node_ad, node_ad_jit]:
print(jax.jacobian(node_)(1e-7))
>>> Derivatives close to 0:
... 0.0
... -2.000799757290469
... 0.0
... -2.000799757290469
print("Derivatives less close to 0:")
fornode_in [node_ps, node_ps_jit, node_ad, node_ad_jit]:
print(jax.jacobian(node_)(1e-5))
>>> Derivatives less close to 0:
... -1.9999999168263007
... -1.9999999168263007
... -1.9999999168263003
... -1.9999999168263003
#### BUGS with MottonenStatePreparation
def _node(x):
qml.MottonenStatePreparation(x, wires=[0, 1])
returnqml.probs()
dev = qml.device("default.qubit")
node_ps = qml.QNode(_node, dev, diff_method="parameter-shift")
node_ps_jit = jax.jit(qml.QNode(_node, dev, diff_method="parameter-shift"))
node_ad = qml.QNode(_node, dev)
node_ad_jit = jax.jit(qml.QNode(_node, dev))
x1 = jnp.array([1, 1, 0, 1]) / np.sqrt(3)
fornode_in [node_ps, node_ad, node_ps_jit, node_ad_jit]: # Fails with JITted nodes
print(jax.jacobian(node_)(x1))
>>> [[ 7.69800359e-01 -3.84900179e-01 nan nan]
... [-3.84900179e-01 7.69800359e-01 nan nan]
... [-4.80740672e-17 4.80740672e-17 nan nan]
... [-3.84900179e-01 -3.84900179e-01 nan nan]]
>>> [[ 7.69800359e-01 -3.84900179e-01 nan nan]
... [-3.84900179e-01 7.69800359e-01 nan nan]
... [-4.80740672e-17 4.80740672e-17 nan nan]
... [-3.84900179e-01 -3.84900179e-01 nan nan]]
x2 = jnp.array([1, 0, 0, 1]) / np.sqrt(2)
fornode_in [node_ps, node_ad, node_ps_jit, node_ad_jit]: # Fails with JITted nodes
print(jax.jacobian(node_)(x2))
>>> [[nan nan nan nan]
... [nan nan nan nan]
... [nan nan nan nan]
... [nan nan nan nan]]
>>> [[nan nan nan nan]
... [nan nan nan nan]
... [nan nan nan nan]
... [nan nan nan nan]]
Tracebacks
No response
System information
pl dev
Existing GitHub issues
I have searched existing GitHub issues to make sure the issue does not already exist.
The text was updated successfully, but these errors were encountered:
While trying to fix this, I noticed that fuse_rot_angles uses a function that - as it stands - is not differentiable everywhere. At those singular points, we're returning wrong derivatives in yet another way :/
**Context:**
The decomposition of `MottonenStatePreparation` skips some gates for
special parameter values/input states.
See the linked issue for details.
**Description of the Change:**
This PR introduces a check for differentiability so that the gates only
are skipped when no derivatives are being computed.
Note that this does *not* fix the non-differentiability at other special
parameter points that also is referenced in #5715 and that is being
warned against in the docs already.
Also, the linked issue is about multiple operations and we here only
address `MottonenStatePreparation`.
**Benefits:**
Fixes parts of #5715. Unblocks #5620 .
**Possible Drawbacks:**
**Related GitHub Issues:**
#5715
Expected behavior
Using decompositions and transforms does not change the derivative of the overall workflow.
Actual behavior
Some decompositions/transforms only reproduce the function, but not its derivative. I found this in the following parts of the codebase:
merge_rotations
: Some rotation gates are skipped for zero anglessingle_qubit_fusion
: Some rotation gates are skipped for zero anglesMottonenStatePreparation
: Depending on the input state, gates are skipped, which leads to errors with JITting (no gradient entries to stack) or producesnan
values.fuse_rot_angles
: Used inmerge_rotations
andsingle_qubit_fusion
, creates second bugs within both functionsAdditional information
Note that JITting usually prevents the source of error (except for
MottonenStatePrep
), and in all examples above, the code base has special logic for JITting.As a consequence, JITted derivatives tend to be unaffected by the type of bug observed in the transforms.
Under the hood, this seems like similar to #5541, which is concerned with
AmplitudeEmbedding
and is being solved in #5620 by modifying the diff method ofGlobalPhase
. However, the bug described here is of different origin and was encountered while finalizing the tests for #5620 forMottonenStatePreparation
.Source code
Tracebacks
No response
System information
Existing GitHub issues
The text was updated successfully, but these errors were encountered: