In [72]:
%reload_ext autoreload
%autoreload 2

In [74]:
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import qutip as qu
from jax.scipy.linalg import expm

from some_functions import _G, _dag, _spost, _spre, _sprepost, canonical_povm

# Number of real parametes hermitian matrix: n**2 (n is dim)

In [43]:
def _make_pauli_dissipator(A, B):
    return _sprepost(A, B) - 0.5 * (_spre(A @ B) + _spost(A @ B))


dissipators_list = []
for g_i in _G[1:]:
    aux_list = []
    for g_j in _G[1:]:
        aux_list.append(_make_pauli_dissipator(g_i, g_j))
    dissipators_list.append(aux_list)

_Pauli_dissipators_array = jnp.array(dissipators_list, jnp.complex64)

In [4]:
generators_traceless_hermitian = jnp.array(
    [qu.sigmax().full(), qu.sigmay().full(), qu.sigmaz().full()]
)

generators_hermitian_3d_matrices = np.array(
    [
        np.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]]),
        np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]]),
        np.array([[0, 0, 0], [0, 0, 0], [0, 0, 1]]),
        np.array([[0, 1, 0], [1, 0, 0], [0, 0, 0]]),
        np.array([[0, 0, 1], [0, 0, 0], [1, 0, 0]]),
        np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0]]),
        1j * np.array([[0, -1, 0], [1, 0, 0], [0, 0, 0]]),
        1j * np.array([[0, 0, -1], [0, 0, 0], [1, 0, 0]]),
        1j * np.array([[0, 0, 0], [0, 0, -1], [0, 1, 0]]),
    ]
)


In [107]:
from jax.experimental import checkify


class OneQubitParameters(eqx.Module):
    d: int
    N: int
    parameters: jnp.array

    def __init__(self, dimension_system, pars):
        self.d = dimension_system
        self.N = self.d**2
        self.parameters = self.set_pars(pars)

    @property
    def n_indep_hamiltonian(self):
        return self.d**2 - 1

    @property
    def n_indep_dissipator(self):
        return (self.N - 1) ** 2

    def set_pars(self, pars):
        parameters = jnp.zeros([self.n_indep_dissipator + self.n_indep_hamiltonian])
        parameters = parameters.at[:].set(pars)
        return parameters

    @property
    def hamiltonian_pars(self):
        return self.parameters[0 : self.n_indep_hamiltonian]

    @property
    def dissipator_pars(self):
        return self.parameters[self.n_indep_hamiltonian :]


@jax.jit
def generate_hamiltonian_one_qubit(
    hamiltonian_parameters, generators_traceless_hermitian
):
    return jnp.einsum("i, ijk", hamiltonian_parameters, generators_traceless_hermitian)


@jax.jit
def generate_hermitian_matrix(parameters, generators_hermitian):
    return jnp.einsum("i, ijk", parameters, generators_hermitian)


def _make_dissipator(dissipator_matrix, pauli_dissipators):
    return jnp.einsum("ij, ijmn-> mn", dissipator_matrix, pauli_dissipators)


def _make_superop_hamiltonian(hamiltonian_matrix, hbar=1):
    return -1j / hbar * (_spre(hamiltonian_matrix) - _spost(hamiltonian_matrix))


is_probability_correct = lambda p: jnp.logical_and((p >= 0.0), (p <= 1.0))
trim_invalid_probs = lambda prob_array: jnp.where(
    is_probability_correct(prob_array), prob_array, jnp.abs(prob_array) * 0
)

trim_nan_probs = lambda prob_array: jnp.where(
    ~jnp.isnan(prob_array), prob_array, jnp.abs(prob_array) * 0
)


def clean_probabilities(prob_array):
    return trim_nan_probs(trim_invalid_probs(prob_array))


def evolve_state(lindbladian, time, rho_super):
    return expm(lindbladian * time) @ rho_super


def compute_probability(rho_super, povm_super):
    return clean_probabilities(jnp.dot(_dag(rho_super), povm_super)).real


In [13]:
key = jax.random.key(seed=0)
key, subkey = jax.random.split(key)

In [62]:
example_initial_pars = jax.random.uniform(subkey, shape=(12,))

pars = OneQubitParameters(2, example_initial_pars)

hamiltonian = generate_hamiltonian_one_qubit(
    pars.hamiltonian_pars, generators_traceless_hermitian
)

dissipator_matrix = generate_hermitian_matrix(
    pars.dissipator_pars, generators_hermitian_3d_matrices
)

hamiltonian_superop = _make_superop_hamiltonian(hamiltonian)

dissipator = _make_dissipator(dissipator_matrix, _Pauli_dissipators_array)

lindbladian = hamiltonian_superop + dissipator
print(lindbladian)

[[-0.97443646+0.j         -1.516825  +0.876489j   -1.516825  -0.8764889j
   0.97443646+0.j        ]
 [ 1.0880125 +0.01430997j -1.1999061 -0.6452427j   0.15700915-0.65097016j
  -1.0880125 -0.01430997j]
 [ 1.0880125 -0.01430997j  0.15700915+0.65097016j -1.1999061 +0.6452427j
  -1.0880125 +0.01430997j]
 [-0.20480433+0.j         -0.4343413 -0.8656794j  -0.4343413 +0.8656793j
   0.20480433+0.j        ]]


In [110]:
initial_rho = qu.rand_dm(2, seed=128)
initial_rho_super = initial_rho.full().flatten()

qutip_povm = [
    [
        qu.identity(2) + qu.sigmax(),
        qu.identity(2) - qu.sigmax(),
    ],
    [
        qu.identity(2) + qu.sigmay(),
        qu.identity(2) - qu.sigmay(),
    ],
    [
        qu.identity(2) + qu.sigmaz(),
        qu.identity(2) - qu.sigmaz(),
    ],
]

probs_qu = []
for i in range(3):
    for j in range(2):
        probs_qu.append(qu.expect(qutip_povm[i][j], initial_rho))

probs_qu = np.array(probs_qu) / 2


probs_mine = []
for i in range(3):
    for j in range(2):
        probs_mine.append(
            compute_probability(initial_rho_super, canonical_povm[i, j].flatten())
        )

probs_mine = np.array(probs_mine)

np.isclose(probs_mine, probs_qu)

array([ True,  True,  True,  True,  True,  True])