Skip to content

Commit

Permalink
Update paper.md
Browse files Browse the repository at this point in the history
  • Loading branch information
josh146 committed Jun 12, 2024
1 parent 2cc7214 commit f5c41a2
Showing 1 changed file with 35 additions and 20 deletions.
55 changes: 35 additions & 20 deletions paper/paper.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ quantum computer to find the ground state energy of a molecule:

```python
import pennylane as qml
from catalyst import qjit, grad, for_loop
from catalyst import grad, for_loop, qjit
import jaxopt
from jax import numpy as jnp
import optax

mol = qml.data.load("qchem", molname="H3+")[0]
n_qubits = len(mol.hamiltonian.wires)
Expand All @@ -188,28 +188,25 @@ def cost(params):
qml.DoubleExcitation(params[1], wires=[0, 1, 4, 5])
return qml.expval(mol.hamiltonian)

opt = optax.adam(learning_rate=0.3)

@qml.qjit
@qjit
def optimization(params):
opt_state = opt.init(params)
loss = lambda x: (cost(x), grad(cost)(x))

def update_step(i, args):
params, opt_state = args
grads = grad(cost)(params)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return (params, opt_state)
# set up optimizer and define optimization step
opt = jaxopt.GradientDescent(loss, stepsize=0.3, value_and_grad=True)
update_step = lambda step, args: tuple(opt.update(*args))

(params, opt_state) = qml.for_loop(0, 100, 1)(update_step)((params, opt_state))
# gradient descent parameter update loop using jit-compatible for-loop
state = opt.init_state(params)
(params, _) = for_loop(0, 10, step=1)(update_step)((params, state))
return params
```

```pycon
>>> params = jnp.array([0.54, 0.3154])
>>> final_params = optimization(params)
>>> cost(final_params) # optimized energy of H3+
-1.2621747418777423
-1.2621179827928877
>>> mol.vqe_energy # expected energy of H3+
-1.2613407428534986
```
Expand All @@ -221,17 +218,35 @@ the same system, as a non-rigorous demonstration of the advantage of performing
outside of Python:

```python
>>> %timeit workflow(params)
62 ms ± 254 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit optimization(params)
599 ms ± 96 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```

Comparing this to a non-compiled workflow, where the `@qjit` decorator has
been removed, and the Catalyst gradient function has been replaced with
`jax.grad`:
been removed:

```python
>>> %timeit nojit_workflow(params)
440 ms ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
@qml.qnode(dev)
def no_qjit_cost(params):
qml.BasisState(jnp.array(mol.hf_state), wires=range(n_qubits))
qml.DoubleExcitation(params[0], wires=[0, 1, 2, 3])
qml.DoubleExcitation(params[1], wires=[0, 1, 4, 5])
return qml.expval(mol.hamiltonian)

def no_qjit_optimization(params):
# set up optimizer
opt = jaxopt.GradientDescent(no_qjit_cost, stepsize=0.3, jit=False)
state = opt.init_state(params)

for i in range(15):
(params, state) = opt.update(params, state)

return params
```

```pycon
>>> %timeit no_qjit_optimization(params)
3.73 s ± 522 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```

For more code examples, please see the Catalyst documentation\footnote
Expand Down

0 comments on commit f5c41a2

Please sign in to comment.