In [1]:
%env JAX_ENABLE_X64=True
import jax.numpy as jnp
import jax
    
import scipy
import numpy as np

env: JAX_ENABLE_X64=True


In [2]:
from openfermion.hamiltonians import MolecularData
from openfermion.transforms import get_fermion_operator, jordan_wigner

filename = "./Quantaggle_dataset-master/datasets/Small_Molecules_1/H2_line_sto-3g/H2_line_sto-3g_singlet_0.50.hdf5"
# filename = "./Quantaggle_dataset-master/datasets/Small_Molecules_1/H4_line_sto-3g/H4_line_sto-3g_singlet_0.50.hdf5"
# filename = "./Quantaggle_dataset-master/datasets/Small_Molecules_1/H2O_sto-3g/H2O_sto-3g_singlet_0.50_104.5deg_0.50.hdf5"
molecular_data = MolecularData(filename=filename) # load hdf5 file
molecular_hamiltonian = get_fermion_operator(molecular_data.get_molecular_hamiltonian()) # get an instance of second quantized hamiltonian
jw_hamiltonian = jordan_wigner(molecular_hamiltonian) # get a Pauli operator representation of the hamiltonian
# print(jw_hamiltonian)

In [3]:
NQBITS = np.max(np.array([qbit for op in molecular_hamiltonian for terms, _const in op.terms.items() for qbit, name in terms]))+1
print("Number of qbits", NQBITS)

Number of qbits 4


In [4]:
def U3Gate(t,p,l):
    """
    Universal single qbit gate.
    """
    return jnp.array(([
        [jnp.cos(t/2), -jnp.exp(1j*l)*jnp.sin(t/2)],
        [jnp.exp(1j*p)*jnp.sin(t/2), jnp.exp(1j*(p+l))*jnp.cos(t/2)]
    ]))

def XGate(): return jnp.array([[0.0,1.0],[1.0,0.0]])
def YGate(): return jnp.array([[0.0,1.0j],[-1.0j,0.0]])
def ZGate(): return jnp.array([[1.0,0.0],[0.0,-1.0]])

GATES = {
    'X': XGate(),
    'Y': YGate(),
    'Z': ZGate(),
}

def apply_gate(gate, qbit, phi):
    """
    Apply quantum gate to a statevector for a given qbit.
    Argument:
        gate: 2x2 matrix of the gate.
        qbit: index of qbit affected by the gate.
        phi: statevector in the form of 2x2x2x...x2 array.
    """
    nq = phi.ndim
    assert qbit<nq    
    idx = list(range(nq))
    idx[qbit] = nq
    return jnp.einsum(gate, [nq, qbit], phi, list(range(nq)), idx)

def apply_controlled_gate(control, gate, qbit, phi):
    """
    Apply controlled quantum gate to a statevector for a given qbit.
    Argument:
        control: control bit.
        gate: 2x2 matrix of the gate.
        qbit: index of qbit affected by the gate.
        phi: statevector in the form of 2x2x2x...x2 array.
    """
    nq = phi.ndim
    assert control<nq and qbit<nq and control!=qbit
    prefix = (slice(None),)*control
    psi0 = phi[prefix+(0,None)] # Action if control is zero.
    idx = list(range(nq))
    idx[qbit] = nq
    psi1 = jnp.einsum(gate, [nq, qbit], phi[prefix+(1,None)], list(range(nq)), idx) # Action if control 1.
    return jnp.concatenate((psi0,psi1), axis=control)

# Small test.
phi = jnp.arange(4).reshape((2,2))
print("Initial state phi", phi.tolist())
print("X0 phi", apply_gate(XGate(), 0, phi).tolist())
print("X1 phi", apply_gate(XGate(), 1, phi).tolist())
print("C0X1 phi", apply_controlled_gate(0, XGate(), 1, phi).tolist())



Initial state phi [[0, 1], [2, 3]]
X0 phi [[2.0, 3.0], [0.0, 1.0]]
X1 phi [[1.0, 0.0], [3.0, 2.0]]
C0X1 phi [[0.0, 1.0], [3.0, 2.0]]


In [5]:
def apply_entangle(phi):
    """
    Entangle layer in HEA.
    """
    nqbits = phi.ndim
    for q in range(nqbits):
        phi = apply_controlled_gate(q, XGate(), (q+1)%nqbits, phi)
    return phi

def ansatz(theta):
    """
    Compute ansatz for VQE.
    Input:
        theta: array of parameters 3 x nqbits x depth.
    Output:
        statevector: array 2x2x...x2 = 2^nqbits.
    """
    nparam, nqbits, depth = theta.shape
    phi = jnp.zeros((2,)*nqbits, dtype=np.complex128)
    phi = jax.ops.index_update(phi, (0,)*nqbits, 1.) # JAX version
    
    for d in range(depth):
        if d>0: phi = apply_entangle(phi)
        for q in range(nqbits):
            phi = apply_gate(U3Gate(*theta[:,q,d]), q, phi)
    return phi

# A small test.
ansatz(np.zeros((3,4,1))).tolist()

[[[[(1+0j), 0j], [0j, 0j]], [[0j, 0j], [0j, 0j]]],
 [[[0j, 0j], [0j, 0j]], [[0j, 0j], [0j, 0j]]]]

In [6]:
def apply_hamiltonian_term(terms, phi):
    global GATES
    for qbit, name in terms:
        phi = apply_gate(GATES[name], qbit, phi)
    return phi

def apply_hamiltonian(hamiltonian, phi):
    result = 0.
    for op in hamiltonian:
        for terms, const in op.terms.items():
            psi = apply_hamiltonian_term(terms, phi)
            result = result + const*psi 
    return result

def mean_energy(hamiltonian, phi):
    psi = apply_hamiltonian(hamiltonian, phi)
    E = jnp.sum( jnp.conj(phi)*psi )
#     assert np.abs(np.imag(E))<1e-6
    return jnp.real(E)
    
def experiment(theta, hamiltonian):
    phi = ansatz(theta)
    return mean_energy(hamiltonian, phi)
    
# A small test.
theta0 = np.random.randn(3, NQBITS, 1)
experiment(theta0, jw_hamiltonian)

DeviceArray(0.49861167, dtype=float64)

In [7]:
nqbits = NQBITS
@jax.jit
def energy_flat(x):
    return apply_hamiltonian(jw_hamiltonian, x.reshape((2,)*nqbits)).flatten()
H = scipy.sparse.linalg.LinearOperator((2**nqbits, 2**nqbits), matvec=energy_flat)
%time e0, w0 = scipy.sparse.linalg.eigsh(H, k=1, which='SA', tol=1e-6) # Lowest eigenvalue
print("Ground state energy:", e0)

CPU times: user 263 ms, sys: 4.47 ms, total: 267 ms
Wall time: 269 ms
Ground state energy: [-1.05515979]


In [8]:
# Reminder, how to extract data from OpenFermion.
# op = next(jw_hamiltonian.get_operators())
# ?op.terms

# for op in jw_hamiltonian:
#     print(f"{op=}")
#     for key, const in op.terms.items():
#         for nq, name_op in key:
#             print(nq, name_op)
#         print("Const", const)

In [9]:
nqbits = NQBITS
depth = 2
theta_initial = np.random.randn(3, nqbits, depth) # Initial values of parameters

@jax.jit
def loss(theta):
    theta = theta.reshape(theta_initial.shape)
    return experiment(theta, jw_hamiltonian)

%time res = scipy.optimize.minimize(loss, theta_initial.flatten())
theta = res.x
min_energy = res.fun
print(res.message, f"#parameters {theta.size}. Cost: {res.nfev}. Min. energy: {min_energy:.6}")
# print(res)

CPU times: user 810 ms, sys: 21.3 ms, total: 831 ms
Wall time: 813 ms
Optimization terminated successfully. #parameters 24. Cost: 675. Min. energy: -1.05516


In [10]:
# Parameter shift rule.
# @jax.jit
def d_loss_psr(theta):
    grad = []
    for idx in range(theta.shape[0]):
        tp = jax.ops.index_add(theta, idx, jnp.pi/2)
        gp = loss(tp)
        tm = jax.ops.index_add(theta, idx, -jnp.pi/2)
        gm = loss(tm)
        grad.append( 0.5*(gp-gm) )
    return jnp.asarray(grad)

# Much slower than JIT. 
%time res = scipy.optimize.minimize(loss, theta_initial.flatten(), jac=d_loss_psr)
theta = res.x
min_energy = res.fun
print(res.message, f"#parameters {theta.size}. Cost: {res.nfev}. Min. energy: {min_energy:.6}")

# Full cost here is much higher, since energy should be computed 2*#parameters+1 times per iteration.
print(f"Full cost: {res.nfev*(2*theta.size+1)}")

CPU times: user 1.85 s, sys: 242 ms, total: 2.09 s
Wall time: 1.68 s
Optimization terminated successfully. #parameters 24. Cost: 27. Min. energy: -1.05516
Full cost: 1323


In [11]:
# If JAX is enabled, the gradient can be computed automatically.
# Try to use autograd (https://github.com/HIPS/autograd) on Windows, since jaxlib is not available for the platform/

d_loss_ag = jax.jit(jax.grad(loss))

dedtheta_ag = d_loss_ag(theta_initial.flatten()) # Function d_loss_ag will be compiled here, can take minutes.
# Again slow. Do not run for larger problems.
dedtheta_psr = d_loss_psr(theta_initial.flatten())
print("Error:", np.linalg.norm(dedtheta_ag-dedtheta_psr))

Error: 3.5930558143486307e-16


In [12]:
%time res = scipy.optimize.minimize(loss, theta_initial.flatten(), jac=d_loss_ag)
theta = res.x
min_energy = res.fun
print(res.message, f"#parameters {theta.size}. Cost: {res.nfev}. Min. energy: {min_energy:.6}")


CPU times: user 11.6 ms, sys: 260 µs, total: 11.9 ms
Wall time: 9.53 ms
Optimization terminated successfully. #parameters 24. Cost: 27. Min. energy: -1.05516
