In [1]:
from jax import jit, vmap, grad
from jax.lax import scan
from jax.experimental.ode import odeint
import jaxquantum as jqt
import jax.numpy as jnp

# QuTIP

In [2]:
import qutip as qt
import numpy as np

In [3]:
def gate_step_qutip(gate, p0, H0, c_ops):
    H1 = gate[0]
    ts = gate[1]

    H = [H0, [H1, 'cos(t)']]
    opts = qt.Options(rhs_reuse=True)
    output = qt.mesolve(H, p0, ts, c_ops, options=opts)
    return output.states

def test_qutip(p0, _):
    N_size = 50
    H1 = qt.num(N_size) + 0.0j
    ts = np.linspace(0, 1, 101)
    gate = (H1, ts)
    H0 = qt.destroy(N_size) + qt.create(N_size)
    c_ops = np.array([])
    results = gate_step_qutip(gate, p0, H0, c_ops)
    return results[-1], results


def test_multi_qutip(p0):
    for j in range(100):
        p0, _ = test_qutip(p0, None)
    return p0

In [4]:
p0 = qt.ket2dm(qt.coherent(50,20))

In [5]:
%timeit -n1 -r1 test_qutip(p0, None)
%timeit test_qutip(p0, None)

1.91 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
21.7 ms ± 269 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
%timeit -n1 -r1 test_multi_qutip(p0)
%timeit test_multi_qutip(p0)

2.09 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
2.06 s ± 10.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# JAX

In [7]:
@jit
def gate_step_jax(gate, p0, H0, c_ops):
    H1 = gate[0]
    ts = gate[1]
    
    def H_func(rho, t, H0_val, H1_val, c_ops_val):
        H = H0_val + H1_val*jnp.cos(t)
        rho_dot = -1.0j*(H @ rho - rho @ H)
        return rho_dot
        
    states = odeint(H_func, p0, ts, H0, H1, c_ops)
    return states

@jit
def test_jax(p0, _):
    N_size = 50
    H1 = jqt.num(N_size) + 0.0j
    ts = jnp.linspace(0, 1, 101)
    gate = (H1, ts)
    H0 = jqt.destroy(N_size) + jqt.create(N_size) + 0.0j
    c_ops = jnp.array([])
    results = gate_step_jax(gate, p0, H0, c_ops)
    return results[-1], results

@jit
def test_multi_jax(p0):
    return scan(test_jax, p0, None, length=100)

In [8]:
p0 = jqt.ket2dm(jqt.coherent(50,20))

In [9]:
%timeit -n1 -r1 test_jax(p0 + 0.0j, None)
%timeit test_jax(p0 + 0.0j, None)

381 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
50.7 ms ± 223 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
%timeit -n1 -r1 test_multi_jax(p0 + 0.0j)
%timeit test_multi_jax(p0 + 0.0j)

5.51 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
5.11 s ± 35.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Diffrax

In [11]:
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController

In [12]:
@jit
def gate_step_diffrax(gate, p0, H0, c_ops):
    H1 = gate[0]
    ts = gate[1]
    
    def H_func(t, rho, args):
        H0_val = args[0]
        H1_val = args[1]
        c_ops_val = args[2]

        H = H0_val + H1_val*jnp.cos(t)
        rho_dot = -1.0j*(H @ rho - rho @ H)
        return rho_dot

    term = ODETerm(H_func)
    solver = Dopri5()
    saveat = SaveAt(ts=ts)
    stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)

    sol = diffeqsolve(term, solver, t0=ts[0], t1=ts[-1], dt0=ts[1]-ts[0], y0=p0, saveat=saveat,
                    stepsize_controller=stepsize_controller, args=[H0, H1, c_ops])

    print(sol.ts)  # DeviceArray([0.   , 1.   , 2.   , 3.    ])
    print(sol.ys)  # DeviceArray([1.   , 0.368, 0.135, 0.0498])
    
    states = sol.ys
    return states + 0.0j

@jit
def test_diffrax(p0, _):
    N_size = 50
    H1 = jqt.num(N_size) + 0.0j
    ts = jnp.linspace(0, 1, 101)
    gate = (H1, ts)
    H0 = (jqt.destroy(N_size) + 0.0j) + (jqt.create(N_size) + 0.0j)
    c_ops = jnp.array([])
    results = gate_step_diffrax(gate, p0, H0, c_ops)
    return results[-1], results

@jit
def test_multi_diffrax(p0):
    return scan(test_diffrax, p0, None, length=100)

In [13]:
p0 = jqt.ket2dm(jqt.coherent(50,20))

In [14]:
%timeit -n1 -r1 test_diffrax(p0 + 0.0j, None)
%timeit test_diffrax(p0 + 0.0j, None)

  return lax_internal._convert_element_type(out, dtype, weak_type)


Traced<ShapedArray(float64[101], weak_type=True)>with<DynamicJaxprTrace(level=0/2)>
Traced<ShapedArray(float64[101,50,50], weak_type=True)>with<DynamicJaxprTrace(level=0/2)>
909 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
10.9 ms ± 171 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
%timeit -n1 -r1 test_multi_diffrax(p0 + 0.0j)
%timeit test_multi_diffrax(p0 + 0.0j)

Traced<ShapedArray(float64[101], weak_type=True)>with<DynamicJaxprTrace(level=1/3)>
Traced<ShapedArray(float64[101,50,50], weak_type=True)>with<DynamicJaxprTrace(level=1/3)>


  return lax_internal._convert_element_type(out, dtype, weak_type)


1.99 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
1.05 s ± 7.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
# vector_field = lambda t, y, args: -y
# term = ODETerm(vector_field)
# solver = Dopri5()
# saveat = SaveAt(ts=[0., 1., 2., 3.])
# stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)

# sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
#                   stepsize_controller=stepsize_controller)

# print(sol.ts)  # DeviceArray([0.   , 1.   , 2.   , 3.    ])
# print(sol.ys)  # DeviceArray([1.   , 0.368, 0.135, 0.0498])

In [20]:
# from diffrax import diffeqsolve, ODETerm, Dopri5
# import jax.numpy as jnp

# def f(t, y, args):
#     return -y

# term = ODETerm(f)
# solver = Dopri5()
# y0 = jnp.array([2., 3.])
# solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)