# Example E: State stabilization with SNAP gates and displacement gates

The use of feedback GRAPE applied to the Jaynes-
Cummings scenario allows us to discover strategies
extending the lifetime of a range of quantum states. How-
ever, for more complex quantum states such as kitten
states, the infidelity becomes significant after just a few
dissipative evolution steps in spite of the feedback [cf.
Fig. 6(c)]. This raises the question of whether the limited
quality of the stabilization is to be attributed to a failure
of our feedback-GRAPE learning algorithm to properly
explore the control-parameter landscape or, rather, to the
limited expressivity of the controls. With the goal of
addressing this question, we test our method on the state-
stabilization task using a more expressive control scheme.

In [2]:
# ruff: noqa
import os

os.sys.path.append("..")
from feedback_grape.fgrape import optimize_pulse_with_feedback
from feedback_grape.utils.operators import sigmam, identity, cosm, sinm
from feedback_grape.utils.states import coherent, basis
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
import jax
from jax.scipy.linalg import expm

jax.config.update("jax_enable_x64", True)

## Initialize states

In [3]:
from feedback_grape.utils.fidelity import ket2dm

N_cav = 30  # number of cavity modes
N_snap = 15

alpha = 2
psi_target = coherent(N_cav, alpha) + coherent(N_cav, -alpha)

# Normalize psi_target before constructing rho_target
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = ket2dm(psi_target)

rho_target = tensor(rho_target, ket2dm(basis(2)))

In [4]:
# Parity Operator
from feedback_grape.utils.operators import create, destroy


def parity_operator(N_cav):
    return tensor(
        jax.scipy.linalg.expm(1j * jnp.pi * (create(N_cav) @ destroy(N_cav))),
        identity(2),
    )

In [5]:
# Confirm that the kitten2 state has an even parity
parity_op = parity_operator(N_cav)
parity_check = jnp.isclose(
    jnp.trace((parity_op @ rho_target) @ rho_target), 1.0
)
print("Parity check for the kitten2 state:", parity_check)
print("parity_check trace :", jnp.real(jnp.trace(parity_op @ rho_target)))

Parity check for the kitten2 state: True
parity_check trace : 1.0000000000000013


## Initialize the parameterized Gates

In [6]:
def displacement_gate(alpha_re, alpha_im):
    """Displacement operator for a coherent state."""
    alpha = alpha_re + 1j * alpha_im
    gate = jax.scipy.linalg.expm(
        alpha * create(N_cav) - alpha.conj() * destroy(N_cav)
    )
    return tensor(gate, identity(2))


def displacement_gate_dag(alpha_re, alpha_im):
    """Displacement operator for a coherent state."""
    alpha = alpha_re + 1j * alpha_im
    gate = (
        jax.scipy.linalg.expm(
            alpha * create(N_cav) - alpha.conj() * destroy(N_cav)
        )
        .conj()
        .T
    )
    return tensor(gate, identity(2))

In [7]:
def snap_gate(
    phase0,
    phase1,
    phase2,
    phase3,
    phase4,
    phase5,
    phase6,
    phase7,
    phase8,
    phase9,
    phase10,
    phase11,
    phase12,
    phase13,
    phase14,
):
    phase_list = [
        phase0,
        phase1,
        phase2,
        phase3,
        phase4,
        phase5,
        phase6,
        phase7,
        phase8,
        phase9,
        phase10,
        phase11,
        phase12,
        phase13,
        phase14,
    ]
    diags = jnp.ones(shape=(N_cav - len(phase_list)))
    exponentiated = jnp.exp(1j * jnp.array(phase_list))
    diags = jnp.concatenate((exponentiated, diags))
    return tensor(jnp.diag(diags), identity(2))

In [8]:
from feedback_grape.utils.operators import create, destroy


def povm_measure_operator(measurement_outcome, gamma, delta):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    number_operator = tensor(create(N_cav) @ destroy(N_cav), identity(2))
    angle = (gamma * number_operator) + delta / 2
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

## Initialize RNN of choice

In [9]:
import flax.linen as nn


class RNN(nn.Module):
    hidden_size: int  # number of features in the hidden state
    output_size: int  # number of features in the output ( 2 in the case of gamma and beta)

    @nn.compact
    def __call__(self, measurement, hidden_state):
        """
        If your GRU has a hidden state increasing number of features in the hidden stateH means:

        - You're allowing the model to store more information across time steps

        - Each time step can represent more complex features, patterns, or dependencies

        - You're giving the GRU more representational capacity
        """
        gru_cell = nn.GRUCell(
            features=self.hidden_size,
            gate_fn=nn.sigmoid,
            activation_fn=nn.tanh,
        )
        self.make_rng('dropout')

        if measurement.ndim == 1:
            measurement = measurement.reshape(1, -1)

        new_hidden_state, _ = gru_cell(hidden_state, measurement)
        new_hidden_state = nn.Dropout(rate=0.1, deterministic=False)(
            new_hidden_state
        )
        # this returns the povm_params after linear regression through the hidden state which contains
        # the information of the previous time steps and this is optimized to output best povm_params
        # new_hidden_state = nn.Dense(features=self.hidden_size)(new_hidden_state)
        new_hidden_state = nn.Dense(
            features=self.hidden_size,
            kernel_init=nn.initializers.glorot_uniform(),
        )(new_hidden_state)
        new_hidden_state = nn.relu(new_hidden_state)
        new_hidden_state = nn.Dense(
            features=self.hidden_size,
            kernel_init=nn.initializers.glorot_uniform(),
        )(new_hidden_state)
        new_hidden_state = nn.relu(new_hidden_state)
        output = nn.Dense(
            features=self.output_size,
            kernel_init=nn.initializers.glorot_uniform(),
            bias_init=nn.initializers.constant(0.1),
        )(new_hidden_state)
        output = nn.relu(output)
        # output = jnp.asarray(output)
        return output[0], new_hidden_state

In [10]:
# Note if tsave = jnp.linspace(0, 1, 1) = [0.0] then the decay is not applied ?
# because the first time step has the original non decayed state
key = jax.random.PRNGKey(42)
snap_init = jax.random.uniform(
    key, shape=(N_snap,), minval=-jnp.pi, maxval=jnp.pi
)


result = optimize_pulse_with_feedback(
    U_0=rho_target,
    C_target=rho_target,
    parameterized_gates=[
        povm_measure_operator,
        displacement_gate,
        snap_gate,
        displacement_gate_dag,
    ],
    measurement_indices=[0],
    decay={
        "decay_indices": [0, 1],  # indices of gates before which decay occurs
        "c_ops": {
            "tm": [tensor(identity(N_cav), jnp.sqrt(0.005) * sigmam())],
            "tc": [tensor(identity(N_cav), jnp.sqrt(0.005) * sigmam())],
        },
        "tsave": jnp.linspace(0, 1, 2),  # time grid for decay
        "Hamiltonian": None,
    },
    initial_params={
        "POVM": jax.random.uniform(
            key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
        )[0].tolist(),
        "D": jax.random.uniform(
            key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
        )[0].tolist(),
        "Snap": snap_init.tolist(),
        "D_dag": jax.random.uniform(
            key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
        )[0].tolist(),
    },
    num_time_steps=2,
    mode="lookup",
    goal="fidelity",
    max_iter=1000,
    convergence_threshold=1e-6,
    learning_rate=0.01,
    type="density",
    batch_size=16,
)

parameter shapes: [2, 2, 15, 2]
Number of parameters: 21
Iteration 0, Loss: 0.021928
Iteration 10, Loss: 0.068104
Iteration 20, Loss: 0.123502
Iteration 30, Loss: 0.116336
Iteration 40, Loss: 0.106777
Iteration 50, Loss: 0.118837
Iteration 60, Loss: 0.157045
Iteration 70, Loss: 0.187944
Iteration 80, Loss: 0.196117
Iteration 90, Loss: 0.215340
Iteration 100, Loss: 0.153767
Iteration 110, Loss: 0.243403
Iteration 120, Loss: 0.215006
Iteration 130, Loss: 0.226804
Iteration 140, Loss: 0.248245
Iteration 150, Loss: 0.231428
Iteration 160, Loss: 0.238630
Iteration 170, Loss: 0.230646
Iteration 180, Loss: 0.225896
Iteration 190, Loss: 0.286477
Iteration 200, Loss: 0.277647
Iteration 210, Loss: 0.229370
Iteration 220, Loss: 0.251809
Iteration 230, Loss: 0.222857
Iteration 240, Loss: 0.266373
Iteration 250, Loss: 0.245919
Iteration 260, Loss: 0.255272
Iteration 270, Loss: 0.331532
Iteration 280, Loss: 0.239328
Iteration 290, Loss: 0.291380
Iteration 300, Loss: 0.240949
Iteration 310, Loss: 0.2

In [11]:
result

FgResult(optimized_trainable_parameters={'initial_params': [Array([-0.4578815 , -3.14982266], dtype=float64), Array([-0.46038428, -3.04728161], dtype=float64), Array([-0.46038428, -3.04728161,  0.46586312,  1.44399472,  1.17998434,
        0.90438751,  0.90876566,  0.58447858, -1.65424795, -1.49014079,
       -1.62133312, -2.78063235,  1.41394885,  2.00049328,  0.55503047],      dtype=float64), Array([-0.46038428, -3.04728161], dtype=float64)], 'lookup_table': [[Array([1.54603278, 0.20808653, 2.48406743, 0.93286689, 2.65952632,
       0.79730518, 3.31906547, 0.79875239, 3.86400982, 4.04317131,
       5.33494901, 1.09421009, 4.28318359, 3.16261449, 4.24029802,
       4.67213644, 2.37102287, 2.19780994, 2.58248716, 2.62612895,
       2.67499731], dtype=float64), Array([ 1.39226034,  0.1309011 ,  2.08088821, -0.03401347,  3.90569858,
        1.25302374,  1.73385025,  0.61406551,  0.89131889,  4.3911209 ,
        4.44655996,  0.32176911,  4.7685982 ,  7.37266084,  2.07202918,
        3.830

In [12]:
result.final_fidelity

Array(0.75986007, dtype=float64)

In [13]:
result.final_state.shape

(10, 60, 60)

In [14]:
from feedback_grape.utils.fidelity import ket2dm

N_cav = 30  # number of cavity modes
N_snap = 15

alpha = 2
psi_target = coherent(N_cav, alpha) + coherent(N_cav, -alpha)

# Normalize psi_target before constructing rho_target
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = ket2dm(psi_target)

rho_target = tensor(rho_target, ket2dm(basis(2)))

In [15]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho_target, type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, type="density"),
    )

initial fidelity: 1.0000000139428749
fidelity of state 0: 0.7796323880584769
fidelity of state 1: 0.7278535162067473
fidelity of state 2: 0.8405094870895429
fidelity of state 3: 0.7278535162067473
fidelity of state 4: 0.7421784260620066
fidelity of state 5: 0.7278535162067473
fidelity of state 6: 0.7421784260620066
fidelity of state 7: 0.8405094870895429
fidelity of state 8: 0.7278535162067473
fidelity of state 9: 0.7421784260620066
