In [1]:
import pennylane as qml
import pennylane.numpy as np
import jax.numpy as jnp
import jax

jax.config.update("jax_enable_x64", True)

import sys
sys.path.append('../')
from util import grad_estimate

In [2]:
data = qml.data.load("qchem", molname="HeH+", basis="STO-3G", bondlength=1.5)[0]
H_obj = data.tapered_hamiltonian
H_obj = qml.Hamiltonian(jnp.array(H_obj.coeffs), H_obj.ops)
E_exact = data.fci_energy

# values and parametrization matching exactly https://arxiv.org/pdf/2210.15812.pdf
# all in units of 10^9
qubit_freq = jnp.pi * 2 * np.array([5.23, 5.01])
eps = np.array([32.9, 31.5]) #10^9
max_amp = jnp.array([0.955, 0.987]) # much larger than in ctrl-vqe paper
connections = [(0, 1)]
coupling = 0.0123
wires = [0, 1]
n_wires = len(wires)
timespan = 100.

def normalize(x):
    """Differentiable normalization to +/- 1 outputs (shifted sigmoid)"""
    return (1 - jnp.exp(-x))/(1 + jnp.exp(-x))

legendres = jnp.array([
    [0, 0, 0, 0, 1],
    [0, 0, 0, 1, 0],
    0.5*jnp.array([0, 0, 3, 0, -1]),
    0.5*jnp.array([0, 5, 0, -3, 0]),
    1/8*jnp.array([35, 0, -30, 0, 3])
])
leg_func = jax.jit(jax.vmap(jnp.polyval, [0, None]))
dLeg = len(legendres)

def amp(timespan, omega, max_amp):
    def wrapped(p, t):
        pr, pi = p[:dLeg], p[dLeg:]
        par = pr + 1j * pi
        leg_vals = leg_func(legendres, 2*t/timespan - 1)
        z = par @ leg_vals
        res = normalize(z) * jnp.angle(z)
        res = max_amp * jnp.real(jnp.exp(1j*omega*t) * res) # eq. (27)
        return res
    return wrapped

H_D = qml.dot(0.5*eps, [qml.Identity(i) - qml.PauliZ(i) for i in wires])
H_D += coupling/2 * (qml.PauliX(0) @ qml.PauliX(1) + qml.PauliY(0) @ qml.PauliY(1)) # implicit factor 2 due to redundancy in formula

fs = [amp(timespan, qubit_freq[i], max_amp[i]) for i in range(n_wires)]
ops = [qml.PauliX(i) for i in wires]

H_C = qml.dot(fs, ops)
H = H_D + H_C

atol=1e-16

dev = qml.device("default.qubit.jax", wires=n_wires)

def circuit(params, ts):
    qml.evolve(H, atol=atol)(params, t=ts)
    return qml.expval(H_obj)

def f(params, tau):
    return [fs[i](params[i], tau) for i in range(len(fs))]

cost_jax = qml.QNode(circuit, dev, interface="jax")

key = jax.random.PRNGKey(42)
params = jax.random.normal(key, shape=(n_wires, 2*dLeg))

In [3]:
N_s = 10002
taus = jnp.linspace(0., timespan, N_s)

grad_exact = jnp.array(jax.jacobian(cost_jax)(params, taus))
grad_exact

Array([[-1.52170615e+01, -2.39397289e+00,  2.36863833e+00,
        -4.37190483e-01,  4.99867229e+00, -6.80037822e+00,
         1.07393153e+01,  5.20989824e+00, -3.89863420e+00,
        -1.51870226e+00],
       [-2.22706166e+01, -3.30561140e-01,  9.51753532e+00,
         3.11007803e-01, -5.33835977e+00,  1.97866925e+01,
         1.61035905e-02, -6.84630350e+00,  1.35886302e-01,
         3.01250957e+00]], dtype=float64)

In [4]:
# compute static data once
# [U(t0, t0)] U(t0, t1) .. U(t0, tn-1) [U(t0, tn)]
U0 = qml.matrix(qml.evolve(H, atol=atol)(params, t=taus, return_intermediate=True))[1:-1]
# [U(t0, tn)] U(t1, tn) .. U(tn-1, tn) [U(tn, tn)]
U1 = qml.matrix(qml.evolve(H, atol=atol)(params, t=taus, return_intermediate=True, complementary=True))[1:-1]
taus = taus[1:-1]

In [5]:
def p(h, sign):
    digi_gate = qml.matrix(qml.evolve(h, sign*np.pi/4), wire_order=range(n_wires))
    Hm = qml.matrix(H_obj, wire_order=range(n_wires))

    psi0 = jnp.eye(2**n_wires)[0]
    psit = jnp.einsum("Nij,j", U0, psi0)
    psit = jnp.einsum("ij,Nj->Ni", digi_gate, psit)
    psit = jnp.einsum("Nij,Nj->Ni", U1, psit)
    return jnp.einsum("Ni,ij,Nj->N",psit.conj(), Hm, psit).real

def p_drift(h, sign):
    digi_gate = qml.evolve(H_D + qml.pulse.constant * h)([sign*jnp.pi/4], t=1.)
    digi_gate = qml.matrix(digi_gate, wire_order=range(n_wires))
    Hm = qml.matrix(H_obj, wire_order=range(n_wires))

    psi0 = jnp.eye(2**n_wires)[0]
    psit = jnp.einsum("Nij,j", U0, psi0)
    psit = jnp.einsum("ij,Nj->Ni", digi_gate, psit)
    psit = jnp.einsum("Nij,Nj->Ni", U1, psit)
    return jnp.einsum("Ni,ij,Nj->N",psit.conj(), Hm, psit).real

In [6]:
# classical jacobian
jac_fun = jax.vmap(jax.jit(jax.jacobian(f)), [None, 0])
jac = jnp.array(jac_fun(params, taus))
jac = jnp.moveaxis(jac, 1, 0) # shape Ns, Nops, Nparams
jac = jnp.reshape(jac, (N_s-2, n_wires, -1))

# integrand results
ps = jnp.array([p(h, 1) - p(h, -1) for h in ops])
ps = jnp.moveaxis(ps, -1, 0)[:, :, jnp.newaxis] # Ns, Nops, 1 (for broadcasting)

ps_drift = jnp.array([p_drift(h, 1) - p_drift(h, -1) for h in ops])
ps_drift = jnp.moveaxis(ps_drift, -1, 0)[:, :, jnp.newaxis] # Ns, Nops, 1 (for broadcasting)

res = jac * ps
res_drift = jac * ps_drift

In [7]:
n_tauss = [5, 10, 20, 40, 80, 160]
reps = 100

In [8]:
grads = grad_estimate(res, jac, timespan, seed=92, n_tauss=None, reps=reps, importance=False)

In [9]:
name = "data/legendre_s2n"
np.savez(name,
    res=res,
    res_drift=res_drift,
    taus=taus,
    n_tauss=n_tauss,
    grads=grads, 
    grad_exact=grad_exact,
    qubit_freq=qubit_freq,
    params=params)