In [None]:
import pennylane as qml
import numpy as np
import jax
import matplotlib.pyplot as plt

In [None]:
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
jnp = jax.numpy

In [None]:
def two_qubit_decomp(params, wires):
    """Implement an arbitrary SU(4) gate on two qubits
    using the decomposition from Theorem 5 in
    https://arxiv.org/pdf/quant-ph/0308006.pdf"""
    i, j = wires
    # Single U(2) parameterization on both qubits separately
    qml.Rot(*params[:3], wires=i)
    qml.Rot(*params[3:6], wires=j)
    qml.CNOT(wires=[j, i])  # First CNOT
    qml.RZ(params[6], wires=i)
    qml.RY(params[7], wires=j)
    qml.CNOT(wires=[i, j])  # Second CNOT
    qml.RY(params[8], wires=j)
    qml.CNOT(wires=[j, i])  # Third CNOT
    # Single U(2) parameterization on both qubits separately
    qml.Rot(*params[9:12], wires=i)
    qml.Rot(*params[12:15], wires=j)


# The three building blocks on two qubits we will compare are:
operations = {
    ("Decomposition", "decomposition"): two_qubit_decomp,
    ("PauliRot sequence",) * 2: qml.ArbitraryUnitary,
    ("$\mathrm{SU}(N)$ gate", "SU(N) gate"): qml.SpecialUnitary,
}

In [None]:
num_wires = 6
wires = list(range(num_wires))
np.random.seed(62213)


In [None]:
# Define the Hamiltonian
coeffs = []
obs = []

# Coupling constant (J) and transverse field strength (h)
J = 1.0
h = 0.5

# ZZ interactions
for i in range(num_wires):
    for j in range(i+1, num_wires):
        coeffs.append(-J)
        obs.append(qml.PauliZ(i) @ qml.PauliZ(j))

# X interactions
for i in range(num_wires):
    coeffs.append(-h)
    obs.append(qml.PauliX(i))

# Create the Hamiltonian
H = qml.Hamiltonian(coeffs, obs)

E_min = min(qml.eigvals(H))

#print(qml.eigvals(H))
print(f"Ground state energy: {E_min:.5f}")

In [None]:
loc = 2
d = 4**loc - 1  # d = 15 for two-qubit operations
dev = qml.device("default.qubit", wires=num_wires)
# two blocks with two layers. Each layer contains three operations with d parameters
def launchCircuit(repeatitions: int, params: List):
    def circuit(params, operation=None):
        """Apply an operation in a brickwall-like pattern to a qubit register and measure H.
        Parameters are assumed to have the dimensions (number of blocks, number of
        wires per operation, number of operations per layer, and number of parameters
        per operation), in that order.
        """
        for params_block in params:
            for i, params_layer in enumerate(params_block):
                for j, params_op in enumerate(params_layer):
                    wires_op = [w % num_wires for w in range(loc * j + i, loc * (j + 1) + i)]
                    #print("performing",  operation, "operations", " on ", wires_op, "params", params_op)
                    operation(params_op, wires_op)
        return qml.expval(H)

    qnode = qml.QNode(circuit, dev, interface="jax")
    print(qml.draw(qnode)(init_params, qml.SpecialUnitary))

In [None]:
run(params_shape)

In [None]:
# for auto-differentiation.

learning_rate = 5e-4
num_steps = 500
init_params = jax.numpy.array(init_params)
grad_fn = jax.jit(jax.jacobian(qnode), static_argnums=1)
qnode = jax.jit(qnode, static_argnums=1)

In [None]:
energies = {}

# # The three building blocks on two qubits we will compare are:
# operations = {
#     ("Decomposition", "decomposition"): two_qubit_decomp,
#     ("PauliRot sequence",) * 2: qml.ArbitraryUnitary,
#     ("$\mathrm{SU}(N)$ gate", "SU(N) gate"): qml.SpecialUnitary,
# }

#Hisenberg model XX + YY + ZZ

@jax.jit
def numpy_callback(x):
  # Need to forward-declare the shape & dtype of the expected output.
  result_shape = jax.core.ShapedArray(x.shape, x.dtype)
  return jax.pure_callback(np.sin, result_shape, x)

for (name, print_name), operation in operations.items(): 
    print(f"Running the optimization for the {print_name}")
    params = init_params.copy()
    energy = []
    for step in range(num_steps):
        cost = qnode(params, operation)
        params = params - learning_rate * grad_fn(params, operation)
        #print(numpy_callback(params))
        energy.append(cost)  # Store energy value
        if step % 10 == 0:  # Report current energy
            print(f"{step:3d} Steps: {cost:.6f}")

    energy.append(qnode(params, operation))  # Final energy value
    energies[name] = energy

In [None]:
fig, ax = plt.subplots(1, 1)
styles = [":", "--", "-"]
colors = ["#70CEFF", "#C756B2", "#FFE096"]
for (name, energy), c, ls in zip(energies.items(), colors, styles):
    error = (energy - E_min) / abs(E_min)
    ax.plot(list(range(len(error))), error, label=name, c=c, ls=ls, lw=2.5)

ax.set(xlabel="Iteration", ylabel="Relative error")
ax.legend()
plt.show()

In [None]:
# Plotting the grouped bar chart
bar_width = 0.1
x = np.arange(len(repetitions))

plt.figure(figsize=(14, 8))

for i, steps in enumerate(steps_list):
    plt.bar(x + i * bar_width, final_energies[steps], bar_width, label=f"Steps = {steps}")

plt.xlabel("Number of Repetitions")
plt.ylabel("Absolute Final Energy")
plt.title("Absolute Final Energy for Different Repetitions and Steps (Arb_unitary")
plt.xticks(x + bar_width * (len(steps_list) / 2), repetitions)
plt.legend()
plt.grid(True)
plt.show()