# Perturbative solver demo

This demo walks through the construction and usage of `PerturbativeSolver` objects for simulating a 2 transmon gate, comparing to traditional solvers using both dense and sparse arrays.

In [None]:
from time import time

import numpy as np
import jax.numpy as jnp
from jax.scipy.linalg import expm as jexpm
from jax.scipy.special import erf
from jax import jit, value_and_grad

from qiskit.quantum_info import Operator

from qiskit_dynamics import Solver, Signal
from qiskit_dynamics.perturbation import PerturbativeSolver

Configure to use JAX.

In [None]:
from qiskit_dynamics.array import Array

# configure jax to use 64 bit mode
import jax
jax.config.update("jax_enable_x64", True)

# tell JAX we are using CPU
jax.config.update('jax_platform_name', 'cpu')

# set default backend
Array.set_default_backend('jax')

# 1. Define envelope functions

We define a Gaussian square and bipolar Gaussian square pulse shape.

In [None]:
def gaussian_square(t, amp, sigma, risefall, T):
    """Gaussian square pulse."""
    
    t = Array(t).data
    C = jnp.exp(-(2*risefall*sigma)**2/(8*sigma**2))
    den = (jnp.sqrt(jnp.pi*2*sigma**2)*erf(2*risefall*sigma/(jnp.sqrt(8)*sigma))-2*risefall*sigma*C)
    return amp * jnp.piecewise(t,
                         condlist=[t < (risefall * sigma), (T - t) < (risefall * sigma)],
                         funclist=[lambda s: (jnp.exp(-(s-sigma*risefall)**2/(2*sigma**2))-C)/den,
                                   lambda s: (jnp.exp(-(T-s-sigma*risefall)**2/(2*sigma**2))-C)/den,
                                   lambda s: (1-C)/den]
                        )

def bipolar_gaussian_square(t, amp, sigma, risefall, T):
    t = Array(t).data
    unipolar = lambda s: gaussian_square(s, amp, sigma, risefall, T/2)
    return jnp.piecewise(t,
                         condlist=[t < (T/2)],
                         funclist=[unipolar, lambda s: -unipolar(s - T/2)])

Plot an example.

In [None]:
T = 200.
risefall = 2.
sigma = 7.
amp=4.

test = jnp.vectorize(lambda t: bipolar_gaussian_square(t, amp, sigma, risefall, T))

sig = Signal(test)

In [None]:
sig.draw(0, T, 1000, function='envelope')

# 2. Construct model operators

We construct a two transmon model:

$$H(t) = 2 \pi \nu a_0 a_0^\dagger + 2 \pi r (a_0 + a_0^\dagger) \\
   + 2 \pi \nu a_1 a_1^\dagger + 2 \pi r (a_1 + a_1^\dagger)\\
   + 2 \pi J (a_0a_1^\dagger + a_0^\dagger a_1)$$

In [None]:
w_c = 2 * np.pi * 5.105
w_t = 2 * np.pi * 5.033
alpha_c = 2 * np.pi * (-0.33516)
alpha_t = 2 * np.pi * (-0.33721)
J = 2 * np.pi * 0.002

dim = 5

a = np.diag(np.sqrt(np.arange(1, dim)), 1)
adag = a.transpose()
N = np.diag(np.arange(dim))
ident = np.eye(dim)
ident2 = np.eye(dim**2)

# operators on the control qubit (first tensor factor)
a0 = np.kron(a, ident)
adag0 = np.kron(adag, ident)
N0 = np.kron(N, ident)

# operators on the target qubit (first tensor factor)
a1 = np.kron(ident, a)
adag1 = np.kron(ident, adag)
N1 = np.kron(ident, N)

In [None]:
H0 = (w_c * N0 + 0.5 * alpha_c * N0 @ (N0 - ident2)
      + w_t * N1 + 0.5 * alpha_t * N1 @ (N1 - ident2)
      + J * (a0 @ adag1 + adag0 @ a1))
Hdc = 2 * np.pi * (a0 + adag0)
Hdt = 2 * np.pi * (a1 + adag1)

## 2.1 Get the dressed computational states qubit frequencies

In [None]:
def basis_vec(ind, dimension):
    vec = np.zeros(dimension, dtype=complex)
    vec[ind] = 1.
    return vec

def two_q_basis_vec(inda, indb, dimension):
    vec_a = basis_vec(inda, dimension)
    vec_b = basis_vec(indb, dimension)
    return np.kron(vec_a, vec_b)

def get_dressed_state_index(inda, indb, dimension, evectors):
    b_vec = two_q_basis_vec(inda, indb, dimension)
    overlaps = np.abs(evectors @ b_vec)
    return overlaps.argmax()

def get_dressed_state_and_energy(inda, indb, dimension, evecs):
    ind = get_dressed_state_index(inda, indb, dimension, evecs)
    return evals[ind], evecs[ind]

Diagonalize and get dressed energies/states for computational states.

In [None]:
evals, B = jnp.linalg.eigh(H0)
Badj = B.conj().transpose()

E00, dressed00 = get_dressed_state_and_energy(0, 0, dim, B.transpose())
E01, dressed01 = get_dressed_state_and_energy(0, 1, dim, B.transpose())
E10, dressed10 = get_dressed_state_and_energy(1, 0, dim, B.transpose())
E11, dressed11 = get_dressed_state_and_energy(1, 1, dim, B.transpose())

# "target dressed frequency"
v_t = E01 / (2 * np.pi)

In [None]:
H0_B = Badj @ H0 @ B
Hdc_B = Badj @ Hdc @ B
Hdt_B = Badj @ Hdt @ B

Define fidelity with respect to the $Z \otimes X$ operator for the computational states.

In [None]:
idx00 = 0
idx01 = get_dressed_state_index(0, 1, dim, B.transpose())
idx10 = get_dressed_state_index(1, 0, dim, B.transpose())
idx11 = get_dressed_state_index(1, 1, dim, B.transpose())

e00 = np.zeros(dim**2, dtype=complex)
e00[0] = 1.
e10 = np.zeros(dim**2, dtype=complex)
e10[idx10] = 1.
e01 = np.zeros(dim**2, dtype=complex)
e01[idx01] = 1.
e11 = np.zeros(dim**2, dtype=complex)
e11[idx11] = 1.

# set up observables
S = np.array([e00, e01, e10, e11]).transpose()
Sdag = S.conj().transpose()

ZX = S @ np.array(Operator.from_label('ZX')) @ Sdag

target = S @ jexpm(-1j * np.array(Operator.from_label('ZX')) * jnp.pi / 4) @ Sdag
target_conj = target.conj()

def fidelity(U):
    return jnp.abs(jnp.sum(target_conj * U))**2 / (4**2)

# 3. Construct dense version of simulation

Here we construct a function for simulating the system in the rotating frame of the drift, using a standard ODE solver, and dense arrays.

In [None]:
dense_solver = Solver(
    static_hamiltonian=H0_B,
    hamiltonian_operators=[Hdc_B, Hdt_B],
    rotating_frame=np.diag(H0_B)
)

y0 = np.eye(dim**2, dtype=complex)

def ode_sim(params, tol):
    cr_amp = params[0]
    rotary_amp = params[1]
    bipolar_amp = params[2] 
    
    cr_phase = params[3]
    rotary_phase = params[4]
    bipolar_phase = params[5]
    
    cr_signal = Signal(lambda t: gaussian_square(t, cr_amp, sigma, risefall, T), 
                       carrier_freq=v_t, 
                       phase=cr_phase)
    rotary_signal = Signal(lambda t: gaussian_square(t, rotary_amp, sigma, risefall, T), 
                           carrier_freq=v_t,
                           phase=rotary_phase)
    bipolar_signal = Signal(lambda t: bipolar_gaussian_square(t, bipolar_amp, sigma, risefall, T), 
                            carrier_freq=v_t,
                            phase=bipolar_phase)
    
    target_signal = (rotary_signal + bipolar_signal).flatten()
    
    solver_copy = dense_solver.copy()
    
    solver_copy.signals = [cr_signal, target_signal]
    results = solver_copy.solve(t_span=[0, T],
                                y0=y0,
                                method='jax_odeint',
                                atol=tol,
                                rtol=tol)
    return results.y[-1]


def ode_obj(params, tol):
    return fidelity(ode_sim(params, tol))


## Setup a collection of inputs values and create benchmark final unitaries

In [None]:
input_params = jnp.array(np.random.uniform(low=-2, high=2, size=(10, 6)))

# orig = jnp.array([1.4, 1., 0.3, 0., 0., 0.])

In [None]:
benchmark_sim = jit(lambda x: ode_sim(x, 1e-14))

benchmark_yfs = [benchmark_sim(x) for x in input_params]

## Create error metrics and function for running sims

In [15]:
def distance(U, V):
    return jnp.linalg.norm(U - V) / dim


target = S @ jexpm(-1j * np.array(Operator.from_label('ZX')) * jnp.pi / 4) @ Sdag
target_conj = target.conj()

def gate_fidelity(U):
    return jnp.abs(jnp.sum(target_conj * U))**2 / (4**2)

In [16]:
from time import time

def compute_solver_metrics(sim_func):
    sim_func = jit(sim_func)
    
    # time to jit
    start = time()
    sim_func(input_params[0]).block_until_ready()
    jit_time = time() - start
    
    # loop over and run simulations
    start = time()
    yfs = [sim_func(x) for x in input_params]
    ave_run_time = (time() - start) / len(input_params)
    
    distances = []
    for yf, benchmark_yf in zip(yfs, benchmark_yfs):
        distances.append(distance(yf, benchmark_yf))
    
    ave_distance = np.sum(distances).real / len(input_params)
    
    def fid_func(x):
        yf = sim_func(x)
        return gate_fidelity(yf)

    jit_grad_fid_func = jit(value_and_grad(fid_func))
    
    # time to jit
    start = time()
    jit_grad_fid_func(input_params[0])[0].block_until_ready()
    jit_grad_time = time() - start
    
    # time to compute gradients
    start = time()
    for x in input_params:
        jit_grad_fid_func(x)[0].block_until_ready()
    ave_grad_run_time = (time() - start) / len(input_params)
    
    return {
        'jit_time': jit_time,
        'ave_run_time': ave_run_time,
        'ave_distance': ave_distance,
        'jit_grad_time': jit_grad_time,
        'ave_grad_run_time': ave_grad_run_time
    }

# Dense simulation

Run the sims for dense simulation at various tolerances.

we should run this for up to `k==1e-13`, and possibly even for intermediate values to fill out the curve.

In [57]:
tols = [10**-k for k in range(6, 11)]

dense_results = []
for tol in tols:
    dense_results.append(compute_solver_metrics(lambda params: ode_sim(params, tol)))

# Sparse version of simulation

For sparse simulation we need to make sure we are in a basis in which the operators are actually sparse.

In [None]:
sparse_solver = Solver(
    static_hamiltonian=H0,
    hamiltonian_operators=[Hdc, Hdt],
    rotating_frame=np.diag(H0),
    evaluation_mode='sparse'
)

y0_sparse = B @ y0

def ode_sparse_sim(params, tol):
    cr_amp = params[0]
    rotary_amp = params[1]
    bipolar_amp = params[2] 
    
    cr_phase = params[3]
    rotary_phase = params[4]
    bipolar_phase = params[5]
    
    cr_signal = Signal(lambda t: gaussian_square(t, cr_amp, sigma, risefall, T), 
                       carrier_freq=v_t, 
                       phase=cr_phase)
    rotary_signal = Signal(lambda t: gaussian_square(t, rotary_amp, sigma, risefall, T), 
                           carrier_freq=v_t,
                           phase=rotary_phase)
    bipolar_signal = Signal(lambda t: bipolar_gaussian_square(t, bipolar_amp, sigma, risefall, T), 
                            carrier_freq=v_t,
                            phase=bipolar_phase)
    
    target_signal = (rotary_signal + bipolar_signal).flatten()
    
    solver_copy = sparse_solver.copy()
    
    solver_copy.signals = [cr_signal, target_signal]
    results = solver_copy.solve(t_span=[0, T],
                                y0=y0_sparse,
                                method='jax_odeint',
                                atol=tol,
                                rtol=tol)
    
    # transfer unitary into same basis and frame as the dense simulation
    U = Array(Badj) @ solver_copy.model.rotating_frame.state_out_of_frame(T, results.y[-1])
    U = dense_solver.model.rotating_frame.state_into_frame(T, U).data
    
    return U

In [None]:
sparse_results = []
for tol in tols:
    sparse_results.append(compute_solver_metrics(lambda params: ode_sparse_sim(params, tol)))

# Dyson solver

In [26]:
# system information
operators = [-1j  * Hdc_B, -1j * Hdt_B]
carrier_freqs = [v_t, v_t]
frame_operator = -1j * np.diag(H0_B)

def perturbative_solver_metrics(
    n_steps, 
    expansion_order, 
    chebyshev_order, 
    expansion_method='dyson',
    zero_carriers=False
):
    dt = T / n_steps
    
    reference_freqs = carrier_freqs
    include_imag = [True, True]
    if zero_carriers:
        reference_freqs = np.array([0., 0.])
        include_imag = [False, False]
    
    # construct solver
    start = time()
    perturb_solver = PerturbativeSolver(
        operators=operators,
        rotating_frame=frame_operator,
        dt=dt,
        carrier_freqs=reference_freqs,
        chebyshev_orders=[chebyshev_order] * 2,
        expansion_method=expansion_method,
        expansion_order=expansion_order,
        integration_method='jax_odeint',
        include_imag=include_imag,
        atol=1e-13,
        rtol=1e-13
    )
    construction_time = time() - start
    def perturb_sim(params):
        cr_amp = params[0]
        rotary_amp = params[1]
        bipolar_amp = params[2] 

        cr_phase = params[3]
        rotary_phase = params[4]
        bipolar_phase = params[5]

        cr_signal = Signal(lambda t: gaussian_square(t, cr_amp, sigma, risefall, T), 
                           carrier_freq=v_t, 
                           phase=cr_phase)
        rotary_signal = Signal(lambda t: gaussian_square(t, rotary_amp, sigma, risefall, T), 
                               carrier_freq=v_t,
                               phase=rotary_phase)
        bipolar_signal = Signal(lambda t: bipolar_gaussian_square(t, bipolar_amp, sigma, risefall, T), 
                                carrier_freq=v_t,
                                phase=bipolar_phase)

        target_signal = (rotary_signal + bipolar_signal).flatten()

        return perturb_solver.solve([cr_signal, target_signal], y0, 0., n_steps)

    results = compute_solver_metrics(perturb_sim)
    results['construction_time'] = construction_time
    results['num_terms'] = len(perturb_solver.perturbation_polynomial)
    return results


For reference, this is a value that gives a very high quality approximation. 

In [59]:
test = perturbative_solver_metrics(
    n_steps=50000,
    expansion_order=4,
    chebyshev_order=2,
)

In [60]:
test

{'jit_time': 14.078038215637207,
 'ave_run_time': 7.516575574874878,
 'ave_distance': 5.677371668412095e-09,
 'jit_grad_time': 35.84115195274353,
 'ave_grad_run_time': 18.43170909881592,
 'construction_time': 27.006247758865356,
 'num_terms': 1820}

In [48]:
test2 = perturbative_solver_metrics(
    n_steps=50000,
    expansion_order=4,
    chebyshev_order=4,
    zero_carriers=True
)

In [49]:
test2

{'jit_time': 10.976101160049438,
 'ave_run_time': 4.769232988357544,
 'ave_distance': 5.638225479315886e-09,
 'jit_grad_time': 22.496140956878662,
 'ave_grad_run_time': 10.987490820884705,
 'construction_time': 26.107959032058716,
 'num_terms': 1001}

In [55]:
test3 = perturbative_solver_metrics(
    n_steps=50000,
    expansion_order=5,
    chebyshev_order=4,
    zero_carriers=True
)

In [56]:
test3

{'jit_time': 24.03316617012024,
 'ave_run_time': 12.59182322025299,
 'ave_distance': 1.873819275867231e-10,
 'jit_grad_time': 68.87276005744934,
 'ave_grad_run_time': 46.108818006515506,
 'construction_time': 74.88299798965454,
 'num_terms': 3003}

In [58]:
dense_results

[{'jit_time': 3.2447097301483154,
  'ave_run_time': 0.6624621152877808,
  'ave_distance': 0.0034862463385427105,
  'jit_grad_time': 10.773071050643921,
  'ave_grad_run_time': 4.387863397598267},
 {'jit_time': 3.7990450859069824,
  'ave_run_time': 0.9602039813995361,
  'ave_distance': 0.00018223826904978878,
  'jit_grad_time': 13.530965089797974,
  'ave_grad_run_time': 7.015220308303833},
 {'jit_time': 3.1883771419525146,
  'ave_run_time': 1.5688777923583985,
  'ave_distance': 8.745089126281964e-06,
  'jit_grad_time': 17.250519037246704,
  'ave_grad_run_time': 11.260475730895996},
 {'jit_time': 3.914842128753662,
  'ave_run_time': 2.4593165159225463,
  'ave_distance': 4.47805087879055e-07,
  'jit_grad_time': 24.137096166610718,
  'ave_grad_run_time': 17.128229308128358},
 {'jit_time': 6.467632055282593,
  'ave_run_time': 3.9658479928970336,
  'ave_distance': 2.3332094138803402e-08,
  'jit_grad_time': 35.78397989273071,
  'ave_grad_run_time': 28.103383922576903}]