In [1]:
%reload_ext autoreload
%autoreload 2

In [8]:
import jax
import jax.numpy as jnp
import numpy as np
import qutip as qu

from data import Data
from experiments import ExperimentOneQubitTomography

In [3]:
key = jax.random.key(seed=0)
key, subkey = jax.random.split(key)

In [4]:
no_experiments = 500

initial_states = jax.random.choice(subkey, a=jnp.arange(4), shape=(no_experiments,))

key, subkey = jax.random.split(key)
measurement_basis = jax.random.choice(subkey, a=jnp.arange(3), shape=(no_experiments,))

key, subkey = jax.random.split(key)
outcomes = jax.random.choice(subkey, a=jnp.arange(2), shape=(no_experiments,))

key, subkey = jax.random.split(key)
times = jax.random.uniform(subkey, shape=(no_experiments,), minval=0, maxval=23.0)

In [5]:
experiments = ExperimentOneQubitTomography(times, initial_states, outcomes)
data = Data(experiments, outcomes)

In [6]:
print(Data(experiments, outcomes)[34])

Experiment: Time [19.670788]
Initial state [3]
Measurement basis [0] 
outcome: [0]


In [7]:
jax.vmap(lambda d: d.experiment.time)(data)

Array([21.659256  ,  8.946924  ,  5.168275  ,  2.8279095 , 12.376248  ,
       17.427198  ,  6.486207  , 16.328367  , 19.101318  ,  5.7807083 ,
       10.375083  ,  7.2965503 ,  0.67727184,  3.5101905 , 13.456909  ,
       20.552803  , 13.75807   , 19.250679  , 12.603158  , 16.941072  ,
       14.741426  ,  5.618245  ,  0.4464413 , 18.88208   ,  7.3296633 ,
       22.5204    ,  3.3687787 ,  1.7993042 , 12.539093  , 21.35176   ,
        4.8376203 , 10.229603  , 14.25875   , 21.589403  , 19.670788  ,
       12.888397  , 11.236268  , 15.295279  ,  6.808299  , 20.17367   ,
        1.698751  , 16.939709  , 18.539524  , 19.318773  ,  7.367138  ,
       16.954699  ,  8.957796  ,  3.3474858 , 10.130445  , 19.010181  ,
        0.02316284, 20.558643  ,  7.482577  , 21.13469   , 17.599424  ,
        1.1015812 , 11.865004  , 13.9847355 , 13.292115  , 10.910601  ,
        0.9375714 ,  6.0560465 , 14.333798  ,  7.3269567 ,  5.0104284 ,
       12.419705  ,  1.5937093 ,  3.9072928 , 10.936969  , 10.65

In [None]:
# Steps to do :
# Receives the data. Take the experiment and the outcome
# For each experiment, compute the probability of 0 or 1.
# Select the outcome with the outcome from data
# take the log lkl of this. (or compute first the log lkl and then select it does not matter)
# do everything with vmap and sum everything. Put the minus sign.

# For the lkl of an experiment.
# Receive the parameters. Split the hamiltonian and the dissipator
# Built the hamiltonian with the traceless shit whatever.
# Build the dissipator with the generators.
# Construct the full lindbladian
# Evolve the initial state given the experiment
# compute the probability given the measurement chosen.
# Return the number.


In [None]:
from some_functions import (
    _make_dissipator,
    _make_superop_hamiltonian,
    _Pauli_dissipators_array,
    generate_hamiltonian_one_qubit,
    generate_hermitian_matrix,
    generators_hermitian_3d_matrices,
    generators_traceless_hermitian,
)


def generate_complete_lindbladian(parameters_hamiltonian, parameters_dissipator):
    hamiltonian = generate_hamiltonian_one_qubit(
        parameters_hamiltonian, generators_traceless_hermitian
    )
    lindblad_matrix = generate_hermitian_matrix(
        parameters_dissipator, generators_hermitian_3d_matrices
    )

    hamiltonian_superop = _make_superop_hamiltonian(hamiltonian)
    dissipator_superop = _make_dissipator(lindblad_matrix, _Pauli_dissipators_array)
    lindbladian = hamiltonian_superop + dissipator_superop
    return lindbladian


In [None]:
[
    qu.basis(2, 0),
    qu.basis(2, 1),
    (qu.basis(2, 0) + qu.basis(2, 1)).unit(),
    (qu.basis(2, 0) + 1j * qu.basis(2, 1)).unit(),
]

In [64]:
_set_of_initial_states_super = jnp.array(
    [
        qu.ket2dm(ket).full().flatten()
        for ket in [
            qu.basis(2, 0),
            qu.basis(2, 1),
            (qu.basis(2, 0) + qu.basis(2, 1)).unit(),
            (qu.basis(2, 0) + 1j * qu.basis(2, 1)).unit(),
        ]
    ]
)

In [39]:
ex_rho0 = data[5].experiment.initial_state.squeeze()
ex_povm = data[5].experiment.measurement_basis.squeeze()
ex_outcome = data[5].outcome.squeeze()

In [24]:
from some_functions import pauli_projective_povm

In [43]:
data[3].experiment.time.squeeze()

Array(2.8279095, dtype=float32)

In [31]:
pauli_projective_povm.shape

(3, 2, 2, 2)

In [62]:
pauli_projective_povm_super = jnp.array(
    [a.flatten() for a in pauli_projective_povm.reshape(6, 2, 2)]
).reshape(-1, 2, 4)


In [63]:
pauli_projective_povm_super

Array([[[ 0.5+0.j ,  0.5+0.j ,  0.5+0.j ,  0.5+0.j ],
        [ 0.5+0.j , -0.5+0.j , -0.5+0.j ,  0.5+0.j ]],

       [[ 0.5+0.j ,  0. -0.5j,  0. +0.5j,  0.5+0.j ],
        [ 0.5+0.j ,  0. +0.5j,  0. -0.5j,  0.5+0.j ]],

       [[ 1. +0.j ,  0. +0.j ,  0. +0.j ,  0. +0.j ],
        [ 0. +0.j ,  0. +0.j ,  0. +0.j ,  1. +0.j ]]], dtype=complex64)

In [41]:
pauli_projective_povm[ex_povm][ex_outcome].shape

(2, 2)

In [22]:
from some_functions import compute_probability, evolve_state


def likelihood_experiment(experiment, parameters):
    initial_state_index = experiment.initial_state.squeeze()
    measurement_basis_index = experiment.measurement_basis.squeeze()
    time = experiment.time.squeeze()

    initial_state_super = _set_of_initial_states_super[initial_state_index]
    povm_super = pauli_projective_povm_super[measurement_basis_index]

    lindbladian = generate_complete_lindbladian(
        parameters.hamiltonian_pars, parameters.dissipator_pars
    )

    evolved_initial_state_super = evolve_state(lindbladian, time, initial_state_super)

    probabilities_outcomes_basis = compute_probability(
        evolved_initial_state_super, povm_super
    )
    return probabilities_outcomes_basis


Array([[0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j]], dtype=complex64)

In [None]:
def likelihood_data(data, parameters):
    experiment = data.experiment
    outcome = data.outcome.squeeze()
    probability_outcome = likelihood_experiment(experiment, parameters)[outcome]
    return probability_outcome


def neg_log_likelihood_data(data, parameters):
    minus_log_lkl = -1 * jnp.log(likelihood_data(data, parameters))
    return minus_log_lkl


In [None]:
def negative_likelihood_data(data, parameters):
    experiment = data.experiment
    outcome = data.outcome

    log_lkl_data = jax.vmap(_log_lkl_experiment)(data, parameters)
