## State Stabilization

Here the result is as follows: The algorithm optimizes the params, such that the POVM always outputs 1, implying that the measurement leaves the target state invariant. this is what we are indeed seeing when printing the measurement outcome and its probability, when batching however, the optimizer struggles to converge.

Also, lookup here is much better than nn for the same hyperparameters.

This is actually a perfect example, of feedback grape modifying the params so that a certain measurement sequence will always be output because this measurement sequence is the one that is going to lead to the best fidelity

In [1]:
# ruff: noqa
import os

os.sys.path.append("..")
from feedback_grape.fgrape import optimize_pulse_with_feedback
from feedback_grape.utils.operators import (
    sigmap,
    sigmam,
    create,
    destroy,
    identity,
    cosm,
    sinm,
)
from feedback_grape.utils.states import basis, fock
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
import jax
from jax.scipy.linalg import expm

In [2]:
N_cav = 30

# Here, dividing alpha into real and imaginary parts complicates the optimization and converges at 0.89 while if we do not use the imaginary part it converges at 0.999

In [3]:
def qubit_unitary(alpha_re):
    alpha = alpha_re
    return tensor(
        identity(N_cav),
        expm(-1j * (alpha * sigmap() + alpha.conjugate() * sigmam()) / 2),
    )

In [4]:
def qubit_cavity_unitary(beta_re):
    beta = beta_re
    return expm(
        -1j
        * (
            beta * (tensor(destroy(N_cav), sigmap()))
            + beta.conjugate() * (tensor(create(N_cav), sigmam()))
        )
        / 2
    )

In [5]:
from feedback_grape.utils.operators import create, destroy


def povm_measure_operator(measurement_outcome, gamma, delta):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    number_operator = tensor(create(N_cav) @ destroy(N_cav), identity(2))
    angle = (gamma * number_operator) + delta / 2
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

In [6]:
from feedback_grape.utils.states import coherent

alpha = 3
psi_target = tensor(
    coherent(N_cav, alpha)
    + coherent(N_cav, -alpha)
    + coherent(N_cav, 1j * alpha)
    + coherent(N_cav, -1j * alpha),
    basis(2),
)  # 4-legged state

# Normalize psi_target before constructing rho_target
psi_target = psi_target / jnp.linalg.norm(psi_target)
rho_target = psi_target @ psi_target.conj().T

In [7]:
rho_target.shape

(60, 60)

### It is important to test what the POVM probability is, to check if your state is normalized. if the probability is bounded between 0 and 1 then normalized

In [8]:
# TODO/Question: should one normalize within the optimization just in case?
from feedback_grape.utils.povm import (
    _probability_of_a_measurement_outcome_given_a_certain_state,
)

_probability_of_a_measurement_outcome_given_a_certain_state(
    rho_target, 1, povm_measure_operator, [0.058, jnp.pi / 2]
)

Array(0.70390199, dtype=float64)

In [9]:
from feedback_grape.utils.fidelity import fidelity

print(fidelity(U_final=rho_target, C_target=rho_target, evo_type="density"))

1.000000011680078


### Without dissipation

In [10]:
# Here the loss directly corressponds to the -fidelity (when converging) because log(1) is 0 and
# the algorithm is choosing params that makes the POVM generate prob = 1
measure = {
    "gate": povm_measure_operator,
    "initial_params": [0.058, jnp.pi / 2],  # gamma and delta
    "measurement_flag": True,
    # "param_constraints": [[0, 0.5], [-1, 1]],
}

qub_unitary = {
    "gate": qubit_unitary,
    "initial_params": [jnp.pi / 3],
    "measurement_flag": False,
    # "param_constraints": [[0, 0.5], [-1, 1]],
}

qub_cav = {
    "gate": qubit_cavity_unitary,
    "initial_params": [jnp.pi / 3],
    "measurement_flag": False,
    # "param_constraints": [[0, 0.5], [-1, 1]],
}

system_params = [measure, qub_unitary, qub_cav]
result = optimize_pulse_with_feedback(
    U_0=rho_target,
    C_target=rho_target,
    system_params=system_params,
    num_time_steps=1,
    mode="lookup",
    goal="fidelity",
    max_iter=1000,
    convergence_threshold=1e-16,
    learning_rate=0.02,
    evo_type="density",
    batch_size=1,
)

Iteration 0, Loss: 0.021114
Iteration 10, Loss: -0.065730
Iteration 20, Loss: -0.292843
Iteration 30, Loss: -0.327637
Iteration 40, Loss: -0.467265
Iteration 50, Loss: -0.484865
Iteration 60, Loss: -0.625497
Iteration 70, Loss: -0.665717
Iteration 80, Loss: -0.353591
Iteration 90, Loss: -0.603921
Iteration 100, Loss: -0.613936
Iteration 110, Loss: -0.664313
Iteration 120, Loss: -0.593153
Iteration 130, Loss: -0.653803
Iteration 140, Loss: -0.674439
Iteration 150, Loss: -0.681150
Iteration 160, Loss: -0.258173
Iteration 170, Loss: -0.255651
Iteration 180, Loss: -0.376554
Iteration 190, Loss: -0.599526
Iteration 200, Loss: 0.514135
Iteration 210, Loss: 0.389031
Iteration 220, Loss: 0.593266
Iteration 230, Loss: -0.515546
Iteration 240, Loss: -0.637913
Iteration 250, Loss: -0.475294
Iteration 260, Loss: -0.657642
Iteration 270, Loss: -0.441192
Iteration 280, Loss: -0.188770
Iteration 290, Loss: -0.607850
Iteration 300, Loss: -0.609874
Iteration 310, Loss: -0.641626
Iteration 320, Loss: -0

In [11]:
print(result.final_fidelity)

0.999995162213917


In [12]:
result.optimized_trainable_parameters["initial_params"]

[Array([1.06202532e-09, 1.46607657e+00], dtype=float64),
 Array([1.04719755], dtype=float64),
 Array([1.04719755], dtype=float64)]

In [13]:
result.returned_params

[[Array([[1.06202532e-09, 1.46607657e+00],
         [1.06202532e-09, 1.46607657e+00],
         [1.06202532e-09, 1.46607657e+00],
         [1.06202532e-09, 1.46607657e+00],
         [1.06202532e-09, 1.46607657e+00],
         [1.06202532e-09, 1.46607657e+00],
         [1.06202532e-09, 1.46607657e+00],
         [1.06202532e-09, 1.46607657e+00],
         [1.06202532e-09, 1.46607657e+00],
         [1.06202532e-09, 1.46607657e+00]], dtype=float64),
  Array([[-0.00622567],
         [-0.00622567],
         [-0.00622567],
         [-0.00622567],
         [-0.00622567],
         [-0.00622567],
         [-0.00622567],
         [-0.00622567],
         [-0.00622567],
         [-0.00622567]], dtype=float64),
  Array([[-4.05896188e-05],
         [-4.05896188e-05],
         [-4.05896188e-05],
         [-4.05896188e-05],
         [-4.05896188e-05],
         [-4.05896188e-05],
         [-4.05896188e-05],
         [-4.05896188e-05],
         [-4.05896188e-05],
         [-4.05896188e-05]], dtype=float64)]

In [14]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho_target, evo_type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, evo_type="density"),
    )

initial fidelity: 1.000000011680078
fidelity of state 0: 0.9999951622139169
fidelity of state 1: 0.9999951622139169
fidelity of state 2: 0.9999951622139169
fidelity of state 3: 0.9999951622139169
fidelity of state 4: 0.9999951622139169
fidelity of state 5: 0.9999951622139169
fidelity of state 6: 0.9999951622139169
fidelity of state 7: 0.9999951622139169
fidelity of state 8: 0.9999951622139169
fidelity of state 9: 0.9999951622139169


### With Dissipation

In [15]:
# Note if tsave = jnp.linspace(0, 1, 1) = [0.0] then the decay is not applied ?
# because the first time step has the original non decayed state
result = optimize_pulse_with_feedback(
    U_0=rho_target,
    C_target=rho_target,
    decay={
        "decay_indices": [0],  # indices of gates before which decay occurs
        # c_ops need to be tensored with the identity operator for the cavity
        # because it is used directly in the lindblad equation
        "c_ops": {
            # weird behavior when gamma is 0.00
            "tm": [tensor(identity(N_cav), jnp.sqrt(0.15) * sigmam())],
            # "tc": [tensor(identity(N_cav), jnp.sqrt(0.15) * sigmap())],
        },  # c_ops for each decay index
        "tsave": jnp.linspace(0, 1, 2),  # time grid for decay
        "Hamiltonian": None,
    },
    system_params=system_params,
    num_time_steps=1,
    mode="lookup",
    goal="fidelity",
    max_iter=1000,
    convergence_threshold=1e-6,
    learning_rate=0.02,
    evo_type="density",
    batch_size=1,
)

Iteration 0, Loss: 0.023975
Iteration 10, Loss: -0.065218
Iteration 20, Loss: -0.277716
Iteration 30, Loss: -0.322778
Iteration 40, Loss: -0.425977
Iteration 50, Loss: -0.448958
Iteration 60, Loss: -0.580473
Iteration 70, Loss: -0.620405
Iteration 80, Loss: -0.321951
Iteration 90, Loss: -0.563551
Iteration 100, Loss: -0.567752
Iteration 110, Loss: -0.606977
Iteration 120, Loss: -0.546509
Iteration 130, Loss: -0.602735
Iteration 140, Loss: -0.624824
Iteration 150, Loss: -0.632265
Iteration 160, Loss: -0.240146
Iteration 170, Loss: -0.214720
Iteration 180, Loss: -0.363248
Iteration 190, Loss: -0.532644
Iteration 200, Loss: -0.836740
Iteration 210, Loss: 0.083543
Iteration 220, Loss: -0.562332
Iteration 230, Loss: -0.619736
Iteration 240, Loss: -0.329103
Iteration 250, Loss: -0.437918
Iteration 260, Loss: -0.494462
Iteration 270, Loss: -0.597588
Iteration 280, Loss: -0.601478
Iteration 290, Loss: -0.629192
Iteration 300, Loss: -0.601520
Iteration 310, Loss: -0.573487
Iteration 320, Loss: 

In [16]:
# 0.9276290167783705
print(result.final_fidelity)

0.9275978401176824


In [17]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho_target, evo_type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, evo_type="density"),
    )

initial fidelity: 1.000000011680078
fidelity of state 0: 0.9275978401176822
fidelity of state 1: 0.9275978401176822
fidelity of state 2: 0.9275978401176822
fidelity of state 3: 0.9275978401176822
fidelity of state 4: 0.9275978401176822
fidelity of state 5: 0.9275978401176822
fidelity of state 6: 0.9275978401176822
fidelity of state 7: 0.9275978401176822
fidelity of state 8: 0.9275978401176822
fidelity of state 9: 0.9275978401176822


In [18]:
result.returned_params

[[Array([[-0.00741632,  1.06379157],
         [-0.00741632,  1.06379157],
         [-0.00741632,  1.06379157],
         [-0.00741632,  1.06379157],
         [-0.00741632,  1.06379157],
         [-0.00741632,  1.06379157],
         [-0.00741632,  1.06379157],
         [-0.00741632,  1.06379157],
         [-0.00741632,  1.06379157],
         [-0.00741632,  1.06379157]], dtype=float64),
  Array([[-2.49457138e-05],
         [-2.49457138e-05],
         [-2.49457138e-05],
         [-2.49457138e-05],
         [-2.49457138e-05],
         [-2.49457138e-05],
         [-2.49457138e-05],
         [-2.49457138e-05],
         [-2.49457138e-05],
         [-2.49457138e-05]], dtype=float64),
  Array([[-0.00126234],
         [-0.00126234],
         [-0.00126234],
         [-0.00126234],
         [-0.00126234],
         [-0.00126234],
         [-0.00126234],
         [-0.00126234],
         [-0.00126234],
         [-0.00126234]], dtype=float64)]]

In [19]:
result.optimized_trainable_parameters["initial_params"]

[Array([-0.00741632,  1.06379157], dtype=float64),
 Array([1.04719755], dtype=float64),
 Array([1.04719755], dtype=float64)]

In [20]:
c_ops = {
    "a": [tensor(identity(N_cav), jnp.sqrt(0.15) * sigmam())],
    "b": [jnp.array([2])],
}

In [21]:
import jax


def prepare_parameters_from_dict(params_dict):
    """
    Convert a nested dictionary of parameters to a flat list and record shapes.

    Args:
        params_dict: Nested dictionary of parameters.

    Returns:
        tuple: Flattened parameters list and list of shapes.
    """
    res = []
    shapes = []
    for value in params_dict.values():
        flat_params = jax.tree_util.tree_leaves(value)
        res.append(jnp.array(flat_params))
        shapes.append(jnp.array(flat_params).shape[0])
    return res, shapes

In [22]:
res, _ = prepare_parameters_from_dict(c_ops)

In [23]:
res.pop(0)

Array([[[0.        +0.j, 0.        +0.j, 0.        +0.j, ...,
         0.        +0.j, 0.        +0.j, 0.        +0.j],
        [0.38729833+0.j, 0.        +0.j, 0.        +0.j, ...,
         0.        +0.j, 0.        +0.j, 0.        +0.j],
        [0.        +0.j, 0.        +0.j, 0.        +0.j, ...,
         0.        +0.j, 0.        +0.j, 0.        +0.j],
        ...,
        [0.        +0.j, 0.        +0.j, 0.        +0.j, ...,
         0.        +0.j, 0.        +0.j, 0.        +0.j],
        [0.        +0.j, 0.        +0.j, 0.        +0.j, ...,
         0.        +0.j, 0.        +0.j, 0.        +0.j],
        [0.        +0.j, 0.        +0.j, 0.        +0.j, ...,
         0.        +0.j, 0.38729833+0.j, 0.        +0.j]]],      dtype=complex128)