In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import dynamiqs as dq
import jaxquantum as jqt
import jax.numpy as jnp

## Dynamiqs 

In [None]:
# parameters
n = 128      # Hilbert space dimension
omega = 1.0  # frequency
kappa = 0.1  # decay rate
alpha = 1.0  # initial coherent state amplitude

# initialize operators, initial state and saving times
a = dq.destroy(n)
H = omega * dq.dag(a) @ a
jump_ops = [jnp.sqrt(kappa) * a]
psi0 = dq.coherent(n, alpha)
tsave = jnp.linspace(0, 20.0, 101)

# run simulation
# result = dq.mesolve(H, jump_ops, psi0, tsave, exp_ops=[dq.dag(a) @ a], solver=dq.solver.Dopri5())
# print(result)

In [None]:
%timeit -n1 -r1 result = dq.mesolve(H, jump_ops, psi0, tsave)
%timeit result = dq.mesolve(H, jump_ops, psi0, tsave)

In [None]:
def map_kappa(kappa):
    # parameters
    n = 128      # Hilbert space dimension
    omega = 1.0  # frequency
    # kappa = 0.1  # decay rate
    alpha = 1.0  # initial coherent state amplitude

    # initialize operators, initial state and saving times
    a = dq.destroy(n)
    H = omega * dq.dag(a) @ a
    jump_ops = [jnp.sqrt(kappa) * a]
    psi0 = dq.coherent(n, alpha)
    tsave = jnp.linspace(0, 20.0, 101)

    # run simulation
    return dq.mesolve(H, jump_ops, psi0, tsave, exp_ops=[dq.dag(a) @ a], solver=dq.solver.Dopri5())


In [None]:
from jax import jit, vmap

In [None]:
jit(vmap(map_kappa))(jnp.linspace(0.1, 0.6, 6))

In [None]:
import matplotlib.pyplot as plt
plt.plot(result.tsave, result.Esave[0])

In [None]:
result.ysave[0].shape

In [None]:
dq.plot_wigner(result.ysave[0])

## jaxquantum 

In [3]:
# parameters
n = 128      # Hilbert space dimension
omega = 1.0  # frequency
kappa = 0.1  # decay rate
alpha = 1.0  # initial coherent state amplitude

# initialize operators, initial state and saving times
a = jqt.destroy(n)
H = omega * jqt.dag(a) @ a
jump_ops = jnp.array([jnp.sqrt(kappa) * a])
psi0 = jqt.coherent(n, alpha)
tsave = jnp.linspace(0, 20.0, 101)

# run simulation
# result = jqt.mesolve(jqt.ket2dm(psi0), tsave, c_ops=jump_ops, H0=H)

# print(result)

In [4]:
result = jqt.mesolve(jqt.ket2dm(psi0), tsave, c_ops=jump_ops, H0=H)

  out = fun(*args, **kwargs)


0.22686290740966797


In [None]:
%timeit -n1 -r1 result = jqt.mesolve(jqt.ket2dm(psi0), tsave, c_ops=jump_ops, H0=H)
%timeit result = jqt.mesolve(jqt.ket2dm(psi0), tsave, c_ops=jump_ops, H0=H)

In [None]:
jqt.mesolve_old

In [None]:
def map_kappa(kappa):

    # parameters
    n = 128      # Hilbert space dimension
    omega = 1.0  # frequency
    # kappa = 0.1  # decay rate
    alpha = 1.0  # initial coherent state amplitude

    # initialize operators, initial state and saving times
    a = jqt.destroy(n)
    H = omega * jqt.dag(a) @ a
    jump_ops = jnp.array([jnp.sqrt(kappa) * a])
    psi0 = jqt.coherent(n, alpha)
    tsave = jnp.linspace(0, 20.0, 101)

    # run simulation
    return jqt.mesolve(jqt.ket2dm(psi0), tsave, c_ops=jump_ops, H0=H)

In [None]:
jit(vmap(map_kappa))(jnp.linspace(0.1, 0.6, 6))