# 3.3.2. Augmenting differentiable compilation transforms with JIT compilation

This notebook contains all the results for the JIT compilation examples. Note that the runtimes will vary depending on your processor speed and so will not match those in the paper exactly.

In [1]:
import time

import pennylane as qml
from pennylane import numpy as np
from pennylane.transforms import (
    commute_controlled,
    single_qubit_fusion
)

import jax
from jax import numpy as jnp

# Enable float64 support
from jax.config import config

remember = config.read("jax_enable_x64")
config.update("jax_enable_x64", True)

First, we show how a compilation pipeline can be defined and applied to a circuit using the `@qml.compile` decorator.

In [2]:
pipeline = [
    commute_controlled(direction='left'),
    single_qubit_fusion
]

dev = qml.device('default.qubit', wires=3)

@qml.qnode(dev)
@qml.compile(pipeline=pipeline)
def circuit(x, y, z):
    qml.CNOT(wires=[0, 1])
    qml.RX(x, wires=1)
    qml.RY(y, wires=1)
    qml.S(wires=1)
    qml.CNOT(wires=[1, 2])
    qml.Hadamard(wires=2)
    qml.CNOT(wires=[2, 0])
    qml.RZ(z, wires=2)
    return qml.expval(qml.PauliZ(1))

params = np.array([0.1, 0.2, 0.3], requires_grad=True)

In [3]:
print(qml.draw(circuit)(*params))

0: ───────────────────────╭●──────────────────────────────────────────────╭X─┤     
1: ──Rot(1.57,0.10,-1.57)─╰X──Rot(0.00,0.20,1.57)─╭●──────────────────────│──┤  <Z>
2: ───────────────────────────────────────────────╰X──Rot(3.14,1.57,0.30)─╰●─┤     


We can compute the gradient of the three input parameters using the `qml.grad` transform.

In [4]:
qml.grad(circuit)(*params)

(array(-0.0978434), array(-0.19767681), array(1.33356867e-17))

# JIT compilation

Next we will incorporate JIT into the compilation process. Let's define the circuit structure below with a number of layers of parametrized gates. 

In [5]:
dev = qml.device('default.qubit', wires=5)

def circuit(x, weights):
    for wire in range(5):
        qml.RX(x[wire], wires=wire)
        qml.Hadamard(wires=wire)
        
    for wire in range(5):
        qml.Rot(*weights[wire, :], wires=wire)
    
    for wire in range(5):
        qml.CNOT(wires=[wire, (wire + 1) % 5])
    
    return qml.expval(
        qml.PauliY(0) @ qml.PauliY(1) @ qml.PauliY(2) @ qml.PauliY(3) @ qml.PauliY(4)
    )

In [6]:
original_qnode = qml.QNode(circuit, dev, interface="jax", diff_method="parameter-shift")

These are the weights that were used to generate the example in the paper.

In [7]:
x = jnp.array([0.1, 0.2, 0.3, 0.4, 0.5])

weights = jnp.array([
    [-0.28371043,  0.93681631, -1.00500712],
    [ 1.41650132,  1.05433029,  0.91081303],
    [-0.42656701,  0.98618842, -0.55753227],
    [ 0.01532506, -2.07856628,  0.55483725],
    [ 0.91423682,  0.57445956,  0.72278638]]
)



In [8]:
original_qnode(x, weights)

DeviceArray(0.47543957, dtype=float64)

In [9]:
print(qml.draw(original_qnode, expansion_strategy="device")(x, weights))

0: ──RX(0.10)──H──Rot(-0.28,0.94,-1.01)─╭●──────────╭X─┤ ╭<Y@Y@Y@Y@Y>
1: ──RX(0.20)──H──Rot(1.42,1.05,0.91)───╰X─╭●───────│──┤ ├<Y@Y@Y@Y@Y>
2: ──RX(0.30)──H──Rot(-0.43,0.99,-0.56)────╰X─╭●────│──┤ ├<Y@Y@Y@Y@Y>
3: ──RX(0.40)──H──Rot(0.02,-2.08,0.55)────────╰X─╭●─│──┤ ├<Y@Y@Y@Y@Y>
4: ──RX(0.50)──H──Rot(0.91,0.57,0.72)────────────╰X─╰●─┤ ╰<Y@Y@Y@Y@Y>


In [10]:
with qml.Tracker(dev) as tracker:
    jax.grad(original_qnode, argnums=(0, 1))(x, weights)

In [11]:
tracker.totals

{'executions': 41, 'batches': 2, 'batch_len': 41}

In [12]:
%%timeit 
jax.grad(original_qnode, argnums=(0, 1))(x, weights)

79.3 ms ± 2.44 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
# Now run through compiler
compiled_qnode = qml.QNode(
    qml.transforms.single_qubit_fusion()(circuit), 
    dev, 
    interface="jax", 
    diff_method="parameter-shift"
)

In [14]:
with qml.Tracker(dev) as tracker:
    jax.grad(compiled_qnode, argnums=(0, 1))(x, weights)

In [15]:
tracker.totals

{'executions': 31, 'batches': 2, 'batch_len': 31}

In [16]:
%%timeit
jax.grad(compiled_qnode, argnums=(0, 1))(x, weights)

601 ms ± 14.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## JIT of grad

Let's apply JIT to the gradient computation process of the original QNode.

In [17]:
original_grad = jax.grad(original_qnode, argnums=(0, 1))
compiled_grad = jax.grad(compiled_qnode, argnums=(0, 1))
jitted_original_grad = jax.jit(original_grad)
jitted_compiled_grad = jax.jit(compiled_grad)

Time the first execution manually (it will be the longest), then benchmark with JIT.

In [18]:
t0 = time.time()
jitted_original_grad(x, weights)
t1 = time.time()
print(t1 - t0)

0.1258864402770996


In [19]:
%%timeit
jitted_original_grad(x, weights)

13.6 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Let's now apply JIT to gradient computation of the compiled QNode. 

In [20]:
t0 = time.time()
jitted_compiled_grad(x, weights)
t1 = time.time()
print(t1 - t0)

5.505212306976318


In [21]:
%%timeit
jitted_compiled_grad(x, weights)

8.28 ms ± 95 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


The initial run takes substantially longer, but we see great improvement afterwards.

## Grad of JIT

We can also reverse this process: JIT the QNode first, then compute the gradient of that.

In [22]:
jitted_original_qnode = jax.jit(original_qnode)
grad_original_jit = jax.grad(jitted_original_qnode, argnums=(0, 1))

In [23]:
t0 = time.time()
grad_original_jit(x, weights)
t1 = time.time()
print(t1 - t0)

0.15106821060180664


In [24]:
%%timeit
grad_original_jit(x, weights)

16.3 ms ± 197 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


For the compiled QNode:

In [25]:
jitted_compiled_qnode = jax.jit(compiled_qnode)
grad_compiled_jit = jax.grad(jitted_compiled_qnode, argnums=(0, 1))

In [26]:
t0 = time.time()
grad_compiled_jit(x, weights)
t1 = time.time()
print(t1 - t0)

6.736414194107056


In [27]:
%%timeit
grad_compiled_jit(x, weights)

24.4 ms ± 1.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


This generally leads to longer runtimes than applying JIT to the gradient computation directly.