In [1]:
%load_ext autoreload
%autoreload 2

In [7]:
from jax import jit, vmap
import jaxquantum as jqt 
import jax.numpy as jnp
import matplotlib.pyplot as plt



def f_jit(f_a):
    N = 100
    a = jqt.destroy(N); n = a.dag() @ a

    omega_a = 2.0*jnp.pi*f_a; H0 = omega_a*n # Hamiltonian

    kappa = 2*jnp.pi*jnp.array([1,2]); batched_loss_op = jnp.sqrt(kappa)*a; 
    c_ops = jqt.Qarray.from_list([batched_loss_op]) # collapse operators

    initial_state = (jqt.displace(N, 0.1) @ jqt.basis(N,0)).to_dm() # initial state

    ts = jnp.linspace(0, 4*2*jnp.pi/omega_a, 101) # Time points

    solver_options = jqt.SolverOptions.create(progress_meter=False) 
    states = jit(jqt.mesolve, static_argnums=(5))(
        H0, initial_state, ts, c_ops=c_ops, solver_options=solver_options) # solve

    a_exp = jqt.overlap(a, states) # expectation values

    return a_exp


def f(f_a):
    N = 100
    a = jqt.destroy(N); n = a.dag() @ a

    omega_a = 2.0*jnp.pi*f_a; H0 = omega_a*n # Hamiltonian

    kappa = 2*jnp.pi*jnp.array([1,2]); batched_loss_op = jnp.sqrt(kappa)*a; 
    c_ops = jqt.Qarray.from_list([batched_loss_op]) # collapse operators

    initial_state = (jqt.displace(N, 0.1) @ jqt.basis(N,0)).to_dm() # initial state

    ts = jnp.linspace(0, 4*2*jnp.pi/omega_a, 101) # Time points

    solver_options = jqt.SolverOptions.create(progress_meter=False) 
    states = jqt.mesolve(
        H0, initial_state, ts, c_ops=c_ops, solver_options=solver_options) # solve

    a_exp = jqt.overlap(a, states) # expectation values

    return a_exp

vmap_f_jit = vmap(f_jit)
jit_vmap_f = jit(vmap(f))
jit_vmap_f_jit = jit(vmap(f_jit))
jit_vmap_jit_f_jit = jit(vmap(jit(f_jit)))

In [3]:
f_as = jnp.linspace(4.5,5.5,21)

In [5]:
%timeit -n1 -r1 jit_vmap_f(f_as).block_until_ready()

7.83 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [4]:
%timeit -n1 -r1 jit_vmap_f_jit(f_as).block_until_ready()

8.16 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [6]:
%timeit -n1 -r1 vmap_f_jit(f_as).block_until_ready()

9.27 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [8]:
%timeit -n1 -r1 jit_vmap_jit_f_jit(f_as).block_until_ready()

7.36 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
