In [10]:
import pennylane as qml
import jax
from jax import numpy as jnp
import optax

n_wires = 4
weights = jnp.ones((n_wires,3))
bias = jnp.array(0.)
opt = optax.adam(learning_rate=0.1)
params = {"weights": weights, "bias": bias}
opt_state = opt.init(params)

@qml.qnode(qml.device("lightning.qubit", wires=n_wires), diff_method="adjoint")
def circuit(x, weights):

    # Embedding Ansatz
    for i in range(n_wires):
        qml.RY(2*jnp.arccos(x),wires = i)

    # Variational Ansatz   
    for i in range(n_wires):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    # Total magnetization in z-direction as cost function
    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

def my_model(data, weights, bias):
    return circuit(data, weights) + bias

def loss_fnc(params):
    # Loss function of: du/dx = 1, u(0) = 0
    x = jnp.linspace(0,0.99,11)
    _dudx = jax.grad(my_model, argnums=0)
    dudx = jnp.array([_dudx(i, params["weights"], params["bias"]) for i in x])
    
    loss_diff = jnp.mean((dudx - jnp.ones_like(dudx))**2)
    loss_initial = jnp.mean(my_model(jnp.zeros_like(x),params["weights"], params["bias"])**2)
    
    return loss_diff + loss_initial

def optimize(params, opt_state, n=10):
    loss_history = []

    for i in range(1,n+1):
        loss_val, grads = jax.value_and_grad(loss_fnc)(params)
        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        if i%1 == 0: jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val=loss_val)
        loss_history.append(loss_val)

    return params, opt_state, loss_history

%timeit -r1 -n1 optimize(params, opt_state)

Step: 1  Loss: 22.767384660616983
Step: 2  Loss: 22.075861797576376
Step: 3  Loss: 24.740213192148016
Step: 4  Loss: 28.548354969906313
Step: 5  Loss: 32.84860372675379
Step: 6  Loss: 37.56952225127498
Step: 7  Loss: 42.778332784546386
Step: 8  Loss: 48.53382627549362
Step: 9  Loss: 54.829978789433554
Step: 10  Loss: 61.57769251022365
9.82 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [None]:
# params, opt_state, loss_history = optimize(params, opt_state)

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure()
fig.set_figheight(5)
fig.set_figwidth(12)

# Add energy plot on column 1
ax1 = fig.add_subplot(121)
ax1.plot(range(len(loss_history)), loss_history, "go", ls="dashed")
ax1.set_xlabel("Optimization step", fontsize=13)
ax1.set_ylabel("Loss", fontsize=13)

ax2 = fig.add_subplot(122)
x = jnp.linspace(0,0.99,21)
f_qc = my_model(x,params["weights"], params["bias"])
f_an = x
ax2.plot( x, f_qc, "ro", ls="dashed")
ax2.plot( x, f_an, "go", ls="dashed")
ax2.legend(["QCML", "Analytical"])

plt.show()