## State Stabilization

Here the result is as follows: The algorithm optimizes the params, such that the POVM always outputs 1, implying that the measurement leaves the target state invariant. this is what we are indeed seeing when printing the measurement outcome and its probability, when batching however, the optimizer struggles to converge.

Also, lookup here is much better than nn for the same hyperparameters.

This is actually a perfect example, of feedback grape modifying the params so that a certain measurement sequence will always be output because this measurement sequence is the one that is going to lead to the best fidelity

In [1]:
# ruff: noqa
import os

os.sys.path.append("..")
from feedback_grape.fgrape import optimize_pulse_with_feedback
from feedback_grape.utils.operators import (
    sigmap,
    sigmam,
    create,
    destroy,
    identity,
    cosm,
    sinm,
)
from feedback_grape.utils.states import basis, fock
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
from jax.scipy.linalg import expm

In [2]:
N_cav = 30

In [3]:
def qubit_unitary(alpha):
    """
    TODO: see if alpha, can be sth elser other than scalar, and if the algo understands this
    see if there can be multiple params like alpha and beta input
    """
    return expm(
        -1j
        * (
            alpha * tensor(identity(N_cav), sigmap())
            + alpha.conjugate() * tensor(identity(N_cav), sigmam())
        )
        / 2
    )

In [4]:
def qubit_cavity_unitary(beta):
    return expm(
        -1j
        * (
            beta
            * (
                tensor(destroy(N_cav), identity(2))
                @ tensor(identity(N_cav), sigmap())
            )
            + beta.conjugate()
            * (
                tensor(create(N_cav), identity(2))
                @ tensor(identity(N_cav), sigmam())
            )
        )
        / 2
    )

In [5]:
from feedback_grape.utils.operators import create, destroy


def povm_measure_operator(measurement_outcome, gamma, delta):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    number_operator = tensor(create(N_cav) @ destroy(N_cav), identity(2))
    angle = (gamma * number_operator) + delta / 2
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

In [6]:
from feedback_grape.utils.states import coherent

alpha = 3
psi_target = tensor(
    coherent(N_cav, alpha)
    + coherent(N_cav, -alpha)
    + coherent(N_cav, 1j * alpha)
    + coherent(N_cav, -1j * alpha),
    basis(2),
)  # 4-legged state

# Normalize psi_target before constructing rho_target
psi_target = psi_target / jnp.linalg.norm(psi_target)
rho_target = psi_target @ psi_target.conj().T

In [7]:
rho_target.shape

(60, 60)

### It is important to test what the POVM probability is, to check if your state is normalized. if the probability is bounded between 0 and 1 then normalized

In [8]:
# TODO/Question: should one normalize within the optimization just in case?
from feedback_grape.utils.povm import (
    _probability_of_a_measurement_outcome_given_a_certain_state,
)

_probability_of_a_measurement_outcome_given_a_certain_state(
    rho_target, 1, povm_measure_operator, [0.058, jnp.pi / 2]
)

Array(0.70390199, dtype=float64)

In [9]:
from feedback_grape.utils.fidelity import fidelity

print(fidelity(U_final=rho_target, C_target=rho_target, type="density"))

1.0000000079954297


### Without dissipation

In [10]:
# Here the loss directly corressponds to the -fidelity (when converging) because log(1) is 0 and 
# the algorithm is choosing params that makes the POVM generate prob = 1
result = optimize_pulse_with_feedback(
    U_0=rho_target,
    C_target=rho_target,
    parameterized_gates=[
        povm_measure_operator,
        qubit_unitary,
        qubit_cavity_unitary,
    ],
    measurement_indices=[0],
    initial_params={
        "POVM": [0.058, jnp.pi / 2],
        "U_q": [jnp.pi / 3],
        "U_qc": [jnp.pi / 3],
    },
    num_time_steps=1,
    mode="lookup",
    goal="fidelity",
    optimizer="adam",
    max_iter=1000,
    convergence_threshold=1e-16,
    learning_rate=0.01,
    type="density",
    batch_size=1,
)

Iteration 0, Loss: -0.191559
Iteration 10, Loss: -0.308840
Iteration 20, Loss: -0.386906
Iteration 30, Loss: -0.430668
Iteration 40, Loss: -0.476257
Iteration 50, Loss: -0.514217
Iteration 60, Loss: -0.549715
Iteration 70, Loss: -0.582782
Iteration 80, Loss: -0.613787
Iteration 90, Loss: -0.642515
Iteration 100, Loss: -0.669039
Iteration 110, Loss: -0.693400
Iteration 120, Loss: -0.715645
Iteration 130, Loss: -0.735849
Iteration 140, Loss: -0.754097
Iteration 150, Loss: -0.770484
Iteration 160, Loss: -0.785118
Iteration 170, Loss: -0.798112
Iteration 180, Loss: -0.809586
Iteration 190, Loss: -0.819659
Iteration 200, Loss: -0.828454
Iteration 210, Loss: -0.836089
Iteration 220, Loss: -0.842682
Iteration 230, Loss: -0.848343
Iteration 240, Loss: -0.853184
Iteration 250, Loss: -0.857337
Iteration 260, Loss: -0.861203
Iteration 270, Loss: -0.888402
Iteration 280, Loss: -0.975265
Iteration 290, Loss: -0.987029
Iteration 300, Loss: -0.988439
Iteration 310, Loss: -0.993398
Iteration 320, Loss

In [11]:
print(result.final_fidelity)

1.0000000075622455


In [12]:
result.optimized_trainable_parameters["initial_params"]

[Array([6.71204658e-17, 1.67551608e+00], dtype=float64),
 Array([1.04719755], dtype=float64),
 Array([1.04719755], dtype=float64)]

In [13]:
result.returned_params

[[Array([[6.71204658e-17, 1.67551608e+00],
         [6.71204658e-17, 1.67551608e+00],
         [6.71204658e-17, 1.67551608e+00],
         [6.71204658e-17, 1.67551608e+00],
         [6.71204658e-17, 1.67551608e+00],
         [6.71204658e-17, 1.67551608e+00],
         [6.71204658e-17, 1.67551608e+00],
         [6.71204658e-17, 1.67551608e+00],
         [6.71204658e-17, 1.67551608e+00],
         [6.71204658e-17, 1.67551608e+00]], dtype=float64),
  Array([[1.17388843e-08],
         [1.17388843e-08],
         [1.17388843e-08],
         [1.17388843e-08],
         [1.17388843e-08],
         [1.17388843e-08],
         [1.17388843e-08],
         [1.17388843e-08],
         [1.17388843e-08],
         [1.17388843e-08]], dtype=float64),
  Array([[-3.75038534e-16],
         [-3.75038534e-16],
         [-3.75038534e-16],
         [-3.75038534e-16],
         [-3.75038534e-16],
         [-3.75038534e-16],
         [-3.75038534e-16],
         [-3.75038534e-16],
         [-3.75038534e-16],
         [-3.7

In [14]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho_target, type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, type="density"),
    )

initial fidelity: 1.0000000079954297
fidelity of state 0: 1.0000000075622453
fidelity of state 1: 1.0000000075622453
fidelity of state 2: 1.0000000075622453
fidelity of state 3: 1.0000000075622453
fidelity of state 4: 1.0000000075622453
fidelity of state 5: 1.0000000075622453
fidelity of state 6: 1.0000000075622453
fidelity of state 7: 1.0000000075622453
fidelity of state 8: 1.0000000075622453
fidelity of state 9: 1.0000000075622453


### With Dissipation

In [15]:
# Note if tsave = jnp.linspace(0, 1, 1) = [0.0] then the decay is not applied ?
# because the first time step has the original non decayed state
result = optimize_pulse_with_feedback(
    U_0=rho_target,
    C_target=rho_target,
    parameterized_gates=[
        povm_measure_operator,
        qubit_unitary,
        qubit_cavity_unitary,
    ],
    measurement_indices=[0],
    decay={
        "decay_indices": [0],  # indices of gates before which decay occurs
        # c_ops need to be tensored with the identity operator for the cavity
        # because it is used directly in the lindblad equation
        "c_ops": {
            # weird behavior when gamma is 0.00
            "tm": [tensor(identity(N_cav), jnp.sqrt(0.15) * sigmam())],
            # "tc": [tensor(identity(N_cav), jnp.sqrt(0.15) * sigmap())],
        },  # c_ops for each decay index
        "tsave": jnp.linspace(0, 1, 2),  # time grid for decay
        "Hamiltonian": None,
    },
    initial_params={
        # if POVM params is 0 and jnp.pi/3 the probability that the measurement outcome is 1 is 1.0
        # therefore the algorithm trains really well on that measurement outcome
        # this POVM, initial_params, leads to probabilities similar to the fig 6b example
        "POVM": [0.058, jnp.pi / 2],
        "U_q": [jnp.pi / 3],
        "U_qc": [jnp.pi / 3],
    },
    num_time_steps=1,
    mode="lookup",
    goal="fidelity",
    optimizer="adam",
    max_iter=1000,
    convergence_threshold=1e-6,
    learning_rate=0.02,
    type="density",
    batch_size=1,
)

Iteration 0, Loss: -0.216127
Iteration 10, Loss: -0.332174
Iteration 20, Loss: -0.522050
Iteration 30, Loss: -0.610393
Iteration 40, Loss: -0.685936
Iteration 50, Loss: -0.741057
Iteration 60, Loss: -0.788380
Iteration 70, Loss: -0.826686
Iteration 80, Loss: -0.856720
Iteration 90, Loss: -0.879411
Iteration 100, Loss: -0.895846
Iteration 110, Loss: -0.907367
Iteration 120, Loss: -0.915141
Iteration 130, Loss: -0.920202
Iteration 140, Loss: -0.923379
Iteration 150, Loss: -0.925302
Iteration 160, Loss: -0.926425
Iteration 170, Loss: -0.927056
Iteration 180, Loss: -0.927399
Iteration 190, Loss: -0.927577
Iteration 200, Loss: -0.927667
Iteration 210, Loss: -0.927709
Iteration 220, Loss: -0.927729


In [16]:
# 0.9276290167783705
print(result.final_fidelity)

0.9277343738602085


In [17]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho_target, type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, type="density"),
    )

initial fidelity: 1.0000000079954297
fidelity of state 0: 0.9277343738602085
fidelity of state 1: 0.9277343738602085
fidelity of state 2: 0.9277343738602085
fidelity of state 3: 0.9277343738602085
fidelity of state 4: 0.9277343738602085
fidelity of state 5: 0.9277343738602085
fidelity of state 6: 0.9277343738602085
fidelity of state 7: 0.9277343738602085
fidelity of state 8: 0.9277343738602085
fidelity of state 9: 0.9277343738602085


In [18]:
result.returned_params

[[Array([[6.30137489e-07, 1.46607620e+00],
         [6.30137489e-07, 1.46607620e+00],
         [6.30137489e-07, 1.46607620e+00],
         [6.30137489e-07, 1.46607620e+00],
         [6.30137489e-07, 1.46607620e+00],
         [6.30137489e-07, 1.46607620e+00],
         [6.30137489e-07, 1.46607620e+00],
         [6.30137489e-07, 1.46607620e+00],
         [6.30137489e-07, 1.46607620e+00],
         [6.30137489e-07, 1.46607620e+00]], dtype=float64),
  Array([[0.0096846],
         [0.0096846],
         [0.0096846],
         [0.0096846],
         [0.0096846],
         [0.0096846],
         [0.0096846],
         [0.0096846],
         [0.0096846],
         [0.0096846]], dtype=float64),
  Array([[2.98303997e-06],
         [2.98303997e-06],
         [2.98303997e-06],
         [2.98303997e-06],
         [2.98303997e-06],
         [2.98303997e-06],
         [2.98303997e-06],
         [2.98303997e-06],
         [2.98303997e-06],
         [2.98303997e-06]], dtype=float64)]]

In [19]:
result.optimized_trainable_parameters["initial_params"]

[Array([6.30137489e-07, 1.46607620e+00], dtype=float64),
 Array([1.04719755], dtype=float64),
 Array([1.04719755], dtype=float64)]

In [20]:
c_ops = {
    "a": [tensor(identity(N_cav), jnp.sqrt(0.15) * sigmam())],
    "b": [jnp.array([2])],
}

In [21]:
import jax
def prepare_parameters_from_dict(params_dict):
    """
    Convert a nested dictionary of parameters to a flat list and record shapes.

    Args:
        params_dict: Nested dictionary of parameters.

    Returns:
        tuple: Flattened parameters list and list of shapes.
    """
    res = []
    shapes = []
    for value in params_dict.values():
        flat_params = jax.tree_util.tree_leaves(value)
        res.append(jnp.array(flat_params))
        shapes.append(jnp.array(flat_params).shape[0])
    return res, shapes

In [22]:
res, _ = prepare_parameters_from_dict(c_ops)

In [23]:
res.pop(0)

Array([[[0.        +0.j, 0.        +0.j, 0.        +0.j, ...,
         0.        +0.j, 0.        +0.j, 0.        +0.j],
        [0.38729833+0.j, 0.        +0.j, 0.        +0.j, ...,
         0.        +0.j, 0.        +0.j, 0.        +0.j],
        [0.        +0.j, 0.        +0.j, 0.        +0.j, ...,
         0.        +0.j, 0.        +0.j, 0.        +0.j],
        ...,
        [0.        +0.j, 0.        +0.j, 0.        +0.j, ...,
         0.        +0.j, 0.        +0.j, 0.        +0.j],
        [0.        +0.j, 0.        +0.j, 0.        +0.j, ...,
         0.        +0.j, 0.        +0.j, 0.        +0.j],
        [0.        +0.j, 0.        +0.j, 0.        +0.j, ...,
         0.        +0.j, 0.38729833+0.j, 0.        +0.j]]],      dtype=complex128)