In [102]:
import pennylane as qml
import jax
from jax import numpy as jnp
#import jaxopt
import optax
import catalyst

n_wires = 2

@qml.qnode(qml.device("lightning.qubit", wires=n_wires), diff_method="best")
def circuit(x, theta):

    for i in range(n_wires):
        qml.RY(2*jnp.arccos(x),wires = i)
    
    for i in range(n_wires):
        qml.RX(theta[i, 0], wires=i)
        qml.RY(theta[i, 1], wires=i)
        qml.RX(theta[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

def loss_fnc(theta):
    x = jnp.linspace(0,.9,11)
    _dudx = jax.grad(circuit, argnums=0)
    dudx = jnp.array([_dudx(i, theta) for i in x])
    
    u_0 = circuit(x[0], theta)
    
    loss_1 = jnp.mean(u_0**2)
    loss_2 = jnp.mean((dudx - jnp.ones_like(dudx))**2)

    return loss_1 + loss_2

theta = jnp.ones((n_wires,3))
opt = optax.adam(learning_rate=0.01)
opt_state = opt.init(theta)

def optimize(theta, opt_state):
    loss_history = []

    for i in range(100):
        loss_val, grads = jax.value_and_grad(loss_fnc)(theta)
        updates, opt_state = opt.update(grads, opt_state)
        theta = optax.apply_updates(theta, updates)
        print(f"Step: {i}  Loss: {loss_val}")
        loss_history.append(loss_val)

    return theta, opt_state, loss_history


In [103]:
theta, opt_state, loss_history = optimize(theta, opt_state)

Step: 0  Loss: 6.206298278863095
Step: 1  Loss: 5.973365632428284
Step: 2  Loss: 5.74568701590633
Step: 3  Loss: 5.523665986970554
Step: 4  Loss: 5.307676035139273
Step: 5  Loss: 5.098056012364894
Step: 6  Loss: 4.895106012314087
Step: 7  Loss: 4.699083860209125
Step: 8  Loss: 4.510202362141512
Step: 9  Loss: 4.328627438612831


In [None]:
import matplotlib.pyplot as plt

fig = plt.figure()
fig.set_figheight(5)
fig.set_figwidth(12)

# Add energy plot on column 1
ax1 = fig.add_subplot(121)
ax1.plot(range(len(loss_history)), loss_history, "go", ls="dashed")
ax1.set_xlabel("Optimization step", fontsize=13)
ax1.set_ylabel("Loss", fontsize=13)

ax2 = fig.add_subplot(122)
x = jnp.linspace(0,1,21)
f_qc = my_model(x, params["weights"], params["bias"])
f_an = target_fnc(x)
ax2.plot( x, f_qc, "ro", ls="dashed")
ax2.plot( x, f_an, "go", ls="dashed")
ax2.legend(["QCML", "Analytical"])

plt.show()