From f5c41a2ecc47ebac101cb5ba23390442df459a75 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Wed, 12 Jun 2024 14:14:38 -0400 Subject: [PATCH] Update paper.md --- paper/paper.md | 55 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/paper/paper.md b/paper/paper.md index bef339031f..4383bccc95 100644 --- a/paper/paper.md +++ b/paper/paper.md @@ -171,9 +171,9 @@ quantum computer to find the ground state energy of a molecule: ```python import pennylane as qml -from catalyst import qjit, grad, for_loop +from catalyst import grad, for_loop, qjit +import jaxopt from jax import numpy as jnp -import optax mol = qml.data.load("qchem", molname="H3+")[0] n_qubits = len(mol.hamiltonian.wires) @@ -188,20 +188,17 @@ def cost(params): qml.DoubleExcitation(params[1], wires=[0, 1, 4, 5]) return qml.expval(mol.hamiltonian) -opt = optax.adam(learning_rate=0.3) - -@qml.qjit +@qjit def optimization(params): - opt_state = opt.init(params) + loss = lambda x: (cost(x), grad(cost)(x)) - def update_step(i, args): - params, opt_state = args - grads = grad(cost)(params) - updates, opt_state = opt.update(grads, opt_state) - params = optax.apply_updates(params, updates) - return (params, opt_state) + # set up optimizer and define optimization step + opt = jaxopt.GradientDescent(loss, stepsize=0.3, value_and_grad=True) + update_step = lambda step, args: tuple(opt.update(*args)) - (params, opt_state) = qml.for_loop(0, 100, 1)(update_step)((params, opt_state)) + # gradient descent parameter update loop using jit-compatible for-loop + state = opt.init_state(params) + (params, _) = for_loop(0, 10, step=1)(update_step)((params, state)) return params ``` @@ -209,7 +206,7 @@ def optimization(params): >>> params = jnp.array([0.54, 0.3154]) >>> final_params = optimization(params) >>> cost(final_params) # optimized energy of H3+ --1.2621747418777423 +-1.2621179827928877 >>> mol.vqe_energy # expected energy of H3+ -1.2613407428534986 ``` @@ -221,17 +218,35 @@ the same system, as a non-rigorous demonstration of the advantage of performing outside of Python: ```python ->>> %timeit workflow(params) -62 ms ± 254 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) +>>> %timeit optimization(params) +599 ms ± 96 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` Comparing this to a non-compiled workflow, where the `@qjit` decorator has -been removed, and the Catalyst gradient function has been replaced with -`jax.grad`: +been removed: ```python ->>> %timeit nojit_workflow(params) -440 ms ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) +@qml.qnode(dev) +def no_qjit_cost(params): + qml.BasisState(jnp.array(mol.hf_state), wires=range(n_qubits)) + qml.DoubleExcitation(params[0], wires=[0, 1, 2, 3]) + qml.DoubleExcitation(params[1], wires=[0, 1, 4, 5]) + return qml.expval(mol.hamiltonian) + +def no_qjit_optimization(params): + # set up optimizer + opt = jaxopt.GradientDescent(no_qjit_cost, stepsize=0.3, jit=False) + state = opt.init_state(params) + + for i in range(15): + (params, state) = opt.update(params, state) + + return params +``` + +```pycon +>>> %timeit no_qjit_optimization(params) +3.73 s ± 522 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` For more code examples, please see the Catalyst documentation\footnote