In [17]:
import numpy as np
import jax
import pennylane as qml
import jax.numpy as jnp
import optax
import time

dev = qml.device("lightning.qubit", wires=4)

@qml.qnode(dev)
def cost(weights, data):
    qml.AngleEmbedding(data, wires=range(4))

    for x in weights:
        # each trainable layer
        for i in range(4):
            # for each wire
            if x[i] > 0:
                qml.RX(x[i], wires=i)
            elif x[i] < 0:
                qml.RY(x[i], wires=i)

        for i in range(4):
            qml.CNOT(wires=[i, (i + 1) % 4])

    return qml.expval(qml.PauliZ(0) + qml.PauliZ(3))

weights = jnp.array(2 * np.random.random([5, 4]) - 1)
data = jnp.array(np.random.random([4]))

opt = optax.sgd(learning_rate=0.4)

params = weights
state = opt.init(params)

startTime = time.time()
for i in range(200):
    gradient = jax.grad(cost)(params, data)
    (updates, state) = opt.update(gradient, state)
    params = optax.apply_updates(params, updates)

endTime = time.time()
print(params)

print("Final Time: ",endTime-startTime)

[[-9.05858318e-03 -6.26938201e-04  9.52088275e-01 -4.02600033e-04]
 [-6.03197898e-03 -1.81841675e-02  1.57074446e+00  1.51057752e+00]
 [-3.56914642e-04 -1.58165906e+00 -6.37902783e-02 -4.75711047e-02]
 [-1.25092535e+00 -1.54563805e+00  4.72232893e-04 -3.33794086e-02]
 [ 2.33617919e-01 -3.49388710e-01 -2.98290528e-02  3.23012111e-04]]
Final Time:  4.3881189823150635


In [20]:
import jax
from catalyst import qjit, for_loop, cond, grad

dev = qml.device("lightning.qubit", wires=4)

@qjit
@qml.qnode(dev)
def cost(weights, data):
    qml.AngleEmbedding(data, wires=range(4))

    def layer_loop(i):
        x = weights[i]
        def wire_loop(j):

            @cond(x[j] > 0)
            def trainable_gate():
                qml.RX(x[j], wires=j)

            @trainable_gate.else_if(x[j] < 0)
            def negative_gate():
                qml.RY(x[j], wires=j)

            trainable_gate.otherwise(lambda: None)
            trainable_gate()

        def cnot_loop(j):
            qml.CNOT(wires=[j, jnp.mod((j + 1), 4)])

        for_loop(0, 4, 1)(wire_loop)()
        for_loop(0, 4, 1)(cnot_loop)()

    for_loop(0, jnp.shape(weights)[0], 1)(layer_loop)()
    return qml.expval(qml.PauliZ(0) + qml.PauliZ(3))

opt = optax.sgd(learning_rate=0.4)

params = weights
state = opt.init(params)

startTimeQjit = time.time()
for i in range(200):
    gradient = jax.grad(cost)(params, data)
    (updates, state) = opt.update(gradient, state)
    params = optax.apply_updates(params, updates)

endTimeQjit = time.time()
print(cost.mlir)
print("Final Time: ", endTimeQjit-startTimeQjit)

module @cost {
  func.func public @jit_cost(%arg0: tensor<5x4xf64>, %arg1: tensor<4xf64>) -> tensor<f64> attributes {llvm.emit_c_interface} {
    %0 = catalyst.launch_kernel @module_cost::@cost(%arg0, %arg1) : (tensor<5x4xf64>, tensor<4xf64>) -> tensor<f64>
    return %0 : tensor<f64>
  }
  module @module_cost {
    module attributes {transform.with_named_sequence} {
      transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) {
        transform.yield 
      }
    }
    func.func public @cost(%arg0: tensor<5x4xf64>, %arg1: tensor<4xf64>) -> tensor<f64> attributes {diff_method = "parameter-shift", llvm.linkage = #llvm.linkage<internal>, qnode} {
      %c4 = arith.constant 4 : index
      %c1 = arith.constant 1 : index
      %c5 = arith.constant 5 : index
      %c0 = arith.constant 0 : index
      %c0_i64 = arith.constant 0 : i64
      %cst = stablehlo.constant dense<1.000000e+00> : tensor<f64>
      %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f6

In [22]:
@qjit
def optimize(init_weights, data, steps):
    def loss(x):
        dy = grad(cost, argnums=0)(x, data)
        return (cost(x, data), dy)

    opt = optax.sgd(learning_rate=0.4)

    def update_step(i, params, state):
        (_, gradient) = loss(params)
        (updates, state) = opt.update(gradient, state)
        params = optax.apply_updates(params, updates)
        return (params, state)

    params = init_weights
    state = opt.init(params)

    return for_loop(0, steps, 1)(update_step)(params, state)

startOptimize = time.time()
optimize(weights, data, 200)
endOptimize = time.time()

print(optimize.mlir)

print("Final time: ", endOptimize-startOptimize)

module @optimize {
  func.func public @jit_optimize(%arg0: tensor<5x4xf64>, %arg1: tensor<4xf64>, %arg2: tensor<i64>) -> tensor<5x4xf64> attributes {llvm.emit_c_interface} {
    %c1 = arith.constant 1 : index
    %cst = stablehlo.constant dense<-4.000000e-01> : tensor<f64>
    %c0 = arith.constant 0 : index
    %extracted = tensor.extract %arg2[] : tensor<i64>
    %0 = arith.index_cast %extracted : i64 to index
    %1 = scf.for %arg3 = %c0 to %0 step %c1 iter_args(%arg4 = %arg0) -> (tensor<5x4xf64>) {
      %2 = gradient.grad "auto" @module_cost::@cost(%arg4, %arg1) {diffArgIndices = dense<0> : tensor<1xi64>} : (tensor<5x4xf64>, tensor<4xf64>) -> tensor<5x4xf64>
      %3 = catalyst.launch_kernel @module_cost_0::@cost_1(%arg4, %arg1) : (tensor<5x4xf64>, tensor<4xf64>) -> tensor<f64>
      %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<5x4xf64>
      %5 = stablehlo.multiply %4, %2 : tensor<5x4xf64>
      %6 = stablehlo.add %arg4, %5 : tensor<5x4xf64>
      scf.