# Example E: State stabilization with SNAP gates and displacement gates

The use of feedback GRAPE applied to the Jaynes-
Cummings scenario allows us to discover strategies
extending the lifetime of a range of quantum states. How-
ever, for more complex quantum states such as kitten
states, the infidelity becomes significant after just a few
dissipative evolution steps in spite of the feedback [cf.
Fig. 6(c)]. This raises the question of whether the limited
quality of the stabilization is to be attributed to a failure
of our feedback-GRAPE learning algorithm to properly
explore the control-parameter landscape or, rather, to the
limited expressivity of the controls. With the goal of
addressing this question, we test our method on the state-
stabilization task using a more expressive control scheme.

In [19]:
# ruff: noqa
import os

os.sys.path.append("..")
from feedback_grape.fgrape import optimize_pulse_with_feedback
from feedback_grape.utils.operators import sigmam, identity, cosm, sinm
from feedback_grape.utils.states import coherent, basis
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
import jax
from jax.scipy.linalg import expm

jax.config.update("jax_enable_x64", True)

## Initialize states

In [20]:
from feedback_grape.utils.fidelity import ket2dm

N_cav = 30  # number of cavity modes
N_snap = 15

alpha = 2
psi_target = coherent(N_cav, alpha) + coherent(N_cav, -alpha)

# Normalize psi_target before constructing rho_target
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = ket2dm(psi_target)

rho_target = tensor(rho_target, ket2dm(basis(2)))

In [21]:
# Parity Operator
from feedback_grape.utils.operators import create, destroy


def parity_operator(N_cav):
    return tensor(
        jax.scipy.linalg.expm(1j * jnp.pi * (create(N_cav) @ destroy(N_cav))),
        identity(2),
    )

In [22]:
# Confirm that the kitten2 state has an even parity
parity_op = parity_operator(N_cav)
parity_check = jnp.isclose(
    jnp.trace((parity_op @ rho_target) @ rho_target), 1.0
)
print("Parity check for the kitten2 state:", parity_check)
print("parity_check trace :", jnp.real(jnp.trace(parity_op @ rho_target)))

Parity check for the kitten2 state: True
parity_check trace : 1.0000000000000013


## Initialize the parameterized Gates

In [23]:
def displacement_gate(alpha_re, alpha_im):
    """Displacement operator for a coherent state."""
    alpha = alpha_re + 1j * alpha_im
    gate = jax.scipy.linalg.expm(
        alpha * create(N_cav) - alpha.conj() * destroy(N_cav)
    )
    return tensor(gate, identity(2))


def displacement_gate_dag(alpha_re, alpha_im):
    """Displacement operator for a coherent state."""
    alpha = alpha_re + 1j * alpha_im
    gate = (
        jax.scipy.linalg.expm(
            alpha * create(N_cav) - alpha.conj() * destroy(N_cav)
        )
        .conj()
        .T
    )
    return tensor(gate, identity(2))

In [24]:
# TODO/ QUESTION: see whether to remove the dereferencing thing or keep it that way?
def snap_gate(
    phase0,
    phase1,
    phase2,
    phase3,
    phase4,
    phase5,
    phase6,
    phase7,
    phase8,
    phase9,
    phase10,
    phase11,
    phase12,
    phase13,
    phase14,
):
    phase_list = [
        phase0,
        phase1,
        phase2,
        phase3,
        phase4,
        phase5,
        phase6,
        phase7,
        phase8,
        phase9,
        phase10,
        phase11,
        phase12,
        phase13,
        phase14,
    ]
    diags = jnp.ones(shape=(N_cav - len(phase_list)))
    exponentiated = jnp.exp(1j * jnp.array(phase_list))
    diags = jnp.concatenate((exponentiated, diags))
    return tensor(jnp.diag(diags), identity(2))

In [25]:
from feedback_grape.utils.operators import create, destroy


def povm_measure_operator(measurement_outcome, gamma, delta):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    number_operator = tensor(create(N_cav) @ destroy(N_cav), identity(2))
    angle = (gamma * number_operator) + delta / 2
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

## Initialize RNN of choice

In [26]:
import flax.linen as nn


# You can do whatever you want inside so long as you maintaing the hidden_size and output size shapes
class RNN(nn.Module):
    hidden_size: int  # number of features in the hidden state
    output_size: int  # number of features in the output ( 2 in the case of gamma and beta)

    @nn.compact
    def __call__(self, measurement, hidden_state):
        """
        If your GRU has a hidden state increasing number of features in the hidden stateH means:

        - You're allowing the model to store more information across time steps

        - Each time step can represent more complex features, patterns, or dependencies

        - You're giving the GRU more representational capacity
        """
        gru_cell = nn.GRUCell(
            features=self.hidden_size,
            gate_fn=nn.sigmoid,
            activation_fn=nn.tanh,
        )
        self.make_rng('dropout')

        if measurement.ndim == 1:
            measurement = measurement.reshape(1, -1)

        new_hidden_state, _ = gru_cell(hidden_state, measurement)
        new_hidden_state = nn.Dropout(rate=0.1, deterministic=False)(
            new_hidden_state
        )
        # this returns the povm_params after linear regression through the hidden state which contains
        # the information of the previous time steps and this is optimized to output best povm_params
        # new_hidden_state = nn.Dense(features=self.hidden_size)(new_hidden_state)
        new_hidden_state = nn.Dense(
            features=self.hidden_size,
            kernel_init=nn.initializers.glorot_uniform(),
        )(new_hidden_state)
        new_hidden_state = nn.relu(new_hidden_state)
        new_hidden_state = nn.Dense(
            features=self.hidden_size,
            kernel_init=nn.initializers.glorot_uniform(),
        )(new_hidden_state)
        new_hidden_state = nn.relu(new_hidden_state)
        output = nn.Dense(
            features=self.output_size,
            kernel_init=nn.initializers.glorot_uniform(),
            bias_init=nn.initializers.constant(0.1),
        )(new_hidden_state)
        output = nn.relu(output)
        # output = jnp.asarray(output)
        return output[0], new_hidden_state

### In this notebook, we decreased the convergence threshold and evaluate for num_time_steps = 2

In [27]:
# Note if tsave = jnp.linspace(0, 1, 1) = [0.0] then the decay is not applied ?
# because the first time step has the original non decayed state
key = jax.random.PRNGKey(42)
snap_init = jax.random.uniform(
    key, shape=(N_snap,), minval=-jnp.pi, maxval=jnp.pi
)
# TODO/QUESTION: In documentation, clarify that the initial_params are the params up to the
# point where measurement occurs, compared with other modes where the initial_params
# are the initial params for the entire system for all time steps.
measure = {
    "gate": povm_measure_operator,
    "initial_params": jax.random.uniform(
        key,
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jnp.pi,
        maxval=jnp.pi,
    )[0].tolist(),
    "measurement_flag": True,
    # "param_constraints": [[0, 0.5], [-1, 1]],
}

displacement = {
    "gate": displacement_gate,
    "initial_params": jax.random.uniform(
        key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
    )[0].tolist(),
    "measurement_flag": False,
}

snap = {
    "gate": snap_gate,
    "initial_params": snap_init.tolist(),
    "measurement_flag": False,
}

displacement_dag = {
    "gate": displacement_gate_dag,
    "initial_params": jax.random.uniform(
        key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
    )[0].tolist(),
    "measurement_flag": False,
}

system_params = [measure, displacement, snap, displacement_dag]


result = optimize_pulse_with_feedback(
    U_0=rho_target,
    C_target=rho_target,
    decay={
        "decay_indices": [0, 1],  # indices of gates before which decay occurs
        "c_ops": {
            "tm": [tensor(identity(N_cav), jnp.sqrt(0.005) * sigmam())],
            "tc": [tensor(identity(N_cav), jnp.sqrt(0.005) * sigmam())],
        },
        "tsave": jnp.linspace(0, 1, 2),  # time grid for decay
        "Hamiltonian": None,
    },
    system_params=system_params,
    num_time_steps=2,
    mode="nn",
    goal="fidelity",
    max_iter=1000,
    convergence_threshold=1e-6,
    learning_rate=0.09,
    type="density",
    batch_size=16,
    rnn=RNN,
    rnn_hidden_size=30,
)

Iteration 0, Loss: 0.051286
Iteration 10, Loss: -0.046407
Iteration 20, Loss: -0.289864
Iteration 30, Loss: -0.796614
Iteration 40, Loss: -0.857959
Iteration 50, Loss: -0.955282
Iteration 60, Loss: -0.748941
Iteration 70, Loss: -0.908621
Iteration 80, Loss: -0.785943
Iteration 90, Loss: -0.416823
Iteration 100, Loss: -0.676020
Iteration 110, Loss: -0.240745
Iteration 120, Loss: -0.233151
Iteration 130, Loss: -0.682477
Iteration 140, Loss: -0.281225
Iteration 150, Loss: -0.315091
Iteration 160, Loss: -0.413254
Iteration 170, Loss: -0.227406
Iteration 180, Loss: -0.266969
Iteration 190, Loss: -0.964452
Iteration 200, Loss: -0.973934
Iteration 210, Loss: -0.555330
Iteration 220, Loss: -0.968880
Iteration 230, Loss: -0.779439
Iteration 240, Loss: -0.953672
Iteration 250, Loss: -0.983785
Iteration 260, Loss: -0.958988
Iteration 270, Loss: -0.615746
Iteration 280, Loss: -0.330535
Iteration 290, Loss: -0.257443
Iteration 300, Loss: -0.659959
Iteration 310, Loss: -0.798333
Iteration 320, Loss:

In [28]:
result

FgResult(optimized_trainable_parameters={'initial_params': [Array([ 1.97485858e-04, -2.51347818e+00], dtype=float64), Array([-0.46038428, -3.04728161], dtype=float64), Array([-0.46038428, -3.04728161,  0.46586312,  1.44399472,  1.17998434,
        0.90438751,  0.90876566,  0.58447858, -1.65424795, -1.49014079,
       -1.62133312, -2.78063235,  1.41394885,  2.00049328,  0.55503047],      dtype=float64), Array([-0.46038428, -3.04728161], dtype=float64)], 'rnn_params': {'params': {'Dense_0': {'bias': Array([-0.86135095, -0.28204468, -0.27452636, -0.35027426, -0.64815074,
       -0.3532077 , -0.9384373 ,  0.4386696 , -0.4282595 , -0.46899855,
       -0.64997715, -0.18722911, -0.0946504 , -0.8911167 , -0.70567155,
       -0.7837897 , -0.7694482 , -0.5528712 , -0.49137455, -0.83089715,
       -0.54747313, -0.49697948, -0.88991284, -0.9998346 , -0.2879899 ,
       -0.7133515 ,  0.3154305 , -0.50793445, -0.49798992,  0.05102238],      dtype=float32), 'kernel': Array([[-7.65240133e-01, -5.54185

In [29]:
result.final_fidelity

Array(0.98983774, dtype=float64)

In [30]:
result.final_state.shape

(10, 60, 60)

In [31]:
from feedback_grape.utils.fidelity import ket2dm

N_cav = 30  # number of cavity modes
N_snap = 15

alpha = 2
psi_target = coherent(N_cav, alpha) + coherent(N_cav, -alpha)

# Normalize psi_target before constructing rho_target
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = ket2dm(psi_target)

rho_target = tensor(rho_target, ket2dm(basis(2)))

In [32]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho_target, type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, type="density"),
    )

initial fidelity: 1.0000000191119465
fidelity of state 0: 0.9896874628446899
fidelity of state 1: 0.9896897540287319
fidelity of state 2: 0.9894047072413638
fidelity of state 3: 0.9898964331514902
fidelity of state 4: 0.98981801110817
fidelity of state 5: 0.9899699879774597
fidelity of state 6: 0.9899626413311895
fidelity of state 7: 0.9900407902354398
fidelity of state 8: 0.9900467965501678
fidelity of state 9: 0.9898608063615215
