In [64]:
from copy import deepcopy
import numpy as np
from functools import partial, reduce

from typing import Callable

from qiskit.aqua.operators.gradients.qfi import QFI
from qiskit import BasicAer
from qiskit.aqua.operators import X, Z, StateFn, CircuitStateFn, CircuitSampler, Zero, Plus, One
from qiskit import QuantumCircuit, QuantumRegister
from qiskit.circuit import Parameter
from qiskit.aqua import QuantumInstance, aqua_globals
from qiskit import QuantumCircuit
from qiskit.circuit import ParameterExpression, Parameter, ParameterVector, Instruction
from qiskit.providers import BaseBackend
from qiskit.aqua import QuantumInstance, AquaError
from qiskit.aqua.operators.gradients.gradient.operator_gradient import ObservableGradient
# from qiskit.aqua.operators.gradients.gradient.prob_gradient import ProbabilityGradient
from qiskit.aqua.operators.gradients.gradient.state_gradient import StateGradient
from qiskit.aqua.operators import OperatorBase, ListOp, SummedOp, ComposedOp, TensoredOp

In [89]:
def get_grad_combo_fn(operator: ListOp) -> Callable:
        """
        Get the derivative of the operator combo_fn
        Args:
            operator: The operator for whose combo_fn we want to get the gradient.

        Returns:
            Derivative of the operator combo_fn

        """
        if isinstance(operator, TensoredOp):
            return operator.combo_fn
        elif isinstance(operator, ComposedOp):
            def grad_composed_combo_fn(x):
                # Gradient of the doperator/dz
                grad_combo_fn = []
                for i in range(len(x)):
                    y = deepcopy(x)
                    y.pop(i)
                    import numpy as np
                    grad_combo_fn.append(partial(reduce, np.dot)(y))
                return grad_combo_fn
            return grad_composed_combo_fn
        else:
            from jax import grad, jit
            import jax.numpy as np
            try:
                return jit(grad(operator.combo_fn)) #handles SummedOp too
            except Exception:
                return grad_combo_fn

In [90]:
a = [~StateFn(X), ~StateFn(Z)]
# op = ListOp(a, combo_fn = lambda x: (-0.5)*(x[0]*np.conjugate(x[1]) + x[1]*np.conjugate(x[0])))
# op = TensoredOp(a)
# op = SummedOp(a)
op = ComposedOp(a)

In [91]:
g = get_grad_combo_fn(op)
print(g([1, 2]))

[2, 1]
