Skip to content

Commit

Permalink
Review suggestions for the Catalyst paper (#807)
Browse files Browse the repository at this point in the history
  • Loading branch information
josh146 committed Jun 20, 2024
1 parent 36ed03b commit 77c2ee7
Showing 1 changed file with 63 additions and 31 deletions.
94 changes: 63 additions & 31 deletions paper/paper.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,59 +142,73 @@ dev = qml.device("lightning.qubit", wires=2)
@qjit(autograph=True)
def hybrid_function(x):

@qml.qnode(dev, diff_method="parameter-shift")
def circuit(x):
qml.RX(x, wires=0)
qml.RY(x ** 2, wires=1)
qml.CNOT(wires=[0, 1])
@qml.qnode(dev, diff_method="parameter-shift")
def circuit(x):
qml.RX(x, wires=0)
qml.RY(x ** 2, wires=1)
qml.CNOT(wires=[0, 1])

for i in range(0, 10):
m = measure(wires=0)
for i in range(0, 10):
m = measure(wires=0)

if m == 1:
qml.CRX(x * jnp.exp(- x ** 2), wires=[0, 1])
if m == 1:
qml.CRX(x * jnp.exp(- x ** 2), wires=[0, 1])

x = x * 0.2
x = x * 0.2

return qml.expval(qml.PauliZ(0))
return qml.expval(qml.PauliZ(0))

return jnp.sin(circuit(x)) ** 2
return jnp.sin(circuit(x)) ** 2
```

```pycon
>>> hybrid_function(0.543)
array(0.70807342)
```

We can also consider an example that includes a classical optimization loop, such as optimizing a
quantum computer to find the ground state energy of a molecule:

```python
import pennylane as qml
from catalyst import grad, for_loop, qjit
import jaxopt
from catalyst import grad
from jax import numpy as jnp

molecule = qml.data.load("qchem", molname="H3+")[0]
n_qubits = len(molecule.hamiltonian.wires)
mol = qml.data.load("qchem", molname="H3+")[0]
n_qubits = len(mol.hamiltonian.wires)

dev = qml.device("lightning.qubit", wires=n_qubits)

@qjit
@qml.qnode(dev)
def cost(params):
qml.BasisState(molecule.hf_state, wires=range(n_qubits))
qml.BasisState(jnp.array(mol.hf_state), wires=range(n_qubits))
qml.DoubleExcitation(params[0], wires=[0, 1, 2, 3])
qml.DoubleExcitation(params[1], wires=[0, 1, 4, 5])
return qml.expval(molecule.hamiltonian)
return qml.expval(mol.hamiltonian)

@qjit
def workflow(init_params):
loss = lambda x: (cost(params), grad(cost)(params)[0])
opt = jaxopt.GradientDescent(loss, stepsize=0.4, value_and_grad=True)
def optimization(params):
loss = lambda x: (cost(x), grad(cost)(x))

# set up optimizer and define optimization step
opt = jaxopt.GradientDescent(loss, stepsize=0.3, value_and_grad=True)
update_step = lambda step, args: tuple(opt.update(*args))

params = init_params
# gradient descent parameter update loop using jit-compatible for-loop
state = opt.init_state(params)
(param, _) = for_loop(0, 10, step=1)(update_step)((params, state))
return param
(params, _) = for_loop(0, 10, step=1)(update_step)((params, state))
return params
```

params = jnp.array([0.54, 0.3154])
workflow(params)
```pycon
>>> params = jnp.array([0.54, 0.3154])
>>> final_params = optimization(params)
>>> cost(final_params) # optimized energy of H3+
-1.2621179827928877
>>> mol.vqe_energy # expected energy of H3+
-1.2613407428534986
```

Here, we are using the JAXopt gradient optimization library [@blondel2022efficient] alongside the
Expand All @@ -204,17 +218,35 @@ the same system, as a non-rigorous demonstration of the advantage of performing
outside of Python:

```python
>>> %timeit workflow(params)
62 ms ± 254 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit optimization(params)
599 ms ± 96 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```

Comparing this to a non-compiled workflow, where the `@qjit` decorator has
been removed, and the Catalyst gradient function has been replaced with
`jax.grad`:
been removed:

```python
>>> %timeit nojit_workflow(params)
440 ms ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
@qml.qnode(dev)
def no_qjit_cost(params):
qml.BasisState(jnp.array(mol.hf_state), wires=range(n_qubits))
qml.DoubleExcitation(params[0], wires=[0, 1, 2, 3])
qml.DoubleExcitation(params[1], wires=[0, 1, 4, 5])
return qml.expval(mol.hamiltonian)

def no_qjit_optimization(params):
# set up optimizer
opt = jaxopt.GradientDescent(no_qjit_cost, stepsize=0.3, jit=False)
state = opt.init_state(params)

for i in range(15):
(params, state) = opt.update(params, state)

return params
```

```pycon
>>> %timeit no_qjit_optimization(params)
3.73 s ± 522 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```

For more code examples, please see the Catalyst documentation\footnote
Expand Down

0 comments on commit 77c2ee7

Please sign in to comment.