In [1]:
import jax 
from jax import jit
import jax.numpy as jnp
import numpy as np
import qutip as qu

In [2]:
seed = 1
key = jax.random.PRNGKey(seed)

In [3]:
key, s1, s2 = jax.random.split(key, 3)
a = jax.random.uniform(s1, [4 ,4]) - 0.5 + jax.random.uniform(s2, [4 ,4])-0.5

In [4]:
jax.random.split(key, 3)

Array([[ 849246744, 3571547783],
       [ 571865015,  580183114],
       [ 303579492, 3206762241]], dtype=uint32)

In [5]:

@jit
def vec(a):
    return jnp.expand_dims(a.flatten("F"), 1)

@jit
def v_trace(a):
    """
    trace of vectorized  operator (n x n) a
    :param a: vectorized operator
    :return: a dot vec(I)
    """
    one = jnp.identity(
        int(np.sqrt((a).shape))
    )

    one_v = vec(jnp.identity(int(np.sqrt(vec(example_rho.full()).shape[0]))))
    return jnp.dot(vec(one), a)


@jit
def v_evolve(t, L, rho_0v):
    rho = jax.scipy.linalg.expm(L*t)@rho_0v
    return rho/v_trace(rho)

@jit
def compute_p(rho_v, E_v):
    return jnp.clip(jnp.real(jnp.dot(rho_v, E_v)), 0.0, 1.0)
    


In [6]:
example_rho = qu.rand_dm_ginibre(2, seed=seed)
example_L = qu.rand_super(N = 2, seed=seed)

Array([[1.],
       [0.],
       [0.],
       [1.]], dtype=float32)

In [122]:
v_evolve(5.3, example_L.full(), vec(example_rho.full()))

Array([0.49860966+4.4668400e-10j, 0.00312532-3.4105554e-03j,
       0.00312532+3.4105391e-03j, 0.50139046-4.4668402e-10j],      dtype=complex64)

In [128]:
qu.operator_to_vector(example_rho).full().shape

(4, 1)

In [126]:
np.isclose(vec(example_rho.full()), qu.operator_to_vector(example_rho))

array([[ True, False, False, False],
       [False,  True, False, False],
       [False, False,  True, False],
       [False, False, False,  True]])

In [135]:
jnp.expand_dims(vec(example_rho.full()), axis= 1).shape

(4, 1)