In [None]:
! pip install pennylane optax

In [4]:
import pennylane as qml
import jax
from jax import numpy as jnp
import optax

n_wires = 5
data = jnp.cos(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
targets = jnp.array([-0.2, 0.4, 0.35, 0.2])

dev = qml.device("default.qubit", wires=n_wires)

@qml.qnode(dev)
def circuit(data, weights):
    """Quantum circuit ansatz"""

    # data embedding
    for i in range(n_wires):
        # data[i] will be of shape (4,); we are
        # taking advantage of operation vectorization here
        qml.RY(data[i], wires=i)

    # trainable ansatz
    for i in range(n_wires):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    # we use a sum of local Z's as an observable since a
    # local Z would only be affected by params on that qubit.
    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

def my_model(data, weights, bias):
    return circuit(data, weights) + bias

In [5]:
@jax.jit
def loss_fn(params, data, targets):
    predictions = my_model(data, params["weights"], params["bias"])
    loss = jnp.sum((targets - predictions) ** 2 / len(data))
    return loss

In [6]:
weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}

In [7]:
print(loss_fn(params, data, targets))
print(jax.grad(loss_fn)(params, data, targets))

0.17344648
{'bias': Array(-0.65765524, dtype=float32, weak_type=True), 'weights': Array([[-0.17739275, -0.06189498, -0.17819384],
       [-0.21560565, -0.0275434 , -0.33223915],
       [-0.0020373 ,  0.09686179, -0.23871091],
       [-0.15701011, -0.00134047, -0.2873454 ],
       [-0.0184525 , -0.0064371 , -0.01853886]], dtype=float32)}


In [8]:
opt = optax.adam(learning_rate=0.3)
opt_state = opt.init(params)

In [9]:
def update_step(opt, params, opt_state, data, targets):
    loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_val

loss_history = []

for i in range(100):
    params, opt_state, loss_val = update_step(opt, params, opt_state, data, targets)

    if i % 5 == 0:
        print(f"Step: {i} Loss: {loss_val}")

    loss_history.append(loss_val)

Step: 0 Loss: 0.17344647645950317
Step: 5 Loss: 0.14710889756679535
Step: 10 Loss: 0.05850798636674881
Step: 15 Loss: 0.04570798948407173
Step: 20 Loss: 0.03201589733362198
Step: 25 Loss: 0.02616313472390175
Step: 30 Loss: 0.024366136640310287
Step: 35 Loss: 0.022175978869199753
Step: 40 Loss: 0.02028464339673519
Step: 45 Loss: 0.019171901047229767
Step: 50 Loss: 0.018599100410938263
Step: 55 Loss: 0.01829129084944725
Step: 60 Loss: 0.018130116164684296
Step: 65 Loss: 0.017963577061891556
Step: 70 Loss: 0.01769602671265602
Step: 75 Loss: 0.01737845502793789
Step: 80 Loss: 0.017195910215377808
Step: 85 Loss: 0.016997545957565308
Step: 90 Loss: 0.01689930260181427
Step: 95 Loss: 0.01683393120765686


In [11]:
# Define the optimizer we want to work with
opt = optax.adam(learning_rate=0.3)

@jax.jit
def update_step_jit(i, args):
    params, opt_state, data, targets, print_training = args

    loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    def print_fn():
        jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val=loss_val)

    # if print_training=True, print the loss every 5 steps
    jax.lax.cond((jnp.mod(i, 5) == 0) & print_training, print_fn, lambda: None)

    return (params, opt_state, data, targets, print_training)

@jax.jit
def optimization_jit(params, data, targets, print_training=False):

    opt_state = opt.init(params)
    args = (params, opt_state, data, targets, print_training)
    (params, opt_state, _, _, _) = jax.lax.fori_loop(0, 100, update_step_jit, args)

    return params

In [12]:
params = {"weights": weights, "bias": bias}
optimization_jit(params, data, targets, print_training=True)

Step: 0  Loss: 0.17344647645950317
Step: 5  Loss: 0.14710880815982819
Step: 10  Loss: 0.05850811302661896
Step: 15  Loss: 0.045708067715168
Step: 20  Loss: 0.03201604634523392
Step: 25  Loss: 0.02616310305893421
Step: 30  Loss: 0.024366099387407303
Step: 35  Loss: 0.022175926715135574
Step: 40  Loss: 0.020284635946154594
Step: 45  Loss: 0.01917184703052044
Step: 50  Loss: 0.018599100410938263
Step: 55  Loss: 0.01829131320118904
Step: 60  Loss: 0.01813017763197422
Step: 65  Loss: 0.017963571473956108
Step: 70  Loss: 0.017696011811494827
Step: 75  Loss: 0.017378445714712143
Step: 80  Loss: 0.01719590835273266
Step: 85  Loss: 0.016997555270791054
Step: 90  Loss: 0.01689927838742733
Step: 95  Loss: 0.01683397963643074


{'bias': Array(0.01127004, dtype=float32),
 'weights': Array([[1.2887369 , 2.0005755 , 0.99289083],
        [1.6590643 , 1.3155318 , 1.2168641 ],
        [1.4232814 , 0.2568506 , 1.7053303 ],
        [1.2973213 , 1.880504  , 0.85184103],
        [0.05557076, 3.3709753 , 3.1788893 ]], dtype=float32)}

In [13]:
from timeit import repeat

def optimization(params, data, targets):
    opt = optax.adam(learning_rate=0.3)
    opt_state = opt.init(params)

    for i in range(100):
        params, opt_state, loss_val = update_step(opt, params, opt_state, data, targets)

    return params

reps = 5
num = 2

times = repeat("optimization(params, data, targets)", globals=globals(), number=num, repeat=reps)
result = min(times) / num

print(f"Jitting just the cost (best of {reps}): {result} sec per loop")

times = repeat("optimization_jit(params, data, targets)", globals=globals(), number=num, repeat=reps)
result = min(times) / num

print(f"Jitting the entire optimization (best of {reps}): {result} sec per loop")

Jitting just the cost (best of 5): 0.6283237499992538 sec per loop
Jitting the entire optimization (best of 5): 0.0019953500013798475 sec per loop


In [14]:
n_wires = 5
data = jnp.cos(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
targets = jnp.array([-0.2, 0.4, 0.35, 0.2])
print(data)

[[-7.2067559e-02 -1.1728344e-02 -2.4895622e-05  4.9101636e-03]
 [ 4.7578786e-02  1.5772879e-01  3.3818188e-01  5.6220156e-01]
 [ 7.8138560e-01  9.4138408e-01  1.0000000e+00  9.4138354e-01]
 [ 7.8138465e-01  5.6220043e-01  3.3818090e-01  1.5772806e-01]
 [ 4.7578432e-02  4.9100821e-03 -2.4898063e-05 -1.1728490e-02]]
