In [1]:
%reload_ext autoreload
%autoreload 2

In [8]:
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import qutip as qu
from jax.scipy.linalg import expm

from some_functions import *
from some_functions import (
    _make_superop_hamiltonian,
    _make_dissipator,
    _make_pauli_dissipator,
    generate_hamiltonian_one_qubit,
    generate_hermitian_matrix,
    _Pauli_dissipators_array
)

# Number of real parametes hermitian matrix: n**2 (n is dim)

In [9]:
key = jax.random.key(seed=0)
key, subkey = jax.random.split(key)

In [10]:
example_initial_pars = jax.random.uniform(subkey, shape=(12,))

pars = OneQubitParameters(2, example_initial_pars)

hamiltonian = generate_hamiltonian_one_qubit(
    pars.hamiltonian_pars, generators_traceless_hermitian
)

dissipator_matrix = generate_hermitian_matrix(
    pars.dissipator_pars, generators_hermitian_3d_matrices
)

hamiltonian_superop = _make_superop_hamiltonian(hamiltonian)

dissipator = _make_dissipator(dissipator_matrix, _Pauli_dissipators_array)

lindbladian = hamiltonian_superop + dissipator
print(lindbladian)

[[ 0.20480433+0.j         -0.5412418 +0.8710841j  -0.5412419 -0.8710841j
   0.97443646+0.j        ]
 [ 2.0635958 +0.01971477j -1.1999061 -0.6452427j   0.15700915-0.65097016j
  -0.11242938-0.00890517j]
 [ 2.0635958 -0.01971471j  0.15700915+0.65097016j -1.1999061 +0.6452427j
  -0.11242938+0.00890511j]
 [-0.20480433+0.j          0.5412418 -0.8710841j   0.5412419 +0.8710841j
  -0.97443646+0.j        ]]


In [11]:
initial_rho = qu.rand_dm(2, seed=128)
initial_rho_super = initial_rho.full().flatten()

qutip_povm = [
    [
        qu.identity(2) + qu.sigmax(),
        qu.identity(2) - qu.sigmax(),
    ],
    [
        qu.identity(2) + qu.sigmay(),
        qu.identity(2) - qu.sigmay(),
    ],
    [
        qu.identity(2) + qu.sigmaz(),
        qu.identity(2) - qu.sigmaz(),
    ],
]

probs_qu = []
for i in range(3):
    for j in range(2):
        probs_qu.append(qu.expect(qutip_povm[i][j], initial_rho))

probs_qu = np.array(probs_qu) / 2


probs_mine = []
for i in range(3):
    for j in range(2):
        probs_mine.append(
            compute_probability(initial_rho_super, pauli_projective_povm[i, j].flatten())
        )

probs_mine = np.array(probs_mine)

np.isclose(probs_mine, probs_qu)

array([ True,  True,  True,  True,  True,  True])

In [155]:
# Testing the generation of the exponential

In [156]:
qu_lindbladian = qu.rand_super(2, seed=0)
jax_lindbladian = qu_lindbladian.full()

time = 3.5
evolved_state_qu = ((time*qu_lindbladian).expm()*qu.operator_to_vector(initial_rho)).full()

evolved_state_jax = evolve_state(jax_lindbladian, time, initial_rho.full().flatten(order='F'))

np.isclose(evolved_state_qu.squeeze(), evolved_state_jax.squeeze())

array([ True,  True,  True,  True])

In [157]:
# Testing generation of the Lindbladian