In [14]:
import numpy as np
import jax
import pennylane as qml
import jax.numpy as jnp
import optax
import time

dev = qml.device("default.qubit", wires=4)

@qml.qnode(dev)
def cost(weights, data):
    qml.AngleEmbedding(data, wires=range(4))

    for x in weights:
        # each trainable layer
        for i in range(4):
            # for each wire
            if x[i] > 0:
                qml.RX(x[i], wires=i)
            elif x[i] < 0:
                qml.RY(x[i], wires=i)

        for i in range(4):
            qml.CNOT(wires=[i, (i + 1) % 4])

    return qml.expval(qml.PauliZ(0) + qml.PauliZ(3))

weights = jnp.array(2 * np.random.random([5, 4]) - 1)
data = jnp.array(np.random.random([4]))

opt = optax.sgd(learning_rate=0.4)

params = weights
state = opt.init(params)

startTime = time.time()
for i in range(200):
    gradient = jax.grad(cost)(params, data)
    (updates, state) = opt.update(gradient, state)
    params = optax.apply_updates(params, updates)

endTime = time.time()
print(params)

print("Final Time: ",endTime-startTime)

[[-0.08257651 -0.0075295  -0.01007543  0.10362513]
 [-1.58010072 -0.4642326   0.01680865  1.60022535]
 [-1.64318142 -0.04204414  0.22940424  0.01821919]
 [ 0.00550645  1.69226951  0.00624569  0.06770775]
 [-0.02196728  0.15368392  1.25898752 -0.78985517]]
Final Time:  25.078877925872803


In [15]:
import jax
from catalyst import qjit, for_loop, cond, grad

dev = qml.device("lightning.qubit", wires=4)

@qjit
@qml.qnode(dev)
def cost(weights, data):
    qml.AngleEmbedding(data, wires=range(4))

    def layer_loop(i):
        x = weights[i]
        def wire_loop(j):

            @cond(x[j] > 0)
            def trainable_gate():
                qml.RX(x[j], wires=j)

            @trainable_gate.else_if(x[j] < 0)
            def negative_gate():
                qml.RY(x[j], wires=j)

            trainable_gate.otherwise(lambda: None)
            trainable_gate()

        def cnot_loop(j):
            qml.CNOT(wires=[j, jnp.mod((j + 1), 4)])

        for_loop(0, 4, 1)(wire_loop)()
        for_loop(0, 4, 1)(cnot_loop)()

    for_loop(0, jnp.shape(weights)[0], 1)(layer_loop)()
    return qml.expval(qml.PauliZ(0) + qml.PauliZ(3))

opt = optax.sgd(learning_rate=0.4)

params = weights
state = opt.init(params)

startTimeQjit = time.time()
for i in range(200):
    gradient = jax.grad(cost)(params, data)
    (updates, state) = opt.update(gradient, state)
    params = optax.apply_updates(params, updates)

endTimeQjit = time.time()
print(params)
print("Final Time: ", endTimeQjit-startTimeQjit)

[[-0.08257651 -0.0075295  -0.01007543  0.10362513]
 [-1.58010072 -0.4642326   0.01680865  1.60022535]
 [-1.64318142 -0.04204414  0.22940424  0.01821919]
 [ 0.00550645  1.69226951  0.00624569  0.06770775]
 [-0.02196728  0.15368392  1.25898752 -0.78985517]]
Final Time:  3.5193538665771484


In [16]:
@qjit
def optimize(init_weights, data, steps):
    def loss(x):
        dy = grad(cost, argnums=0)(x, data)
        return (cost(x, data), dy)

    opt = optax.sgd(learning_rate=0.4)

    def update_step(i, params, state):
        (_, gradient) = loss(params)
        (updates, state) = opt.update(gradient, state)
        params = optax.apply_updates(params, updates)
        return (params, state)

    params = init_weights
    state = opt.init(params)

    return for_loop(0, steps, 1)(update_step)(params, state)

startOptimize = time.time()
optimize(weights, data, 200)
endOptimize = time.time()

print(params)

print("Final time: ", endOptimize-startOptimize)

[[-0.08257651 -0.0075295  -0.01007543  0.10362513]
 [-1.58010072 -0.4642326   0.01680865  1.60022535]
 [-1.64318142 -0.04204414  0.22940424  0.01821919]
 [ 0.00550645  1.69226951  0.00624569  0.06770775]
 [-0.02196728  0.15368392  1.25898752 -0.78985517]]
Final time:  1.3835999965667725
