In [5]:
import pennylane as qml
import jax
from jax import numpy as jnp
import optax

n_wires = 4
weights = jnp.ones((n_wires,3))
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}
#opt = optax.adam(learning_rate=0.1)
#opt_state = opt.init(params)
opt = qml.QNGOptimizer()


@qml.qnode(qml.device("default.qubit", wires=n_wires), diff_method="best")
def circuit(x, weights):

    # Embedding Ansatz
    for i in range(n_wires):
        qml.RY(2*jnp.arccos(x),wires = i)

    # Variational Ansatz   
    for i in range(n_wires):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    # Total magnetization in z-direction as cost function
    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

@jax.jit
def my_model(data, weights, bias):
    return circuit(data, weights) + bias

@jax.jit
def loss_fnc(params):
    # Loss function of: du/dx = 1, u(0) = 0
    x = jnp.linspace(0,0.99,21)
    _dudx = jax.grad(my_model, argnums=0)
    dudx = jnp.array([_dudx(i, params["weights"], params["bias"]) for i in x])
    
    res = dudx - (4 * x**3 + x**2 - 2 * x - 0.5)
    loss_diff = jnp.mean(res**2)
    loss_initial = jnp.mean((my_model(jnp.zeros_like(x),params["weights"], params["bias"]) - jnp.ones_like(x))**2 )
    
def optimize(params, n=1000):
    loss_history = []

    for i in range(1,n+1):
        # loss_val, grads = jax.value_and_grad(loss_fnc)(params)
        # updates, opt_state = opt.update(grads, opt_state)
        # params = optax.apply_updates(params, updates)
        params = opt.step(loss_fnc, params)
        loss_val = loss_fnc(params)
        if i%100 == 0: jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val=loss_val)
        loss_history.append(loss_val)

    return params, loss_history

%timeit -r1 -n1 optimize(params)


ValueError: The objective function must be encoded as a single QNode for the natural gradient to be automatically computed. Otherwise, metric_tensor_fn must be explicitly provided to the optimizer.

In [2]:
%timeit -r1 -n1 optimize(params, opt_state)
# params, opt_state, loss_history = optimize(params, opt_state)

Step: 100  Loss: 0.28363654017448425
Step: 200  Loss: 0.07567053288221359
Step: 300  Loss: 0.09308571368455887
Step: 400  Loss: 0.04231073707342148
Step: 500  Loss: 0.015424626879394054
Step: 600  Loss: 0.012171284295618534
Step: 700  Loss: 0.011967191472649574
Step: 800  Loss: 0.03056567907333374
Step: 900  Loss: 0.010834799148142338
Step: 1000  Loss: 0.01202180702239275
51.2 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [None]:
import matplotlib.pyplot as plt

fig = plt.figure()
fig.set_figheight(5)
fig.set_figwidth(12)

# Add energy plot on column 1
ax1 = fig.add_subplot(121)
ax1.plot(range(len(loss_history)), loss_history, "go", ls="dashed")
ax1.set_xlabel("Optimization step", fontsize=13)
ax1.set_ylabel("Loss", fontsize=13)

ax2 = fig.add_subplot(122)
x = jnp.linspace(0,0.99,21)
f_qc = my_model(x,params["weights"], params["bias"])
f_an = x
ax2.plot( x, f_qc, "ro", ls="dashed")
ax2.plot( x, f_an, "go", ls="dashed")
ax2.legend(["QCML", "Analytical"])

plt.show()