# Verifying Parameter Shift Rule by useing automatic differentiation of JAX

## Cost Function

- Computes the expectation value of the quantum state with respect to the observable $M$ for given parameters $\theta$.
- Strictly follows the description in the overleaf document.

---

In this paper, we consider the following unconstrained optimization problem,
$$
\min _{\boldsymbol{\theta} \in \mathbb{R}^m} f(\boldsymbol{\theta})=\left\langle 0\left|U(\boldsymbol{\theta})^{\dagger} M U(\boldsymbol{\theta})\right| 0\right\rangle \text {. }
$$

Here, $U(\boldsymbol{\theta}) \in \mathbb{C}^{2^n \times 2^n}$ is a parameterized quantum circuit and relies on a collection of classical parameters $\boldsymbol{\theta}=\left(\theta_1, \theta_2, \ldots, \theta_m\right) \in \mathbb{R}^m$. Without loss generality, it is sufficient to consider the following typical structure:
$$
U(\boldsymbol{\theta})=V_m U_m\left(\theta_m\right) \cdots V_1 U_1\left(\theta_1\right),
$$
where $V_j \in \mathbb{C}^{2^n \times 2^n}$ are constant arbitrary quantum circuits, while $U_j\left(\theta_j\right) \in \mathbb{C}^{2^n} \times 2^n$ are rotation-like gates, i.e., $U_j\left(\theta_j\right)=e^{-(i / 2) H_j \theta_j}$ for some Hermitian generator $H_j \in \mathbb{C}^{2^n \times 2^n}$. It is well know that $U_j\left(\theta_j\right)$ is unitary for any $\theta_j \in \mathbb{R}$ if the generator $H_j$ is Hermitian ${ }^1$. In additional, we assume that $H_j$ is involutory ${ }^2$, i.e., $H_j^2=I$, where $I$ denotes the identity matrix whose size is clear from context. In this case, we have
$$
U_j\left(\theta_j\right)=e^{-(i / 2) H_j \theta_j}=\cos \left(\theta_j / 2\right) I-i \sin \left(\theta_j / 2\right) H_j
$$
For proof, see Lemma 2 in appendix. This equation implies that $U_j\left(\theta_j\right)$ is a linear combination of constant matrices $I$ and $H_j$, with the association between $U_j\left(\theta_j\right)$ and $\theta_j$ being solely through coefficients. Typically, the circuit $U(\boldsymbol{\theta})$ is applied to a fixed and easy-toprepare input quantum state $|0\rangle \in \mathbb{C}^{2^n}$ and we obtain the output state $U(\theta)|0\rangle \in \mathbb{C}^{2^n}$ in a quantum device. Hence, $\boldsymbol{f}(\boldsymbol{\theta})$ is exactly the expectation value of a Hermitian observable $M \in \mathbb{C}^{2^n} \times 2^n$ is measured with respect to that output state.

---

In [1]:
import jax
import jax.numpy as jnp
import jax.scipy.linalg
from lai_utils import is_unitary, is_hermitian
from lai_utils import generate_random_unitary, generate_random_hermitian, generate_random_H_paulis
from lai_utils import create_uniform_superposed_state, create_ket_zero_state
from lai_utils import U_j

key = jax.random.PRNGKey(2)  # Create the initial key for the random number generator

print('='*100)
print("In this example, we will use JAX's Automatic Differentiation to verify the Parameter Shift Rule is correct.")
print('='*100)

n = 2  # Number of qubits
m = 4  # Number of parameters

# |0> state as input
input_state = jnp.eye(2**n, dtype=jnp.complex64)[:, 0]  

# Generate m different random unitary matrices
subkeys = jax.random.split(key, m+1)
key = subkeys[m]
V_list = [generate_random_unitary(2**n, prng_key=subkeys[i]) for i in range(m)]

# Check if each unitary matrix is indeed unitary
V_is_unitary_list = [is_unitary(V) for V in V_list]
print(f"V_is_unitary_list: {[value.item() for value in V_is_unitary_list]}")

# Generate m different H matrices
subkeys = jax.random.split(key, m+1)
key = subkeys[m]
H_list = [generate_random_H_paulis(n, prng_key=subkeys[i]) for i in range(m)]

# Check if each H matrix is Hermitian
H_is_hermitian_list = [is_hermitian(H) for H in H_list]
print(f"H_is_hermitian_list: {[value.item() for value in H_is_hermitian_list]}")

# Check if each H matrix is unitary
H_is_unitary_list = [is_unitary(H) for H in H_list]
print(f"H_is_unitary_list: {[value.item() for value in H_is_unitary_list]}")

# Define the observable M matrix
key, subkey = jax.random.split(key)
M = generate_random_hermitian(2**n, prng_key=subkey) 

# Define the naive cost function
def exact_cost_naive(theta):
    U_total = jnp.eye(2**n, dtype=jnp.complex64)
    
    for i in range(m):
        U_theta_i = U_j(theta[i], H_list[i], method='exponential')
        U_total = V_list[i] @ U_theta_i @ U_total
    
    U_total_dagger = U_total.conjugate().T
    expectation_value = jnp.vdot(input_state, U_total_dagger @ M @ U_total @ input_state)
    
    return jnp.real(expectation_value)

# Define the practical cost function
def exact_cost(theta):
    state = input_state

    for i in range(m):
        U_theta_i = U_j(theta[i], H_list[i], method='exponential')
        state = U_theta_i @ state  # Apply U_theta_i to the state vector
        state = V_list[i] @ state  # Apply V_list[i] to the state vector
    
    expectation_value = jnp.vdot(state, M @ state)
    
    return jnp.real(expectation_value)

# Randomly initialize parameters
key, subkey = jax.random.split(key)
theta = jax.random.normal(subkey, (m,))

# Compare the outputs of both functions
output_naive = exact_cost_naive(theta)
output_practical = exact_cost(theta)

print(f"Output of exact_cost_naive: {output_naive}")
print(f"Output of exact_cost: {output_practical}")

# Verify if the results are consistent
assert jnp.allclose(output_naive, output_practical), "The outputs are not equal!"
print("The outputs of both functions are equal!")

In this example, we will use JAX's Automatic Differentiation to verify the Parameter Shift Rule is correct.
V_is_unitary_list: [True, True, True, True]
H_is_hermitian_list: [True, True, True, True]
H_is_unitary_list: [True, True, True, True]
Output of exact_cost_naive: 1.2712680101394653
Output of exact_cost: 1.2712682485580444
The outputs of both functions are equal!


Now, the definition of the cost function is completed.

##  Gradient Calculation

- Uses the Parameter Shift Rule to calculate the exact gradient of the cost function.
- Verifies the gradient calculated by the Parameter Shift Rule with the gradient computed using JAX's automatic differentiation.

In [2]:
# Calculate the exact gradient using the Parameter Shift Rule
def psr_grad_function(theta, shift=jnp.pi/2):
    # Check if shift is an integer multiple of pi
    if jnp.isclose(shift % jnp.pi, 0):
        raise ValueError("Shift must not be an integer multiple of pi. The best shift value is pi/2")
    
    m = len(theta)
    gradient = jnp.zeros(m)
    
    for j in range(m):
        theta_shifted = theta.at[j].set(theta[j] + shift)
        f_forward = exact_cost(theta_shifted)
        
        theta_shifted = theta.at[j].set(theta[j] - shift)
        f_backward = exact_cost(theta_shifted)
        
        gradient = gradient.at[j].set((f_forward - f_backward) / (2 * jnp.sin(shift)))
    
    return gradient

key, subkey = jax.random.split(key)

# Randomly initialize parameters, 3 common choices
theta = jax.random.normal(subkey, (m,))

# Create the initial state |+⟩^⊗n
# theta = create_uniform_superposed_state(n)

# Create the initial state |0⟩^⊗n
# theta = create_ket_zero_state(n)

print(f"Randomly chosen theta: {theta}")

# Calculate the cost function value
cost_value = exact_cost(theta)
print(f"Cost value: {cost_value}")

# Calculate the gradient of the cost function

# 1. Using JAX automatic differentiation
jax_grad_function = jax.grad(exact_cost)
jax_grad = jax_grad_function(theta)

# 2. Using the Parameter Shift Rule
SHIFT0 = jnp.pi / 2
prs_grad_0 = psr_grad_function(theta, SHIFT0)
SHIFT1 = jnp.pi / 4
prs_grad_1 = psr_grad_function(theta, SHIFT1)

# Print the gradients computed by JAX and the Parameter Shift Rule
print(f"JAX - gradient: {jax_grad}")
print(f"Parameter shift rule (shift value: {SHIFT0}) - gradient: {prs_grad_0}")
print(f"Parameter shift rule (shift value: {SHIFT1}) - gradient: {prs_grad_1}")

# Verify if the results are consistent
print('='*100)
# Calculate the L2 norm distance between the gradients
distance1 = jnp.linalg.norm(prs_grad_0 - jax_grad)
print(f"L2 distance between JAX and PSR gradient (shift: {SHIFT0}): \n {distance1}")

distance2 = jnp.linalg.norm(prs_grad_1 - jax_grad)
print(f"L2 distance between JAX and PSR gradient (shift: {SHIFT1}): \n {distance2}")

distance3 = jnp.linalg.norm(prs_grad_0 - prs_grad_1)
print(f"L2 distance between the two PSR gradients: \n {distance3}")
print('='*100)

Randomly chosen theta: [-0.32004303 -0.63536096  0.5037522   0.39323545]
Cost value: 0.4903776943683624
JAX - gradient: [ 0.6967335   0.21761885  0.43304354 -0.8293713 ]
Parameter shift rule (shift value: 1.5707963267948966) - gradient: [ 0.69673365  0.21761894  0.43304333 -0.8293711 ]
Parameter shift rule (shift value: 0.7853981633974483) - gradient: [ 0.69673336  0.21761917  0.43304336 -0.8293713 ]
L2 distance between JAX and PSR gradient (shift: 1.5707963267948966): 
 3.3979875979639473e-07
L2 distance between JAX and PSR gradient (shift: 0.7853981633974483): 
 3.7961422094667796e-07
L2 distance between the two PSR gradients: 
 4.1429515817981155e-07
