In [1]:
import catalyst
from catalyst import qjit

import pennylane as qml
from pennylane import numpy as np

import jax.numpy as jnp

import functools
import time

import warnings
warnings.filterwarnings('ignore')

In [2]:
symbols = ["H", "H", "H"]
coordinates = np.array([0.028, 0.054, 0.0, 0.986, 1.610, 0.0, 1.855, 0.002, 0.0])

# Building the molecular hamiltonian for the trihydrogen cation
hamiltonian, qubits = qml.qchem.molecular_hamiltonian(symbols, coordinates, charge=1)

print(f"qubits: {qubits}")

qubits: 6


In [3]:
# The Hartree-Fock State
hf = qml.qchem.hf_state(electrons=2, orbitals=6)

# Define the device, using lightning.qubit device
dev = qml.device("lightning.qubit", wires=qubits)

@qml.qnode(dev, diff_method="adjoint")
def cost_func(params):
    qml.BasisState(hf, wires=range(qubits))
    qml.DoubleExcitation(params[0], wires=[0, 1, 2, 3])
    qml.DoubleExcitation(params[1], wires=[0, 1, 4, 5])
    return qml.expval(hamiltonian)

In [4]:
def workflow(params, ntrials):
    opt = qml.GradientDescentOptimizer(stepsize=0.4)

    for n in range(ntrials):
        params, prev_energy = opt.step_and_cost(cost_func, params)
        # print(f"--- Step: {n}, Energy: {cost_func(params):.8f}")

    return params

start = time.time()
theta = workflow(np.array([0.0, 0.0]), 20)
end = time.time()

print("Final time:",end-start,"\n\n")
print(f"Final angle parameters: {theta}")

Final time: 0.1291189193725586 


Final angle parameters: [0.18814175 0.18891605]


In [5]:
hf = qml.qchem.hf_state(electrons=2, orbitals=6)
print(f"The Hartree-Fock State: {hf}")

@qml.qnode(qml.device("lightning.qubit", wires=qubits))
def catalyst_cost_func(params):
    qml.BasisState(hf, wires=range(qubits))
    qml.DoubleExcitation(params[0], wires=[0, 1, 2, 3])
    qml.DoubleExcitation(params[1], wires=[0, 1, 4, 5])
    coeffs, ops = hamiltonian.terms()
    return qml.expval(qml.Hamiltonian(np.array(coeffs), ops))

The Hartree-Fock State: [1 1 0 0 0 0]


In [6]:
@qjit
def grad_descent(params, ntrials: int, stepsize: float):
    diff = catalyst.grad(catalyst_cost_func, argnums=0)
    theta = params

     # for_loop can only be used in JIT mode
    @catalyst.for_loop(0, ntrials, 1)
    def single_step(i, theta):
        h = diff(theta)
        return theta - h * stepsize

    return single_step(theta)

startCatalyst = time.time()
theta = grad_descent(jnp.array([0.0, 0.0]), 20, 0.4)
endCatalyst = time.time()

print("Final time:",endCatalyst-startCatalyst,"\n\n")
print(f"Final angle parameters: {theta}")

Final time: 1.203507900238037 


Final angle parameters: [0.18814175 0.18891605]


In [7]:
import optax
from jax.lax import fori_loop

In [8]:
@qjit
def workflow():
    def gd_fun(params):
        diff = catalyst.grad(catalyst_cost_func, argnums=0)
        return catalyst_cost_func(params), diff(params)

    opt = optax.sgd(learning_rate=0.4)

    def gd_update(i, args):
        param, state = args
        _, gradient = gd_fun(param)
        (updates, state) = opt.update(gradient, state)
        param = optax.apply_updates(param, updates)
        return (param, state)

    params = jnp.array([0.0, 0.0])
    state = opt.init(params)
    upper = 20
    (params, _) = fori_loop(0, upper, gd_update, (params, state))
    return params

startJAXOpt = time.time()
theta = workflow()
endJAXOpt = time.time()

print("Final time:",endJAXOpt-startJAXOpt,"\n\n")

print(f"Final angle parameters: {theta}")

Final time: 0.11701726913452148 


Final angle parameters: [0.18814175 0.18891605]


In [9]:
@qjit
def grad_descent_step(params, stepsize: float):
    diff = catalyst.grad(catalyst_cost_func, argnums=0)
    return params - diff(params) * stepsize

theta = jnp.array([0.0, 0.0])

startJIT = time.time()

for i in range(20):
    theta = grad_descent_step(theta, 0.4)
    # print(f"--- Step: {i}, Energy: {qjit(catalyst_cost_func)(theta):.8f}")

endJIT = time.time()

print("Final time:",endJIT-startJIT,"\n\n")

print(f"Final angle parameters: {theta}")

Final time: 0.9155933856964111 


Final angle parameters: [0.18814175 0.18891605]
