## State Stabilization

In [16]:
# ruff: noqa
import os

os.sys.path.append("..")
from feedback_grape.fgrape import optimize_pulse_with_feedback
from feedback_grape.utils.operators import (
    sigmap,
    sigmam,
    create,
    destroy,
    identity,
    cosm,
    sinm,
)
from feedback_grape.utils.states import basis, fock
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
from jax.scipy.linalg import expm

In [17]:
N_cav = 30

In [18]:
def qubit_unitary(alpha):
    """
    TODO: see if alpha, can be sth elser other than scalar, and if the algo understands this
    see if there can be multiple params like alpha and beta input
    """
    return expm(
        -1j
        * (
            alpha * tensor(identity(N_cav), sigmap())
            + alpha.conjugate() * tensor(identity(N_cav), sigmam())
        )
        / 2
    )

In [19]:
def qubit_cavity_unitary(beta):
    return expm(
        -1j
        * (
            beta
            * (
                tensor(destroy(N_cav), identity(2))
                @ tensor(identity(N_cav), sigmap())
            )
            + beta.conjugate()
            * (
                tensor(create(N_cav), identity(2))
                @ tensor(identity(N_cav), sigmam())
            )
        )
        / 2
    )

In [20]:
from feedback_grape.utils.operators import create, destroy


def povm_measure_operator(measurement_outcome, gamma, delta):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    number_operator = tensor(create(N_cav) @ destroy(N_cav), identity(2))
    angle = (gamma * number_operator) + delta / 2
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

In [21]:
# initial state is a thermal state coupled to a qubit in the ground state?
n_average = 9
# natural logarithm
beta = jnp.log((1 / n_average) + 1)
diags = jnp.exp(-beta * jnp.arange(N_cav))
normalized_diags = diags / jnp.sum(diags, axis=0)
rho_cav = jnp.diag(normalized_diags)

In [22]:
rho0 = tensor(rho_cav, basis(2, 0) @ basis(2, 0).conj().T)

In [23]:
rho0.shape

(60, 60)

In [24]:
from feedback_grape.utils.states import coherent

alpha = 3
psi_target = tensor(
    coherent(N_cav, alpha)
    + coherent(N_cav, -alpha)
    + coherent(N_cav, 1j * alpha)
    + coherent(N_cav, -1j * alpha),
    basis(2),
)  # 4-legged state

rho_target = psi_target @ psi_target.conj().T

In [25]:
rho_target.shape

(60, 60)

In [26]:
from feedback_grape.utils.fidelity import fidelity

print(fidelity(U_final=rho0, C_target=rho_target, type="density"))

0.20605832342551042


In [27]:
import flax.linen as nn


class RNN(nn.Module):
    hidden_size: int  # number of features in the hidden state
    output_size: int  # number of features in the output ( 2 in the case of gamma and beta)

    @nn.compact
    def __call__(self, measurement, hidden_state):
        """
        If your GRU has a hidden state increasing number of features in the hidden stateH means:

        - You're allowing the model to store more information across time steps

        - Each time step can represent more complex features, patterns, or dependencies

        - You're giving the GRU more representational capacity
        """
        gru_cell = nn.GRUCell(features=self.hidden_size)
        self.make_rng('dropout')
        if measurement.ndim == 1:
            measurement = measurement.reshape(1, -1)
        new_hidden_state, _ = gru_cell(hidden_state, measurement)
        new_hidden_state = nn.Dropout(rate=0.2, deterministic=False)(
            new_hidden_state
        )
        # this returns the povm_params after linear regression through the hidden state which contains
        # the information of the previous time steps and this is optimized to output best povm_params
        # new_hidden_state = nn.Dense(features=self.hidden_size)(new_hidden_state)
        output = nn.Dense(
            features=self.output_size,
            kernel_init=nn.initializers.glorot_uniform(),
            bias_init=nn.initializers.constant(jnp.pi),
        )(new_hidden_state)
        # output = jnp.asarray(output)
        return output[0], new_hidden_state

### Without dissipation

In [85]:
result = optimize_pulse_with_feedback(
    U_0=rho_target,
    C_target=rho_target,
    parameterized_gates=[
        povm_measure_operator,
        qubit_unitary,
        qubit_cavity_unitary,
    ],
    measurement_indices=[0],
    initial_params = {
        "POVM": [0.0, jnp.pi / 3],
        "U_q": [jnp.pi / 2],
        "U_qc": [jnp.pi / 2],
    },
    num_time_steps=1,
    mode="lookup",
    goal="fidelity",
    optimizer="adam",
    max_iter=1000,
    convergence_threshold=1e-6,
    learning_rate=0.005,
    type="density",
    batch_size=10,
)

Iteration 0, Loss: -0.191833
Iteration 10, Loss: -0.352869
Iteration 20, Loss: -0.503874
Iteration 30, Loss: -0.539174
Iteration 40, Loss: -0.582511
Iteration 50, Loss: -0.624338
Iteration 60, Loss: -0.665082
Iteration 70, Loss: -0.706828
Iteration 80, Loss: -0.749155
Iteration 90, Loss: -0.791993
Iteration 100, Loss: -0.835120
Iteration 110, Loss: -0.878371
Iteration 120, Loss: -0.921594
Iteration 130, Loss: -0.964625
Iteration 140, Loss: -1.007331
Iteration 150, Loss: -1.049592
Iteration 160, Loss: -1.091303
Iteration 170, Loss: -1.132371
Iteration 180, Loss: -1.172718
Iteration 190, Loss: -1.212272
Iteration 200, Loss: -1.250976
Iteration 210, Loss: -1.288779
Iteration 220, Loss: -1.325638
Iteration 230, Loss: -1.361518
Iteration 240, Loss: -1.396392
Iteration 250, Loss: -1.430235
Iteration 260, Loss: -1.463031
Iteration 270, Loss: -1.494768
Iteration 280, Loss: -1.525438
Iteration 290, Loss: -1.555037
Iteration 300, Loss: -1.583566
Iteration 310, Loss: -1.611028
Iteration 320, Loss

In [88]:
print(result.final_fidelity)

0.999759638777322


In [87]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho0, type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, type="density"),
    )

initial fidelity: 0.20605832342551042
fidelity of state 0: 0.9997596387773222
fidelity of state 1: 0.9997596387773222
fidelity of state 2: 0.9997596387773222
fidelity of state 3: 0.9997596387773222
fidelity of state 4: 0.9997596387773222
fidelity of state 5: 0.9997596387773222
fidelity of state 6: 0.9997596387773222
fidelity of state 7: 0.9997596387773222
fidelity of state 8: 0.9997596387773222
fidelity of state 9: 0.9997596387773222


### With Dissipation

In [None]:
result = optimize_pulse_with_feedback(
    U_0=rho_target,
    C_target=rho_target,
    parameterized_gates=[
        povm_measure_operator,
        qubit_unitary,
        qubit_cavity_unitary,
    ],
    measurement_indices=[0],
    decay = {
        "decay_indices": [0,1], # indices of gates before which decay occurs
        "decay_rates": [0.01, 0.01], # decay rates for each gate
        "decay_durations": [1, 1], # duration of decay for each gate
    },
    initial_params = {
        "POVM": [0.0, jnp.pi / 3],
        "U_q": [jnp.pi / 2],
        "U_qc": [jnp.pi / 2],
    },
    num_time_steps=1,
    mode="lookup",
    goal="fidelity",
    optimizer="adam",
    max_iter=1000,
    convergence_threshold=1e-6,
    learning_rate=0.005,
    type="density",
    batch_size=1,
)