Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Performance digression of decomposition-based differentiation of SpecialUnitary #4635

Open
dwierichs opened this issue Sep 27, 2023 · 3 comments
Labels
bug 🐛 Something isn't working enhancement ✨ New feature or request

Comments

@dwierichs
Copy link
Contributor

Feature details

In #4585, the decomposition of TmpPauliRot, a helper object of SpecialUnitary, was changed in order to allow the new DefaultQubit device to differentiate SpecialUnitary.
The decomposition-based differentiation pipeline assumed that the device's decomposition step before execution would happen after the trainable parameters are determined and corresponding gradient transforms are called.
Unfortunately, this changes with the new device, so that now the zero-angle instances of TmpPauliRot are decomposed into zero-angle instances of PauliRot, which in turn make it into the execution pipeline.
This makes the simulator execute a lot of identity operations, with causes some overhead (benchmark to be provided)

This affects differentiation of SpecialUnitary only, and only when using a differentiation method that uses the decomposition, like param_shift (as happens when using shot-based simulation).

Implementation

I do not see a straight-forward solution at this point. Possible implementations rely on composability of gradient transforms envisioned for the mid-term future, or on a generalization of the generator property of operations.

Other ideas to implement the (non-backprop) differentiation of SpecialUnitary are very much welcome.

How important would you say this feature is?

2: Somewhat important. Needed this quarter.

Additional information

No response

@dwierichs dwierichs added the enhancement ✨ New feature or request label Sep 27, 2023
@github-actions github-actions bot added the bug 🐛 Something isn't working label Sep 27, 2023
@josh146
Copy link
Member

josh146 commented Aug 27, 2024

Hey @dwierichs, is this still an open issue?

@josh146
Copy link
Member

josh146 commented Aug 27, 2024

If not, @albi3ro will this be fixed by swapping where gradient transforms are done in relation to transforms/decompositions (your proposal for Q4?)

@dwierichs
Copy link
Contributor Author

I looked at an example code on 4 qubits (for which SpecialUnitary has 255 parameters), code below.

  • For JAX, could it be that these identity operations we're concerned about are being compiled away within jax.grad? I'm not entirely sure, but the compile time/first run gets sped up (12.4s -> ~12s) by manually skipping these ops (in apply_operation) whereas the execution time did not really change (is at about 4.3s)
  • For torch, the execution is about 3.2% faster if we manually skip identity ops (64.34s->62.25s)
  • For TF,
  • For torch, the execution time is about the same if we manually skip identity ops (at ~5.1s)

The code (run in jupyter to use magic %prun)

import pennylane as qml

if interface=="jax":
    import jax
    jax.config.update("jax_enable_x64", True)
if interface=="torch":
    import torch
if interface=="tf":
    import tensorflow as tf

N = 4
wires = list(range(N))

@qml.qnode(qml.device("default.qubit"), diff_method="parameter-shift")
def node(x):
    qml.SpecialUnitary(x, wires)
    return qml.expval(qml.Z(0))

if interface=="jax":
    key = jax.random.PRNGKey(824)
    x = jax.random.uniform(key, (4**N - 1,))
    %prun jax.grad(node)(x)
if interface=="torch":
    x = torch.rand(4**N - 1, requires_grad=True)
    out = node(x)
    %prun out.backward(retain_graph=True)
if interface=="tf":
    x = tf.Variable(tf.random.uniform(4**N - 1))
    with tf.GradientTape() as tape:
        out = node(x)
    %prun tape.gradient(out, x)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working enhancement ✨ New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants