In [1]:
from time import time

import matplotlib.pyplot as plt

import numpy as np
from numpy.polynomial.chebyshev import chebpts1, chebvander

import sys,os
sys.path.append(os.getcwd())

import jax.numpy as jnp
from jax.scipy.linalg import expm as jexpm
from jax import vmap, jacfwd, jit
from jax.lax import cond, scan, associative_scan
from jax.config import config
config.update("jax_enable_x64", True)

from qiskit_dynamics import dispatch, solve_lmde
from qiskit_dynamics.dispatch import Array
from qiskit.quantum_info import Operator
from qiskit_dynamics.models import HamiltonianModel, Frame
from qiskit_dynamics.signals import Signal
from qiskit_dynamics_internal.perturbation import solve_lmde_perturbation

dispatch.set_default_backend('jax')

# 1. Define system parameters and operators

System parameters

In [2]:
w_c = 2 * np.pi * 5.1
w_t = 2 * np.pi * 5.
alpha_c = 2 * np.pi * (-0.33)
alpha_t = 2 * np.pi * (-0.33)
J = 2 * np.pi * 0.002
# drive strength
r = 2 * np.pi * 0.05

dim = 4

a = np.diag(np.sqrt(np.arange(1, dim)), 1)
adag = a.transpose()
N = np.diag(np.arange(dim))
ident = np.eye(dim)
ident2 = np.eye(dim**2)

# operators on the control qubit (first tensor factor)
a0 = np.kron(a, ident)
adag0 = np.kron(adag, ident)
N0 = np.kron(N, ident)

# operators on the target qubit (first tensor factor)
a1 = np.kron(ident, a)
adag1 = np.kron(ident, adag)
N1 = np.kron(ident, N)

Hamiltonian operators.

In [3]:
H0 = (w_c * N0 + 0.5 * alpha_c * N0 @ (N0 - ident2)
      + w_t * N1 + 0.5 * alpha_t * N1 @ (N1 - ident2)
      + J * (a0 @ adag1 + adag0 @ a1))
Hdc = r*(a0 + adag0)
Hdt = r*(a1 + adag1)

Pulse envelope to simulate.

In [4]:
def gauss(t, sig):
    t = Array(t).data
    return jnp.exp(- (t**2) / (2 *(sig **2)))

# this is getting a bit wild with the array wrapping but
# I've messed around with too many things
def pulse(sig, taur, tau, t):
    sig = Array(sig).data
    taur = Array(taur).data
    tau = Array(tau).data
    t = Array(t).data
    b = gauss(taur, sig)
    return cond(t < taur,
                lambda s: (gauss(t - taur, sig) - b) / (1 - b),
                lambda s: cond(t < tau-taur,
                               lambda r: 1.,
                               lambda r: (gauss(t - (tau - taur), sig) - b) / (1-b),
                               0.),
                0.)

## 1.1 Establish dressed energies

In [5]:
def basis_vec(ind, dimension):
    vec = np.zeros(dimension, dtype=complex)
    vec[ind] = 1.
    return vec

def two_q_basis_vec(inda, indb, dimension):
    vec_a = basis_vec(inda, dimension)
    vec_b = basis_vec(indb, dimension)
    return np.kron(vec_a, vec_b)

def get_dressed_state_index(inda, indb, dimension, evectors):
    b_vec = two_q_basis_vec(inda, indb, dimension)
    overlaps = np.abs(evectors @ b_vec)
    return overlaps.argmax()

def get_dressed_state_and_energy(inda, indb, dimension, evecs):
    ind = get_dressed_state_index(inda, indb, dimension, evecs)
    return evals[ind], evecs[ind]

Establish dressed energies and basis vectors

In [6]:
evals, B = jnp.linalg.eigh(H0)
Badj = B.conj().transpose()

E00, dressed00 = get_dressed_state_and_energy(0, 0, dim, B)
E01, dressed01 = get_dressed_state_and_energy(0, 1, dim, B)
E10, dressed10 = get_dressed_state_and_energy(1, 0, dim, B)
E11, dressed11 = get_dressed_state_and_energy(1, 1, dim, B)



In [7]:
w_d = E01
v_d = w_d / (2 * jnp.pi)
w_d - w_t

DeviceArray(-0.00025123, dtype=float64)

In [8]:
H0_B = Badj @ H0 @ B
Hdc_B = Badj @ Hdc @ B

def full_ham(t):
    return H0_B + pulse_func(t) * jnp.cos(w_d * t) * Hdc_B

# 2. Full simulation of the system

Set parameters of simulation (length of pulse, number of time steps, etc).

In [9]:
# choose parameters, express everything in terms of dt
dt = 15. / v_d

tau = 200
N = int(tau // dt)
tau = N * dt
taur = (20 //dt) * dt
sig = taur / 2

pulse_func = jnp.vectorize(lambda t: pulse(sig, taur, tau, t))

In [10]:
N

66

In [None]:
times = np.linspace(0, N*dt, 100)
vals = pulse_func(times)
plt.plot(times, vals)

In [None]:
# first simulate full Hamiltonian to as high a precision as possible for comparisons
ham = HamiltonianModel(operators=[H0_B, Hdc_B])

As a function

In [None]:
def full_sim(taur, sig):
    ham_copy = ham.copy()
    drive_func = lambda t: pulse(sig, taur, tau, t)
    ham_copy.signals = [1., Signal(drive_func, carrier_freq=v_d)]
    ham_copy.frame = ham_copy.drift
    
    results = solve_lmde(ham_copy, 
                     t_span=[0, N*dt],
                     y0=jnp.eye(dim**2, dtype=complex),
                     method='jax_odeint',
                     atol=1e-15,
                     rtol=1e-15)

    return results.y[-1]

# just in time compile it
jit_sim = jit(full_sim)

## 2.1 measure single jit + simulate

In [None]:
start = time()
single_sim_y = jit_sim(taur/5, sig).block_until_ready()
print(time() - start)

In [None]:
start = time()
single_sim_y = jit_sim(taur/5, sig).block_until_ready()
print(time() - start)

In [None]:
tau_vals = taur * jnp.linspace(0.1, 5, 1)

## 2.2 measure vmap jit + simulate

In [None]:
jit_vmap_sim = jit(vmap(lambda tau: full_sim(tau, sig)))

In [None]:
start = time()
direct_sim_y = jit_vmap_sim(tau_vals).block_until_ready()
sim_time = time() - start
print('Batch simulation time including jit: ' + str(sim_time))
print('Batch simulation average time including jit: ' + str(sim_time / len(tau_vals)))

In [None]:
start = time()
direct_sim_y = jit_vmap_sim(tau_vals).block_until_ready()
sim_time = time() - start
print('Batch simulation time without jit: ' + str(sim_time))
print('Batch simulation average time without jit: ' + str(sim_time / len(tau_vals)))

## 2.3 Simulate at lower tolerance

Here we set the tolerance of the simulation to try to achieve about $10^{-5}$ infidelity.

In [None]:
def full_sim2(taur, sig):
    ham_copy = ham.copy()
    drive_func = lambda t: pulse(sig, taur, tau, t)
    ham_copy.signals = [1., Signal(drive_func, carrier_freq=v_d)]
    ham_copy.frame = ham_copy.drift
    
    results = solve_lmde(ham_copy, 
                     t_span=[0, N*dt],
                     y0=jnp.eye(dim**2, dtype=complex),
                     method='jax_odeint',
                    atol=2.5*1e-8,
                    rtol=2.5*1e-8)

    return results.y[-1]

jit_sim2 = jit(full_sim2)

### 2.3.1 Measure single jit + sim time 

In [None]:
start = time()
single_sim_lower_tol_y = jit_sim2(taur/5, sig).block_until_ready()
print(time() - start)

In [None]:
start = time()
single_sim_lower_tol_y = jit_sim2(taur/5, sig).block_until_ready()
print(time() - start)

In [None]:
1-(jnp.abs((single_sim_lower_tol_y.conj().transpose() @ single_sim_y).trace()) / (dim**2))**2

### 2.3.2 Measure vmap jit + sim time

In [None]:
jit_vmap_sim_lower_tol = jit(vmap(lambda tau: full_sim2(tau, sig)))

In [None]:
start = time()
direct_sim_lower_tol_y = jit_vmap_sim_lower_tol(tau_vals).block_until_ready()
sim_time = time() - start
print('Batch simulation time including jit: ' + str(sim_time))
print('Batch simulation average time including jit: ' + str(sim_time / len(tau_vals)))

In [None]:
start = time()
direct_sim_lower_tol_y = jit_vmap_sim_lower_tol(tau_vals).block_until_ready()
sim_time = time() - start
print('Batch simulation time without jit: ' + str(sim_time))
print('Batch simulation average time without jit: ' + str(sim_time / len(tau_vals)))

In [None]:
infidelities = []
for k in range(len(tau_vals)):
    infidelities.append(1-(jnp.abs((direct_sim_lower_tol_y[k].conj().transpose() @ direct_sim_y[k]).trace()) / (dim**2))**2)

In [None]:
infidelities

# 3. RWA simulation

In [None]:
def rwa_sim(taur, sig):
    ham_copy = ham.copy()
    drive_func = lambda t: pulse(sig, taur, tau, t)
    ham_copy.signals = [1., Signal(drive_func, carrier_freq=v_d)]
    ham_copy.frame = H0_B
    ham_copy.cutoff_freq = 1.9 * v_d # add cutoff
    
    results = solve_lmde(ham_copy, 
                     t_span=[0, N*dt],
                     y0=jnp.eye(dim**2, dtype=complex),
                     #    solver_cutoff_freq = 1.9 * v_d,
                     method='jax_odeint',
                    atol=1e-6,
                    rtol=1e-6)

    return results.y[-1]

# just in time compile it
jit_rwa_sim = jit(rwa_sim)

In [None]:
start = time()
single_rwa_sim_y = jit_rwa_sim(taur/5, sig).block_until_ready()
print(time() - start)

In [None]:
start = time()
single_rwa_sim_y = jit_rwa_sim(taur/5, sig).block_until_ready()
print(time() - start)

In [None]:
1-(jnp.abs((single_rwa_sim_y.conj().transpose() @ single_sim_y).trace()) / (dim**2))**2

### 3.1 Measure vmap jit + sim time

In [None]:
jit_vmap_rwa_sim = jit(vmap(lambda tau: rwa_sim(tau, sig)))

In [None]:
start = time()
rwa_sim_y = jit_vmap_rwa_sim(tau_vals).block_until_ready()
sim_time = time() - start
print('Batch simulation time including jit: ' + str(sim_time))
print('Batch simulation average time including jit: ' + str(sim_time / len(tau_vals)))

In [None]:
start = time()
rwa_sim_y = jit_vmap_rwa_sim(tau_vals).block_until_ready()
sim_time = time() - start
print('Batch simulation time without jit: ' + str(sim_time))
print('Batch simulation average time without jit: ' + str(sim_time / len(tau_vals)))

In [None]:
infidelities = []
for k in range(len(tau_vals)):
    infidelities.append(1-(jnp.abs((rwa_sim_y[k].conj().transpose() @ direct_sim_y[k]).trace()) / (dim**2))**2)

In [None]:
infidelities

In the above we see the speed benefits of the RWA, however the fidelity of the simulation is fundamentally limited by the approximation.

# 4. Magnus based approximate simulation

Next, we simulate the system using the Magnus-RWA method. For this approach, over each short interval, we approximate the drive envelope as a polynomial decomposed via Chebyshev polynomials. For a fixed time interval, we compute Magnus terms for these Chebyshev polynomial drive terms (up to some order).

Given these pre-computed Magnus terms, to simulate:
- For each time interval, compute Chebyshev coefficients for drive envelope
- Take linear combination of Magnus terms based on these coefficients and exponentiate
- Multiply by frame correction

First, we set up the chebyshev approximation routine.

In [None]:
def get_DCT_data(deg, domain=[-1, 1]):
    """Construct DCT matrix for a given degree, and 
    compute the shifted chebyshev points for computing the approximation over
    domain
    """
    order = deg + 1
    xcheb = chebpts1(order)
    xcheb_shifted = 0.5*((domain[1] - domain[0]) * xcheb + (domain[1] + domain[0]))
    
    dct_mat = chebvander(xcheb, deg).T
    dct_mat[0] /= order
    dct_mat[1:] /= 0.5 * order
    
    return dct_mat, xcheb_shifted

def perform_DCT(func, dct_mat, xcheb_shifted):
    # assume vectorized
    func_vals = func(xcheb_shifted)
    return np.dot(dct_mat, func_vals)

def multi_interval_DCT(func, dct_mat, xcheb_shifted, start_point_shifts):
    dct_coeffs = [perform_DCT(func, dct_mat, xcheb_shifted + point) for point in start_point_shifts]
    return np.array(dct_coeffs)

def construct_multi_interval_DCT(deg, dt, n_intervals, start_time):
    
    # construct chebyshev data for a single interval
    dct_mat, xcheb_0 = get_DCT_data(deg, domain=[0, dt])
    
    # compute all times at which the function needs to be evaluated
    interval_start_times = start_time + np.arange(n_intervals) * dt
    # time values: columns correspond to interval, rows are the shifted chebyshev values
    t_vals = np.add.outer(xcheb_0, interval_start_times)
    
    def approx_func(func):
        f_vals = func(t_vals)
        
        return dct_mat @ f_vals

    return approx_func

In [None]:
approx_func = construct_multi_interval_DCT(2, dt, N, start_time=0)

In [None]:
vec_pulse_func = jnp.vectorize(pulse_func)

In [None]:
cheb_coeffs = approx_func(vec_pulse_func)

## Chebyshev evaluation

Function for evaluating Chebyshev 

In [None]:
def evaluate_cheb_series(x, c, domain=[-1, 1]):
    """Evaluate a chebyshev series on a given domain.
    Aside from the domain modification this is copied from numpy,
    though I think this is a pretty classic algorithm
    """
    x = (2 * x - domain[1] - domain[0]) / (domain[1] - domain[0])

    if len(c) == 1:
        c0 = c[0]
        c1 = 0
    elif len(c) == 2:
        c0 = c[0]
        c1 = c[1]
    else:
        x2 = 2*x
        
        def scan_fun(carry, idx):
            c0, c1 = carry
            tmp = c0
            c0 = c[-idx] - c1
            c1 = tmp + c1 * x2
            return (c0, c1), None
        
        c0, c1 = scan(scan_fun, init=(c[-2], c[-1]), xs=jnp.arange(3, len(c) + 1))[0]
        
    return c0 + c1*x

def approx_func_as_callable(cheb_coeffs, dt, n_intervals, start_time):
    interval_start_times = start_time + np.arange(n_intervals) * dt
    
    def approx_func(t):
        k = jnp.clip(jnp.array(t // dt, dtype=int), 0, cheb_coeffs.shape[-1]-1)
        return evaluate_cheb_series(t - (k * dt), cheb_coeffs[:, k], [0, dt])
    return approx_func
    

In [None]:
approx_pulse_func = approx_func_as_callable(cheb_coeffs, dt, N, 0)

In [None]:
k = 0
times = np.linspace(k*dt, (k+1)*dt, 100)
vals = vec_pulse_func(times)
approx_vals = approx_pulse_func(times)
plt.plot(times, vals, times, approx_vals)
# difference on a log scale!!

## Compute magnus approx with chebyshev polynomials

In [None]:
def G(t):
    return -1j * H0_B

def T0(t, t0):
    return evaluate_cheb_series(t - t0, [1], domain=[0, dt])

def A0(t, t0):
    return -1j * T0(t, t0) * jnp.cos(w_d * t) * Hdc_B

def T1(t, t0):
    return evaluate_cheb_series(t - t0, [0, 1], domain=[0, dt])

def A1(t, t0):
    return -1j * T1(t, t0) * jnp.cos(w_d * t) * Hdc_B

def T2(t, t0):
    return evaluate_cheb_series(t - t0, jnp.array([0, 0, 1], dtype=float), domain=[0, dt])

def A2(t, t0):
    return -1j * T2(t, t0) * jnp.cos(w_d * t) * Hdc_B

In [None]:
start = time()
mag_results = solve_lmde_perturbation(A_list=[lambda t: A0(t, 0), lambda t: A1(t, 0), lambda t: A2(t, 0)],
                                  perturbation_method='symmetric_magnus',
                                  perturbation_order=5,
                                  perturbation_terms=[[0, 0, 0, 0, 0, 0, 0]],
                                  t_span=[0, dt],
                                  generator=G,
                                  method='jax_odeint',
                                  atol=1e-8,
                                  rtol=1e-8)
print(time() - start)

## Approx simulation

In [None]:
# For perfect solutions the Magnus terms are guaranteed to be anti-hermitian,
# but anti-hermiticity doesn't need to be respected in the initial solver tolerance
# So, when extracting, project onto the real subspace of anti-hermitian operators
def to_aherm(A):
    return 0.5 * (A - A.conj().transpose())

M0 = to_aherm(mag_results.perturbation_results[[0]][-1])
M1 = to_aherm(mag_results.perturbation_results[[1]][-1])
M2 = to_aherm(mag_results.perturbation_results[[2]][-1])
M00 = to_aherm(mag_results.perturbation_results[[0, 0]][-1])
M01 = to_aherm(mag_results.perturbation_results[[0, 1]][-1])
M02 = to_aherm(mag_results.perturbation_results[[0, 2]][-1])
M11 = to_aherm(mag_results.perturbation_results[[1, 1]][-1])
M12 = to_aherm(mag_results.perturbation_results[[1, 2]][-1])
M22 = to_aherm(mag_results.perturbation_results[[2, 2]][-1])
M000 = to_aherm(mag_results.perturbation_results[[0, 0, 0]][-1])
M001 = to_aherm(mag_results.perturbation_results[[0, 0, 1]][-1])
M002 = to_aherm(mag_results.perturbation_results[[0, 0, 2]][-1])
M011 = to_aherm(mag_results.perturbation_results[[0, 1, 1]][-1])
M012 = to_aherm(mag_results.perturbation_results[[0, 1, 2]][-1])
M022 = to_aherm(mag_results.perturbation_results[[0, 2, 2]][-1])
M111 = to_aherm(mag_results.perturbation_results[[1, 1, 1]][-1])
M112 = to_aherm(mag_results.perturbation_results[[1, 1, 2]][-1])
M122 = to_aherm(mag_results.perturbation_results[[1, 2, 2]][-1])
M222 = to_aherm(mag_results.perturbation_results[[2, 2, 2]][-1])
M0000 = to_aherm(mag_results.perturbation_results[[0, 0, 0, 0]][-1])
M0001 = to_aherm(mag_results.perturbation_results[[0, 0, 0, 1]][-1])
M0002 = to_aherm(mag_results.perturbation_results[[0, 0, 0, 2]][-1])
M0011 = to_aherm(mag_results.perturbation_results[[0, 0, 1, 1]][-1])
M0012 = to_aherm(mag_results.perturbation_results[[0, 0, 1, 2]][-1])
M0022 = to_aherm(mag_results.perturbation_results[[0, 0, 2, 2]][-1])
M0111 = to_aherm(mag_results.perturbation_results[[0, 1, 1, 1]][-1])
M0112 = to_aherm(mag_results.perturbation_results[[0, 1, 1, 2]][-1])
M0122 = to_aherm(mag_results.perturbation_results[[0, 1, 2, 2]][-1])
M0222 = to_aherm(mag_results.perturbation_results[[0, 2, 2, 2]][-1])
M1111 = to_aherm(mag_results.perturbation_results[[1, 1, 1, 1]][-1])
M1112 = to_aherm(mag_results.perturbation_results[[1, 1, 1, 2]][-1])
M1122 = to_aherm(mag_results.perturbation_results[[1, 1, 2, 2]][-1])
M1222 = to_aherm(mag_results.perturbation_results[[1, 2, 2, 2]][-1])
M2222 = to_aherm(mag_results.perturbation_results[[2, 2, 2, 2]][-1])
M00000 = to_aherm(mag_results.perturbation_results[[0, 0, 0, 0, 0]][-1])
M00001 = to_aherm(mag_results.perturbation_results[[0, 0, 0, 0, 1]][-1])
M00002 = to_aherm(mag_results.perturbation_results[[0, 0, 0, 0, 2]][-1])
M00011 = to_aherm(mag_results.perturbation_results[[0, 0, 0, 1, 1]][-1])
M00012 = to_aherm(mag_results.perturbation_results[[0, 0, 0, 1, 2]][-1])
M00022 = to_aherm(mag_results.perturbation_results[[0, 0, 0, 2, 2]][-1])
M00111 = to_aherm(mag_results.perturbation_results[[0, 0, 1, 1, 1]][-1])
M00112 = to_aherm(mag_results.perturbation_results[[0, 0, 1, 1, 2]][-1])
M00122 = to_aherm(mag_results.perturbation_results[[0, 0, 1, 2, 2]][-1])
M00222 = to_aherm(mag_results.perturbation_results[[0, 0, 2, 2, 2]][-1])
M01111 = to_aherm(mag_results.perturbation_results[[0, 1, 1, 1, 1]][-1])
M01112 = to_aherm(mag_results.perturbation_results[[0, 1, 1, 1, 2]][-1])
M01122 = to_aherm(mag_results.perturbation_results[[0, 1, 1, 2, 2]][-1])
M01222 = to_aherm(mag_results.perturbation_results[[0, 1, 2, 2, 2]][-1])
M02222 = to_aherm(mag_results.perturbation_results[[0, 2, 2, 2, 2]][-1])
M11111 = to_aherm(mag_results.perturbation_results[[1, 1, 1, 1, 1]][-1])
M11112 = to_aherm(mag_results.perturbation_results[[1, 1, 1, 1, 2]][-1])
M11122 = to_aherm(mag_results.perturbation_results[[1, 1, 1, 2, 2]][-1])
M11222 = to_aherm(mag_results.perturbation_results[[1, 1, 2, 2, 2]][-1])
M12222 = to_aherm(mag_results.perturbation_results[[1, 2, 2, 2, 2]][-1])
M22222 = to_aherm(mag_results.perturbation_results[[2, 2, 2, 2, 2]][-1])
M000000 = to_aherm(mag_results.perturbation_results[[0, 0, 0, 0, 0, 0]][-1])
M0000000 = to_aherm(mag_results.perturbation_results[[0, 0, 0, 0, 0, 0, 0]][-1])

# single step frame operator
Uf_dt = jexpm(-1j * H0_B * dt)

# final step frame operator
Uf = jexpm(1j * H0_B * N * dt)

In [None]:
def magnus_sim(taur, sig):
    drive_func = jnp.vectorize(lambda t: pulse(sig, taur, tau, t))
    c_coeffs = approx_func(drive_func)
    
    def approx_gen(k):
        coeffs = c_coeffs[:, k]
        c0 = coeffs[0]
        c1 = coeffs[1]
        c2 = coeffs[2]
        c00 = c0 * c0
        c01 = c0 * c1
        c02 = c0 * c2
        c11 = c1 * c1
        c12 = c1 * c2
        c22 = c2 * c2
        c000 = c0 * c00
        c001 = c0 * c01
        c002 = c0 * c02
        c011 = c0 * c11
        c012 = c0 * c12
        c022 = c0 * c22
        c111 = c1 * c11
        c112 = c1 * c12
        c122 = c1 * c22
        c222 = c2 * c22
        c0000 = c0 * c000
        c0001 = c0 * c001
        c0002 = c0 * c002
        c0011 = c0 * c011
        c0012 = c0 * c012
        c0022 = c0 * c022
        c0111 = c0 * c111
        c0112 = c0 * c112
        c0122 = c0 * c122
        c0222 = c0 * c222
        c1111 = c1 * c111
        c1112 = c1 * c112
        c1122 = c1 * c122
        c1222 = c1 * c222
        c2222 = c2 * c222
        c00000 = c0 * c0000
        c00001 = c0 * c0001
        c00002 = c0 * c0002
        c00011 = c0 * c0011
        c00012 = c0 * c0012
        c00022 = c0 * c0022
        c00111 = c0 * c0111
        c00112 = c0 * c0112
        c00122 = c0 * c0122
        c00222 = c0 * c0222
        c01111 = c0 * c1111
        c01112 = c0 * c1112
        c01122 = c0 * c1122
        c01222 = c0 * c1222
        c02222 = c0 * c2222
        c11111 = c1 * c1111
        c11112 = c1 * c1112
        c11122 = c1 * c1122
        c11222 = c1 * c1222
        c12222 = c1 * c2222
        c22222 = c2 * c2222
        c000000 = c0 * c00000
        c0000000 = c0 * c000000
    
        return (c0 * M0
                + c1 * M1
                + c2 * M2
                + c00 * M00
                + c01 * M01
                + c02 * M02
                + c11 * M11
                + c12 * M12
                + c22 * M22
                + c000 * M000
                + c001 * M001
                + c002 * M002
                + c011 * M011
                + c012 * M012
                + c022 * M022
                + c111 * M111
                + c112 * M112
                + c122 * M112
                + c222 * M222
                + c0000 * M0000
                + c0001 * M0001
                + c0002 * M0002
                + c0011 * M0011
                + c0012 * M0012
                + c0022 * M0022
                + c0111 * M0111
                + c0112 * M0112
                + c0122 * M0122
                + c0222 * M0222
                + c1111 * M1111
                + c1112 * M1112
                + c1122 * M1122
                + c1222 * M1222
                + c2222 * M2222
                + c00001 * M00001
                + c00002 * M00002
                + c00011 * M00011
                + c00012 * M00012
                + c00022 * M00022
                + c00111 * M00111
                + c00112 * M00112
                + c00122 * M00122
                + c00222 * M00222
                + c01111 * M01111
                + c01112 * M01112
                + c01122 * M01122
                + c01222 * M01222
                + c02222 * M02222
                + c11111 * M11111
                + c11112 * M11112
                + c11122 * M11122
                + c11222 * M11222
                + c12222 * M12222
                + c22222 * M22222
                + c000000 * M000000
                + c0000000 * M0000000)


    def single_step(k):
        return Uf_dt @ jexpm(approx_gen(k))
    
    step_propagators = vmap(single_step)(jnp.flip(jnp.arange(0, N, dtype=int)))
    final_prop = associative_scan(jnp.matmul, step_propagators, axis=0)[-1]
    
    return Uf @ final_prop
    
    
    
jit_magnus_sim = jit(magnus_sim)

In [None]:
start = time()
single_magnus_y = jit_magnus_sim(taur/5, sig).block_until_ready()
print(time() - start)

In [None]:
start = time()
single_magnus_y = jit_magnus_sim(taur/5, sig).block_until_ready()
print(time() - start)

In [None]:
1-(jnp.abs((single_magnus_y.conj().transpose() @ single_sim_y).trace()) / (dim**2))**2

### 4.1 Measure vmap jit + sim time

In [None]:
jit_vmap_magnus_sim = jit(vmap(lambda tau: magnus_sim(tau, sig)))

In [None]:
start = time()
vmap_magnus_sim_y = jit_vmap_magnus_sim(tau_vals).block_until_ready()
sim_time = time() - start
print('Batch simulation time including jit: ' + str(sim_time))
print('Batch simulation average time including jit: ' + str(sim_time / len(tau_vals)))

In [None]:
start = time()
vmap_magnus_sim_y = jit_vmap_magnus_sim(tau_vals).block_until_ready()
sim_time = time() - start
print('Batch simulation time without jit: ' + str(sim_time))
print('Batch simulation average time without jit: ' + str(sim_time / len(tau_vals)))

In [None]:
infidelities = []
for k in range(len(tau_vals)):
    infidelities.append(1-(jnp.abs((vmap_magnus_sim_y[k].conj().transpose() @ direct_sim_y[k]).trace()) / (dim**2))**2)

In [None]:
infidelities

# 5. New code

In [11]:
from qiskit_dynamics_internal.perturbation.dysolve_magsolve import (compile_dysolve_jax,
                                                                    compile_magsolve_jax)

In [12]:
start = time()
mag_solver = compile_magsolve_jax([-1j * Hdc_B],
                                  carrier_freqs=[v_d],
                                 frame_operator=-1j*H0_B,
                                 dt=15. / v_d,
                                 polynomial_degrees=[1],
                                 perturbation_order=3,
                                 #perturbation_terms=[[0, 0, 0, 0, 0, 0, 0]],
                                 method='jax_odeint',
                                 atol=1e-8,
                                 rtol=1e-8)
print(time() - start)

8.837493896484375


Something is wrong here, as this is taking much longer than what should be the exact same computation in the previous section.

Actually it may not be wrong, it could be due to this computing many more terms, due to the sine terms.

In [13]:
def magnus_sim_new(taur, sig):
    drive_func = jnp.vectorize(lambda t: pulse(sig, taur, tau, t))
    sig = Signal(drive_func, carrier_freq=v_d)
    
    yf = mag_solver([sig], y0=jnp.eye(dim**2, dtype=complex), t0=0, n_steps=N)
    return yf

jit_magnus_sim_new = jit(magnus_sim_new)

In [14]:
start = time()
yf = jit_magnus_sim_new(taur, sig).block_until_ready()
print(time()-start)

TypeError: operand type(s) all returned NotImplemented from __array_ufunc__(<ufunc 'add'>, 'outer', Array(Traced<ShapedArray(float64[2])>with<DynamicJaxprTrace(level=0/1)>), Array(Traced<ShapedArray(float64[66])>with<DynamicJaxprTrace(level=0/1)>)): 'Array', 'Array'