# Incorporating Custom Input States and Arbitrary Operator Evolutions on Statevectors

Achieving 8ms simulations for Two Coupled Qutrit Hamiltonians, with abritrary input state preparation and desired observable outputs!

In [1]:
# All Imports

import numpy as np
import matplotlib.pyplot as plt
import sympy as sym

import qiskit
from qiskit import pulse

from qiskit_dynamics import Solver, DynamicsBackend
from qiskit_dynamics.pulse import InstructionToSignals
from qiskit_dynamics.array import Array

from qiskit.quantum_info import Statevector, DensityMatrix, Operator
from qiskit.circuit.parameter import Parameter

import jax
import jax.numpy as jnp
from jax import jit, vmap, block_until_ready, config

import chex

from typing import Optional, Union

Array.set_default_backend('jax')
config.update('jax_enable_x64', True)
config.update('jax_platform_name', 'cpu')

In [2]:
# Constructing a Two Qutrit Hamiltonian

dim = 3

v0 = 4.86e9
anharm0 = -0.32e9
r0 = 0.22e9

v1 = 4.97e9
anharm1 = -0.32e9
r1 = 0.26e9

J = 0.002e9

a = np.diag(np.sqrt(np.arange(1, dim)), 1)
adag = np.diag(np.sqrt(np.arange(1, dim)), -1)
N = np.diag(np.arange(dim))

ident = np.eye(dim, dtype=complex)
full_ident = np.eye(dim**2, dtype=complex)

N0 = np.kron(ident, N)
N1 = np.kron(N, ident)

a0 = np.kron(ident, a)
a1 = np.kron(a, ident)

a0dag = np.kron(ident, adag)
a1dag = np.kron(adag, ident)


static_ham0 = 2 * np.pi * v0 * N0 + np.pi * anharm0 * N0 * (N0 - full_ident)
static_ham1 = 2 * np.pi * v1 * N1 + np.pi * anharm1 * N1 * (N1 - full_ident)

static_ham_full = static_ham0 + static_ham1 + 2 * np.pi * J * ((a0 + a0dag) @ (a1 + a1dag))

drive_op0 = 2 * np.pi * r0 * (a0 + a0dag)
drive_op1 = 2 * np.pi * r1 * (a1 + a1dag)

In [12]:
# Default Solver Options

y0 = Array(Statevector(np.ones(9)))
t_linspace = np.linspace(0.0, 400e-9, 11)
t_span = np.array([t_linspace[0], t_linspace[-1]])

dt = 1/4.5e9
atol = 1e-2
rtol = 1e-4

ham_ops = [drive_op0, drive_op1, drive_op0, drive_op1]
ham_chans = ["d0", "d1", "u0", "u1"]
chan_freqs = {"d0": v0, "d1": v1, "u0": v1, "u1": v0}

solver = Solver(
    static_hamiltonian=static_ham_full,
    hamiltonian_operators=ham_ops,
    rotating_frame=static_ham_full,
    hamiltonian_channels=ham_chans,
    channel_carrier_freqs=chan_freqs,
    dt=dt,
)

In [17]:
# Constructing a custom function that takes as input a parameter vector and returns the simulated state

def standard_func(params):
    amp, sigma, freq = params

    # Here we use a Drag Pulse as defined in qiskit pulse as its already a Scalable Symbolic Pulse
    special_pulse = pulse.Drag(
        duration=320,
        amp=amp,
        sigma=sigma,
        beta=0.1,
        angle=0.1,
        limit_amplitude=False
    )

    with pulse.build(default_alignment='sequential') as sched:
        d0 = pulse.DriveChannel(0)
        d1 = pulse.DriveChannel(1)
        u0 = pulse.ControlChannel(0)
        u1 = pulse.ControlChannel(1)

        pulse.shift_frequency(freq, d0)
        pulse.play(special_pulse, d0)

        pulse.shift_frequency(freq, d1)
        pulse.play(special_pulse, d1)

        pulse.shift_frequency(freq, u0)
        pulse.play(special_pulse, u0)

        pulse.shift_frequency(freq, u1)
        pulse.play(special_pulse, u1)
    
    return sched

def evolve_func(inp_y0, params, obs):
    sched = standard_func(params)

    converter = InstructionToSignals(dt, carriers=chan_freqs, channels=ham_chans)

    signals = converter.get_signals(sched)
    
    results = solver.solve(
        t_span=t_span,
        y0=inp_y0 / jnp.linalg.norm(inp_y0),
        t_eval=t_linspace,
        signals=signals,
        rtol=rtol,
        atol=atol,
        convert_results=False,
        method='jax_odeint'
    )

    state_vec = results.y.data[-1]
    evolved_vec = jnp.dot(obs, state_vec) / jnp.linalg.norm(state_vec)
    probs_vec = jnp.abs(evolved_vec)**2
    probs_vec = jnp.clip(probs_vec, a_min=0.0, a_max=1.0)

    return probs_vec

fast_evolve_func = jit(vmap(evolve_func))

batchsize = 400

amp_vals = jnp.linspace(0.5, 0.99, batchsize, dtype=jnp.float64).reshape(-1, 1)
sigma_vals = jnp.linspace(20, 80, batchsize, dtype=jnp.int8).reshape(-1, 1)
freq_vals = jnp.linspace(-0.5, 0.5, batchsize, dtype=jnp.float64).reshape(-1, 1) * 1e6
batch_params = jnp.concatenate((amp_vals, sigma_vals, freq_vals), axis=-1)

batch_y0 = jnp.tile(np.ones(9), (batchsize, 1))
batch_obs = jnp.tile(N0, (batchsize, 1, 1))

print(f"Batched Params Shape: {batch_params.shape}")

res = fast_evolve_func(batch_y0, batch_params, batch_obs)

print(res)
print(res.shape)

# Timing the fast jit + vmap batched simulation
%timeit fast_evolve_func(batch_y0, batch_params, batch_obs).block_until_ready()

Batched Params Shape: (400, 3)
[[0.         0.11111111 0.44444444 ... 0.         0.11111111 0.44444444]
 [0.         0.11111111 0.44444444 ... 0.         0.11111111 0.44444444]
 [0.         0.11111111 0.44444444 ... 0.         0.11111111 0.44444444]
 ...
 [0.         0.01295422 0.35747593 ... 0.         0.01299298 0.07212505]
 [0.         0.36094326 0.06762104 ... 0.         0.14119342 0.02276128]
 [0.         0.02367838 0.01447223 ... 0.         0.2119689  0.12432448]]
(400, 9)
7.37 s ± 593 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
# Timing a standard simulation without jitting or vmapping
%timeit evolve_func(batch_y0[200], batch_params[200], batch_obs[200])

1.71 s ± 28.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Done!

The normal simulation takes 1.71s per simulation. Jit + Vmap + Batched evaluation reaches
3.26s for a batch of 400 different simulations, or in other words 8ms per simulation! Thats a factor of 215 improvement in speed!

Additionally, here we have added desirable operations including state preparation (with the custom y0) and desired basis measurement by evolving the final statevector with an input observable