In [108]:
import catalyst
from catalyst import qjit

import pennylane as qml
from pennylane import numpy as np

import jax.numpy as jnp

import functools
import time
import numpy as np
import warnings
warnings.filterwarnings('ignore')

In [109]:
symbols = ["H", "H", "H"]
coordinates = np.array([0.028, 0.054, 0.0, 0.986, 1.610, 0.0, 1.855, 0.002, 0.0])

# Building the molecular hamiltonian for the trihydrogen cation
hamiltonian, qubits = qml.qchem.molecular_hamiltonian(symbols, coordinates, charge=1)

print(f"qubits: {qubits}")

qubits: 6


In [110]:
# The Hartree-Fock State
hf = qml.qchem.hf_state(electrons=2, orbitals=6)

# Define the device, using lightning.qubit device
dev = qml.device("lightning.qubit", wires=qubits)

@qml.qnode(dev, diff_method="adjoint")
def cost_func(params):
    qml.BasisState(hf, wires=range(qubits))
    qml.DoubleExcitation(params[0], wires=[0, 1, 2, 3])
    qml.DoubleExcitation(params[1], wires=[0, 1, 4, 5])
    return qml.expval(hamiltonian)

In [111]:
def workflow(params, ntrials):
    opt = qml.GradientDescentOptimizer(stepsize=0.4)

    for n in range(ntrials):
        params, prev_energy = opt.step_and_cost(cost_func, params)
        # print(f"--- Step: {n}, Energy: {cost_func(params):.8f}")

    return params

start = time.time()
theta = workflow(np.array([0.0, 0.0]), 2000)
end = time.time()

print("Final time:",end-start,"\n\n")
print(f"Final angle parameters: {theta}")

Final time: 14.823657035827637 


Final angle parameters: [0. 0.]


In [90]:
hf = qml.qchem.hf_state(electrons=2, orbitals=6)
print(f"The Hartree-Fock State: {hf}")

@qml.qnode(qml.device("lightning.qubit", wires=qubits))
def catalyst_cost_func(params):
    qml.BasisState(hf, wires=range(qubits))
    qml.DoubleExcitation(params[0], wires=[0, 1, 2, 3])
    qml.DoubleExcitation(params[1], wires=[0, 1, 4, 5])
    coeffs, ops = hamiltonian.terms()
    return qml.expval(qml.Hamiltonian(np.array(coeffs), ops))

The Hartree-Fock State: [1 1 0 0 0 0]


In [113]:
@qjit
def grad_descent(params, ntrials: int, stepsize: float):
    diff = catalyst.grad(catalyst_cost_func, argnums=0)
    theta = params

     # for_loop can only be used in JIT mode
    @catalyst.for_loop(0, ntrials, 1)
    def single_step(i, theta):
        h = diff(theta)
        return theta - h * stepsize

    return single_step(theta)

startCatalyst = time.time()
theta = grad_descent(np.array([0.0, 0.0]), 2000, 0.4)
endCatalyst = time.time()

print("Final time:",endCatalyst-startCatalyst,"\n\n")
print(f"Final angle parameters: {theta}")

Final time: 12.458559036254883 


Final angle parameters: [0.18825659 0.18904372]


In [92]:
import optax
from jax.lax import fori_loop

In [115]:
@qjit
def workflow():
    def gd_fun(params):
        diff = catalyst.grad(catalyst_cost_func, argnums=0)
        return catalyst_cost_func(params), diff(params)

    opt = optax.sgd(learning_rate=0.4)

    def gd_update(i, args):
        param, state = args
        _, gradient = gd_fun(param)
        (updates, state) = opt.update(gradient, state)
        param = optax.apply_updates(param, updates)
        return (param, state)

    params = np.array([0.0, 0.0])
    state = opt.init(params)
    upper = 20
    (params, _) = fori_loop(0, upper, gd_update, (params, state))
    return params

startJAXOpt = time.time()
theta = workflow()
endJAXOpt = time.time()

print("Final time:",endJAXOpt-startJAXOpt,"\n\n")

print(f"Final angle parameters: {theta}")
print(workflow.mlir)

Final time: 0.13218283653259277 


Final angle parameters: [0.18814175 0.18891605]
module @workflow {
  func.func public @jit_workflow() -> tensor<2xf64> attributes {llvm.emit_c_interface} {
    %cst = stablehlo.constant dense<-4.000000e-01> : tensor<f64>
    %c = stablehlo.constant dense<1> : tensor<i64>
    %c_0 = stablehlo.constant dense<20> : tensor<i64>
    %c_1 = stablehlo.constant dense<[1, 1, 0, 0, 0, 0]> : tensor<6xi64>
    %cst_2 = stablehlo.constant dense<"0xCB38BDC4AC8DD1BFE032BC5C78F4CB3F00682186E020CFBE00682186E020CFBE0018B5295C0DA13E0018B5295C0DA13E3FAB7FDF0E7FBEBF7C38265830A1BB3F70931002CD86BEBF5DF2666EDAA1BB3FE432BC5C78F4CB3F857EF49F82F4C23F003EDDCBB0F9E23E003EDDCBB0F9E23E807A6DFAE8C4B4BE807A6DFAE8C4B4BE00782186E020CFBE003EDDCBB0F9E23E00782186E020CFBE003EDDCBB0F9E23E572314F1CE6BA23F572314F1CE6BA2BF572314F1CE6BA2BF572314F1CE6BA23F0014B5295C0DA13E807A6DFAE8C4B4BE0014B5295C0DA13E807A6DFAE8C4B4BE1D3CBF7CD86BA23F1D3CBF7CD86BA2BF1D3CBF7CD86BA2BF1D3CBF7CD86BA23F00FCC6E28567B2

In [94]:
@qjit
def grad_descent_step(params, stepsize: float):
    diff = catalyst.grad(catalyst_cost_func, argnums=0)
    return params - diff(params) * stepsize

theta = jnp.array([0.0, 0.0])

startJIT = time.time()

for i in range(20):
    theta = grad_descent_step(theta, 0.4)
    # print(f"--- Step: {i}, Energy: {qjit(catalyst_cost_func)(theta):.8f}")

endJIT = time.time()

print("Final time:",endJIT-startJIT,"\n\n")

print(f"Final angle parameters: {theta}")

Final time: 0.9321701526641846 


Final angle parameters: [0.18814175 0.18891605]


In [101]:
def grad_descent_step(params, stepsize: float):
    diff = catalyst.grad(catalyst_cost_func, argnums=0)
    return params - diff(params) * stepsize

@qjit
def workflow():
    theta = jnp.array([0.0, 0.0])
    for i in range(20):
        theta = grad_descent_step(theta, 0.4)
        # print(f"--- Step: {i}, Energy: {qjit(catalyst_cost_func)(theta):.8f}")

startJITWork = time.time()
workflow()
endJITWork = time.time()

print("Final time:",endJITWork-startJITWork,"\n\n")

print(f"Final angle parameters: {theta}")

Final time: 0.18010687828063965 


Final angle parameters: [0. 0.]


In [97]:
from jax.core import ShapedArray

@qjit
def grad_descent_step_aot(params: ShapedArray([2], float), stepsize: float):
    diff = catalyst.grad(catalyst_cost_func, argnums=0)
    return params - diff(params) * stepsize

theta = jnp.array([0.0, 0.0])

startAOT= time.time()
for i in range(20):
    theta = grad_descent_step_aot(theta, 0.4)
    # print(f"--- Step: {i}, Energy: {qjit(catalyst_cost_func)(theta):.8f}")
endAOT = time.time()
print("Final Time: ",endAOT-startAOT,"\n\n")

print(f"Final angle parameters: {theta}")

Final Time:  0.24439597129821777 


Final angle parameters: [0.18814175 0.18891605]


In [102]:
from jax.core import ShapedArray

def grad_descent_step_aot(params: ShapedArray([2], float), stepsize: float):
    diff = catalyst.grad(catalyst_cost_func, argnums=0)
    return params - diff(params) * stepsize

@qjit
def workflow():
    theta = jnp.array([0.0, 0.0])
    for i in range(20):
        theta = grad_descent_step_aot(theta, 0.4)
        # print(f"--- Step: {i}, Energy: {qjit(catalyst_cost_func)(theta):.8f}")

startAOT= time.time()
workflow()
endAOT = time.time()
print("Final Time: ",endAOT-startAOT,"\n\n")

print(f"Final angle parameters: {theta}")

Final Time:  0.13121700286865234 


Final angle parameters: [0. 0.]
