# E: State stabilization with SNAP gates and displacement gates

The use of feedback GRAPE applied to the Jaynes-
Cummings scenario allows us to discover strategies
extending the lifetime of a range of quantum states. How-
ever, for more complex quantum states such as kitten
states, the infidelity becomes significant after just a few
dissipative evolution steps in spite of the feedback [cf.
Fig. 6(c)]. This raises the question of whether the limited
quality of the stabilization is to be attributed to a failure
of our feedback-GRAPE learning algorithm to properly
explore the control-parameter landscape or, rather, to the
limited expressivity of the controls. With the goal of
addressing this question, we test our method on the state-
stabilization task using a more expressive control scheme.

In [6]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../feedback-grape"))
sys.path.append(os.path.abspath("./../"))

# ruff: noqa
from feedback_grape.fgrape import optimize_pulse, Decay, Gate, evaluate_on_longer_time
from feedback_grape.utils.operators import cosm, sinm, identity
from feedback_grape.utils.states import coherent
import jax.numpy as jnp
import jax
from feedback_grape.utils.operators import create, destroy
from feedback_grape.utils.fidelity import ket2dm
from library.utils.FgResult_to_dict import FgResult_to_dict
from tqdm import tqdm
import json

jax.config.update("jax_enable_x64", True)

## Initialize states

In [7]:
# Training parameters
N_training_iterations = 1000 # Number of training iterations
learning_rate = 0.001 # Learning rate
convergence_threshold = 1e-6 # Convergence threshold for early stopping

#num_time_steps : Number of time steps in the control pulse
#lut_depth : Depth of the lookup table for feedback
#reward_weights: Weights for the reward at each time step. Default only weights last timestep [0, 0, ... 0, 1]

samples = 5
experiments = [ # (timesteps, lut_depth, reward_weights)
    #(1,1,[1]),
    #(2,1,[1,1]),
    #(2,2,[0,1]),
    #(2,2,[1,1]),
    #(3,3,[0,0,1]),
    (5,3,[1,1,1,1,1]),
    #(4,4,[0,0,0,1]),
    #(4,3,[0,0,0,1]),
    #(4,3,[0,0,1,1]),
    #(4,3,[1,1,1,1]),
    #(4,4,[1,1,1,1]),
]

# Physical parameters
N_cav = 30  # number of cavity modes
N_snap = 15

alpha = 2
psi_target = coherent(N_cav, alpha) + coherent(N_cav, -alpha)

# Normalize psi_target before constructing rho_target
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = ket2dm(psi_target)

for _, _, weights in experiments:
    for w in weights:
        assert type(w) == int and 0 <= w <= 9, "Weights must be integers between 0 and 9 to save as single character in filename."

## Initialize the parameterized Gates

In [8]:
def displacement_gate(alphas):
    """Displacement operator for a coherent state."""
    alpha_re, alpha_im = alphas
    alpha = alpha_re + 1j * alpha_im
    gate = jax.scipy.linalg.expm(
        alpha * create(N_cav) - alpha.conj() * destroy(N_cav)
    )
    return gate

def initialize_displacement_gate(key):
    return Gate(
        gate=displacement_gate,
        initial_params=jax.random.uniform(
            key,
            shape=(2,),
            minval=-jnp.pi / 2,
            maxval=jnp.pi / 2,
            dtype=jnp.float64,
        ),
        measurement_flag=False,
    )

def displacement_gate_dag(alphas):
    """Displacement operator for a coherent state."""
    return displacement_gate(alphas).conj().T

def initialize_displacement_gate_dag(key):
    return Gate(
        gate=displacement_gate_dag,
        initial_params=jax.random.uniform(
            key,
            shape=(2,),
            minval=-jnp.pi / 2,
            maxval=jnp.pi / 2,
            dtype=jnp.float64,
        ),
        measurement_flag=False,
    )

def snap_gate(phase_list):
    diags = jnp.ones(shape=(N_cav - len(phase_list)))
    exponentiated = jnp.exp(1j * jnp.array(phase_list))
    diags = jnp.concatenate((exponentiated, diags))
    return jnp.diag(diags)

def initialize_snap_gate(key):
    return Gate(
        gate=snap_gate,
        initial_params=jax.random.uniform(
            key,
            shape=(N_snap,),
            minval=-jnp.pi / 2,
            maxval=jnp.pi / 2,
            dtype=jnp.float64,
        ),
        measurement_flag=False,
    )

def povm_measure_operator(measurement_outcome, params):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    gamma, delta = params
    cav_operator = gamma * create(N_cav) @ destroy(N_cav) + delta * identity(N_cav) / 2
    angle = cav_operator
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

def initialize_povm_gate(key):
    return Gate(
        gate=povm_measure_operator,
        initial_params=jax.random.uniform(
            key,
            shape=(2,),  # 2 for gamma and delta
            minval=-jnp.pi / 2,
            maxval=jnp.pi / 2,
            dtype=jnp.float64,
        ),
        measurement_flag=True,
    )

decay_gate = Decay(c_ops=[jnp.sqrt(0.005) * destroy(N_cav)])

def initialize_system_params(key):
    keys = jax.random.split(key, 4)
    return [
        decay_gate,
        initialize_povm_gate(keys[0]),
        decay_gate,
        initialize_displacement_gate(keys[1]),
        initialize_snap_gate(keys[2]),
        initialize_displacement_gate_dag(keys[3])
    ]

## Initialize RNN of choice

In [9]:
import flax.linen as nn


# You can do whatever you want inside so long as you maintaing the hidden_size and output size shapes
class RNN(nn.Module):
    hidden_size: int  # number of features in the hidden state
    output_size: int  # number of features in the output (inferred from the number of parameters) just provide those attributes to the class

    @nn.compact
    def __call__(self, measurement, hidden_state):

        if measurement.ndim == 1:
            measurement = measurement.reshape(1, -1)

        ###############
        ### Free to change whatever you want below as long as hidden layers have size self.hidden_size
        ### and output layer has size self.output_size
        ###############

        gru_cell = nn.GRUCell(
            features=self.hidden_size,
            gate_fn=nn.sigmoid,
            activation_fn=nn.tanh,
        )
        self.make_rng('dropout')

        new_hidden_state, _ = gru_cell(hidden_state, measurement)
        new_hidden_state = nn.Dropout(rate=0.1, deterministic=False)(
            new_hidden_state
        )
        # this returns the povm_params after linear regression through the hidden state which contains
        # the information of the previous time steps and this is optimized to output best povm_params
        # new_hidden_state = nn.Dense(features=self.hidden_size)(new_hidden_state)
        new_hidden_state = nn.Dense(
            features=self.hidden_size,
            kernel_init=nn.initializers.glorot_uniform(),
        )(new_hidden_state)
        new_hidden_state = nn.relu(new_hidden_state)
        new_hidden_state = nn.Dense(
            features=self.hidden_size,
            kernel_init=nn.initializers.glorot_uniform(),
        )(new_hidden_state)
        new_hidden_state = nn.relu(new_hidden_state)
        output = nn.Dense(
            features=self.output_size,
            kernel_init=nn.initializers.glorot_uniform(),
            bias_init=nn.initializers.constant(0.1),
        )(new_hidden_state)
        output = nn.relu(output)

        ###############
        ### Do not change the return statement
        ###############

        return output[0], new_hidden_state

In [None]:
for num_time_steps, lut_depth, reward_weights in experiments:
    for s in range(samples):
        print(f"Evaluating num_time_steps={num_time_steps}, lut_depth={lut_depth}, reward_weights={reward_weights}")
        print("Training LUT")

        system_params = initialize_system_params(jax.random.PRNGKey(s))
        weights_str = "".join([str(w) for w in reward_weights])

        # Train LUT
        result = optimize_pulse(
            U_0=rho_target,
            C_target=rho_target,
            system_params=system_params,
            num_time_steps=num_time_steps,
            lut_depth=lut_depth,
            reward_weights=reward_weights,
            mode="lookup",
            goal="fidelity",
            max_iter=N_training_iterations,
            convergence_threshold=convergence_threshold,
            learning_rate=learning_rate,
            evo_type="density",
            batch_size=16,
            progress=True,
        )

        with open(f"./optimized_architectures/lut_t={num_time_steps}_l={lut_depth}_s={s}_w={weights_str}.json", "w") as f:
            json.dump(FgResult_to_dict(result), f)

        result = evaluate_on_longer_time(
            U_0 = rho_target,
            C_target = rho_target,
            system_params = system_params,
            optimized_trainable_parameters = result.optimized_trainable_parameters,
            num_time_steps = 1000,
            evo_type = "density",
            goal = "fidelity",
            eval_batch_size = 16, # Default is 10
            mode = "lookup",
            rnn = None,
            rnn_hidden_size = 30,
        )

        fidelities_lut = result.fidelity_each_timestep

        jnp.savez(f"./evaluation_results/lut_t={num_time_steps}_l={lut_depth}_s={s}_w={weights_str}.npz", fidelities_lut=jnp.array(fidelities_lut))

        # Check if RNN has to be trained (only if num_time_steps and weight are different)
        if os.path.exists(f"./optimized_architectures/rnn_t={num_time_steps}_s={s}_w={weights_str}.json"):
            print(f"RNN for t={num_time_steps} and weights={weights_str} already trained, skipping training.")
            continue
        """
        # Train RNN
        print("Training RNN")
        result = optimize_pulse(
            U_0=rho_target,
            C_target=rho_target,
            system_params=system_params,
            num_time_steps=num_time_steps,
            reward_weights=reward_weights,
            mode="nn",
            goal="fidelity",
            max_iter=N_training_iterations,
            convergence_threshold=convergence_threshold,
            learning_rate=learning_rate,
            evo_type="density",
            batch_size=16,
            rnn=RNN,
            rnn_hidden_size=30,
            progress=True,
        )

        with open(f"./optimized_architectures/rnn_t={num_time_steps}_s={s}_w={weights_str}.json", "w") as f:
            json.dump(FgResult_to_dict(result), f)

        result = evaluate_on_longer_time(
            U_0 = rho_target,
            C_target = rho_target,
            system_params = system_params,
            optimized_trainable_parameters = result.optimized_trainable_parameters,
            num_time_steps = 1000,
            evo_type = "density",
            goal = "fidelity",
            eval_batch_size = 16, # Default is 10
            mode = "nn",
            rnn = RNN,
            rnn_hidden_size = 30,
        )

        fidelities_rnn = result.fidelity_each_timestep

        jnp.savez(f"./evaluation_results/rnn_t={num_time_steps}_s={s}_w={weights_str}.npz", fidelities_rnn=jnp.array(fidelities_rnn))
        """

Evaluating num_time_steps=5, lut_depth=3, reward_weights=[1, 1, 1, 1, 1]
Training LUT


2025-11-10 22:00:44.961477: E external/xla/xla/service/slow_operation_alarm.cc:73] 
********************************
[Compiling module jit_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2025-11-10 22:00:54.322421: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 2m9.368274s

********************************
[Compiling module jit_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


Iteration 10, Loss: 0.160934, T=14s, eta=1265s
Iteration 20, Loss: 0.172045, T=27s, eta=1306s
Iteration 30, Loss: 0.195961, T=41s, eta=1311s
Iteration 40, Loss: 0.226193, T=55s, eta=1306s
Iteration 50, Loss: 0.291703, T=69s, eta=1299s
Iteration 60, Loss: 0.299828, T=83s, eta=1289s
Iteration 70, Loss: 0.270495, T=97s, eta=1278s
Iteration 80, Loss: 0.267050, T=112s, eta=1276s
Iteration 90, Loss: 0.234718, T=126s, eta=1266s
Iteration 100, Loss: 0.376861, T=140s, eta=1255s
Iteration 110, Loss: 0.317653, T=154s, eta=1243s
Iteration 120, Loss: 0.346998, T=169s, eta=1231s
Iteration 130, Loss: 0.327637, T=183s, eta=1219s
Iteration 140, Loss: 0.334984, T=197s, eta=1206s
Iteration 150, Loss: 0.379662, T=211s, eta=1193s
Iteration 160, Loss: 0.362941, T=226s, eta=1180s
Iteration 170, Loss: 0.248496, T=240s, eta=1167s
Iteration 180, Loss: 0.328306, T=254s, eta=1154s
Iteration 190, Loss: 0.393582, T=268s, eta=1140s
Iteration 200, Loss: 0.279497, T=282s, eta=1127s
Iteration 210, Loss: 0.454245, T=297