# Example E: State stabilization with SNAP gates and displacement gates

The use of feedback GRAPE applied to the Jaynes-
Cummings scenario allows us to discover strategies
extending the lifetime of a range of quantum states. How-
ever, for more complex quantum states such as kitten
states, the infidelity becomes significant after just a few
dissipative evolution steps in spite of the feedback [cf.
Fig. 6(c)]. This raises the question of whether the limited
quality of the stabilization is to be attributed to a failure
of our feedback-GRAPE learning algorithm to properly
explore the control-parameter landscape or, rather, to the
limited expressivity of the controls. With the goal of
addressing this question, we test our method on the state-
stabilization task using a more expressive control scheme.

In [2]:
# ruff: noqa
import os

os.sys.path.append("..")
from feedback_grape.fgrape import optimize_pulse_with_feedback
from feedback_grape.utils.operators import sigmam, identity, cosm, sinm
from feedback_grape.utils.states import coherent, basis
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
import jax
from jax.scipy.linalg import expm

jax.config.update("jax_enable_x64", True)

## Initialize states

In [3]:
from feedback_grape.utils.fidelity import ket2dm

N_cav = 30  # number of cavity modes
N_snap = 15

alpha = 2
psi_target = coherent(N_cav, alpha) + coherent(N_cav, -alpha)

# Normalize psi_target before constructing rho_target
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = ket2dm(psi_target)

rho_target = tensor(rho_target, ket2dm(basis(2)))

In [4]:
# Parity Operator
from feedback_grape.utils.operators import create, destroy


def parity_operator(N_cav):
    return tensor(
        jax.scipy.linalg.expm(1j * jnp.pi * (create(N_cav) @ destroy(N_cav))),
        identity(2),
    )

In [5]:
# Confirm that the kitten2 state has an even parity
parity_op = parity_operator(N_cav)
parity_check = jnp.isclose(
    jnp.trace((parity_op @ rho_target) @ rho_target), 1.0
)
print("Parity check for the kitten2 state:", parity_check)
print("parity_check trace :", jnp.real(jnp.trace(parity_op @ rho_target)))

Parity check for the kitten2 state: True
parity_check trace : 1.0000000000000013


## Initialize the parameterized Gates

In [6]:
def displacement_gate(alpha_re, alpha_im):
    """Displacement operator for a coherent state."""
    alpha = alpha_re + 1j * alpha_im
    gate = jax.scipy.linalg.expm(
        alpha * create(N_cav) - alpha.conj() * destroy(N_cav)
    )
    return tensor(gate, identity(2))


def displacement_gate_dag(alpha_re, alpha_im):
    """Displacement operator for a coherent state."""
    alpha = alpha_re + 1j * alpha_im
    gate = (
        jax.scipy.linalg.expm(
            alpha * create(N_cav) - alpha.conj() * destroy(N_cav)
        )
        .conj()
        .T
    )
    return tensor(gate, identity(2))

In [7]:
def snap_gate(
    phase0,
    phase1,
    phase2,
    phase3,
    phase4,
    phase5,
    phase6,
    phase7,
    phase8,
    phase9,
    phase10,
    phase11,
    phase12,
    phase13,
    phase14,
):
    phase_list = [
        phase0,
        phase1,
        phase2,
        phase3,
        phase4,
        phase5,
        phase6,
        phase7,
        phase8,
        phase9,
        phase10,
        phase11,
        phase12,
        phase13,
        phase14,
    ]
    diags = jnp.ones(shape=(N_cav - len(phase_list)))
    exponentiated = jnp.exp(1j * jnp.array(phase_list))
    diags = jnp.concatenate((exponentiated, diags))
    return tensor(jnp.diag(diags), identity(2))

In [8]:
from feedback_grape.utils.operators import create, destroy


def povm_measure_operator(measurement_outcome, gamma, delta):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    number_operator = tensor(create(N_cav) @ destroy(N_cav), identity(2))
    angle = (gamma * number_operator) + delta / 2
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

## Initialize RNN of choice

In [9]:
import flax.linen as nn


class RNN(nn.Module):
    hidden_size: int  # number of features in the hidden state
    output_size: int  # number of features in the output ( 2 in the case of gamma and beta)

    @nn.compact
    def __call__(self, measurement, hidden_state):
        """
        If your GRU has a hidden state increasing number of features in the hidden stateH means:

        - You're allowing the model to store more information across time steps

        - Each time step can represent more complex features, patterns, or dependencies

        - You're giving the GRU more representational capacity
        """
        gru_cell = nn.GRUCell(
            features=self.hidden_size,
            gate_fn=nn.sigmoid,
            activation_fn=nn.tanh,
        )
        self.make_rng('dropout')

        if measurement.ndim == 1:
            measurement = measurement.reshape(1, -1)

        new_hidden_state, _ = gru_cell(hidden_state, measurement)
        new_hidden_state = nn.Dropout(rate=0.1, deterministic=False)(
            new_hidden_state
        )
        # this returns the povm_params after linear regression through the hidden state which contains
        # the information of the previous time steps and this is optimized to output best povm_params
        # new_hidden_state = nn.Dense(features=self.hidden_size)(new_hidden_state)
        new_hidden_state = nn.Dense(
            features=self.hidden_size,
            kernel_init=nn.initializers.glorot_uniform(),
        )(new_hidden_state)
        new_hidden_state = nn.relu(new_hidden_state)
        new_hidden_state = nn.Dense(
            features=self.hidden_size,
            kernel_init=nn.initializers.glorot_uniform(),
        )(new_hidden_state)
        new_hidden_state = nn.relu(new_hidden_state)
        output = nn.Dense(
            features=self.output_size,
            kernel_init=nn.initializers.glorot_uniform(),
            bias_init=nn.initializers.constant(0.1),
        )(new_hidden_state)
        output = nn.relu(output)
        # output = jnp.asarray(output)
        return output[0], new_hidden_state

### In this notebook we evaluate for time_steps = 100

In [None]:
# Note if tsave = jnp.linspace(0, 1, 1) = [0.0] then the decay is not applied ?
# because the first time step has the original non decayed state
key = jax.random.PRNGKey(42)
snap_init = jax.random.uniform(
    key, shape=(N_snap,), minval=-jnp.pi, maxval=jnp.pi
)

measure = {
    "gate": povm_measure_operator,
    "initial_params": jax.random.uniform(
        key,
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jnp.pi,
        maxval=jnp.pi,
    )[0].tolist(),
    "measurement_flag": True,
    # "param_constrains": [[0, 0.5], [-1, 1]],
}

displacement = {
    "gate": displacement_gate,
    "initial_params": jax.random.uniform(
        key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
    )[0].tolist(),
    "measurement_flag": False,
}

snap = {
    "gate": snap_gate,
    "initial_params": snap_init.tolist(),
    "measurement_flag": False,
}

displacement_dag = {
    "gate": displacement_gate_dag,
    "initial_params": jax.random.uniform(
        key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
    )[0].tolist(),
    "measurement_flag": False,
}

system_params = [measure, displacement, snap, displacement_dag]


result = optimize_pulse_with_feedback(
    U_0=rho_target,
    C_target=rho_target,
    decay={
        "decay_indices": [0, 1],  # indices of gates before which decay occurs
        "c_ops": {
            "tm": [tensor(identity(N_cav), jnp.sqrt(0.005) * sigmam())],
            "tc": [tensor(identity(N_cav), jnp.sqrt(0.005) * sigmam())],
        },
        "tsave": jnp.linspace(0, 1, 2),  # time grid for decay
        "Hamiltonian": None,
    },
    system_params=system_params,
    num_time_steps=2,
    mode="nn",
    goal="fidelity",
    max_iter=1000,
    convergence_threshold=1e-6,
    learning_rate=0.01,
    type="density",
    batch_size=16,
    rnn=RNN,
    rnn_hidden_size=30
)

parameter shapes: [2, 2, 15, 2]
Number of parameters: 21
Iteration 0, Loss: 0.061050
Iteration 10, Loss: 0.232090
Iteration 20, Loss: 0.014440
Iteration 30, Loss: 0.088678
Iteration 40, Loss: 0.055143
Iteration 50, Loss: 0.050234
Iteration 60, Loss: 0.225698
Iteration 70, Loss: -0.058120
Iteration 80, Loss: 0.194732
Iteration 90, Loss: 0.188085
Iteration 100, Loss: -0.023787
Iteration 110, Loss: 0.116117
Iteration 120, Loss: 0.155788
Iteration 130, Loss: 0.229707
Iteration 140, Loss: 0.093308
Iteration 150, Loss: 0.048008
Iteration 160, Loss: -0.026538
Iteration 170, Loss: -0.257134
Iteration 180, Loss: -0.261263
Iteration 190, Loss: -0.282137
Iteration 200, Loss: -0.543713
Iteration 210, Loss: -0.977180
Iteration 220, Loss: -0.963586
Iteration 230, Loss: -0.987092
Iteration 240, Loss: -0.515571
Iteration 250, Loss: -0.651303
Iteration 260, Loss: -0.725734
Iteration 270, Loss: -0.863907
Iteration 280, Loss: -0.972941
Iteration 290, Loss: -0.970709
Iteration 300, Loss: -0.984676
Iterati

In [11]:
result

FgResult(optimized_trainable_parameters={'initial_params': [Array([-1.53516836e-04, -3.14005769e+00], dtype=float64), Array([-0.46038428, -3.04728161], dtype=float64), Array([-0.46038428, -3.04728161,  0.46586312,  1.44399472,  1.17998434,
        0.90438751,  0.90876566,  0.58447858, -1.65424795, -1.49014079,
       -1.62133312, -2.78063235,  1.41394885,  2.00049328,  0.55503047],      dtype=float64), Array([-0.46038428, -3.04728161], dtype=float64)], 'rnn_params': {'params': {'Dense_0': {'bias': Array([-0.28734857, -0.20091279, -0.17574432, -0.07674688, -0.09007517,
       -0.13788211,  0.01112831, -0.01104879, -0.14902   , -0.16531822,
       -0.03819051, -0.06022503, -0.05114665, -0.18305182, -0.06655017,
       -0.05708452, -0.18409555, -0.10593934, -0.13076983, -0.1865322 ,
       -0.16929407, -0.16696413, -0.03470135, -0.04131583, -0.03627504,
       -0.03043098, -0.11653399, -0.17078209, -0.09991261, -0.01972033],      dtype=float32), 'kernel': Array([[-2.60665983e-01,  1.93471

In [12]:
result.final_fidelity

Array(0.6063228, dtype=float64)

In [13]:
result.final_state.shape

(10, 60, 60)

In [14]:
from feedback_grape.utils.fidelity import ket2dm

N_cav = 30  # number of cavity modes
N_snap = 15

alpha = 2
psi_target = coherent(N_cav, alpha) + coherent(N_cav, -alpha)

# Normalize psi_target before constructing rho_target
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = ket2dm(psi_target)

rho_target = tensor(rho_target, ket2dm(basis(2)))

In [15]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho_target, type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, type="density"),
    )

initial fidelity: 1.0000000139428749
fidelity of state 0: 0.6063498271324191
fidelity of state 1: 0.6062127139247991
fidelity of state 2: 0.6063457295727779
fidelity of state 3: 0.6063517799148108
fidelity of state 4: 0.6063145887121253
fidelity of state 5: 0.6063317057791199
fidelity of state 6: 0.6063245798395572
fidelity of state 7: 0.6063252234058198
fidelity of state 8: 0.6063173922805738
fidelity of state 9: 0.6063544609284945
