# B. State purification with qubit-mediated measurement

In [1]:
# ruff: noqa
import os

os.sys.path.append("../../../..")

In [2]:
from feedback_grape.fgrape import optimize_pulse_with_feedback
import jax.numpy as jnp

## The cavity is initially in a  mixed state --> Goal is to purify the state

We are trying to maximize the property determined by $tr (\rho_{\text{cav}}^2)$ which is the purity

In the following, we consider an adaptive measurement
scheme, demonstrated in a series of experiments on Rydberg atoms interacting
with microwave cavities. In this scheme, the
cavity is coupled to an ancilla qubit, which can then be
read out to update our knowledge of the quantum state of
the cavity. <br> <br>
<p align="center">
    <img src="../../_static/notebook-images/b_tut.png" alt="image" width="350"/>
</p>


In [3]:
# initial state is a thermal state
n_average = 2
N_cavity = 30
# natural logarithm
beta = jnp.log((1 / n_average) + 1)
diags = jnp.exp(-beta * jnp.arange(N_cavity))
normalized_diags = diags / jnp.sum(diags, axis=0)
rho_cav = jnp.diag(normalized_diags)

### Now the thing is here, we don't need a rho_final because the purity or the reward that we want to maximize is $tr (\rho_{\text{cav}}^2)$.
Unlike fidelity expressions which wants to find how close to states are

## Next Step is to construct our POVM

In [4]:
from feedback_grape.utils.operators import cosm, sinm

In [5]:
from feedback_grape.utils.operators import create, destroy
import jax


def povm_measure_operator(measurement_outcome, gamma, delta):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    number_operator = create(N_cavity) @ destroy(N_cavity)
    angle = (gamma * number_operator) + delta / 2
    return jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )

In [6]:
# Answer: why does RNN outputs two the same, then two the same, then two the same and so on? -> calculate during forward proagation then again during back_propagation
from feedback_grape.fgrape import Gate
measure = Gate(
    gate=povm_measure_operator,
    initial_params=jax.random.uniform(
        key=jax.random.PRNGKey(42), shape=(1, 2), minval=0.0, maxval=jnp.pi
    )[0].tolist(),
    measurement_flag=True,
    # param_constraints=[
    #     [0, jnp.pi],
    #     [-2 * jnp.pi, 2 * jnp.pi],
    # ],  # if this is commented and no bounds the avg purity would be 0.91 and highest purity of a measurement sequence is 0.997
)

system_params = [measure]

result = optimize_pulse_with_feedback(
    U_0=rho_cav,
    C_target=None,
    system_params=system_params,
    num_time_steps=5,
    mode="lookup",
    goal="purity",
    max_iter=1000,
    convergence_threshold=1e-20,
    learning_rate=0.01,
    evo_type="density",
    batch_size=10,
)

2025-06-25 10:45:36.811088: E external/xla/xla/service/slow_operation_alarm.cc:73] 
********************************
[Compiling module jit_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2025-06-25 10:45:43.767346: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 2m6.962167s

********************************
[Compiling module jit_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


In [7]:
print(result.final_fidelity)

None


In [8]:
# 0.9163363647226792
print(result.final_purity)

0.9194757960201889


In [9]:
from feedback_grape.utils.purity import purity

# the highest purity can be 0.995 if the initial params that initializes the lookup table
# are between 0 and pi rather than -pi and pi
print("initial purity:", purity(rho=rho_cav))
for i, state in enumerate(result.final_state):
    print(f"Purity of state {i}:", purity(rho=state))

initial purity: 0.20000208604889932
Purity of state 0: 0.9929850865960047
Purity of state 1: 0.9912625349139808
Purity of state 2: 0.9827059473529275
Purity of state 3: 0.6203510241040638
Purity of state 4: 0.9827059473529275
Purity of state 5: 0.8252355569116118
Purity of state 6: 0.9912625349139808
Purity of state 7: 0.9827059473529275
Purity of state 8: 0.9282047975295309
Purity of state 9: 0.8973385831739333


In [10]:
result.returned_params

[[Array([[1.05297151, 0.07352308],
         [1.05297151, 0.07352308],
         [1.05297151, 0.07352308],
         [1.05297151, 0.07352308],
         [1.05297151, 0.07352308],
         [1.05297151, 0.07352308],
         [1.05297151, 0.07352308],
         [1.05297151, 0.07352308],
         [1.05297151, 0.07352308],
         [1.05297151, 0.07352308]], dtype=float64)],
 [Array([[1.11495276, 0.03728731],
         [1.56022894, 0.07388012],
         [1.11495276, 0.03728731],
         [1.56022894, 0.07388012],
         [1.11495276, 0.03728731],
         [1.56022894, 0.07388012],
         [1.56022894, 0.07388012],
         [1.11495276, 0.03728731],
         [1.11495276, 0.03728731],
         [1.56022894, 0.07388012]], dtype=float64)],
 [Array([[1.56069326, 0.11265689],
         [1.30275233, 0.04059488],
         [1.56069326, 0.11265689],
         [1.34320626, 0.24920048],
         [1.56069326, 0.11265689],
         [1.34320626, 0.24920048],
         [1.30275233, 0.04059488],
         [1.5606932