## State Stabilization

In [1]:
# ruff: noqa
import os

os.sys.path.append("..")
from feedback_grape.fgrape import optimize_pulse_with_feedback
from feedback_grape.utils.operators import (
    sigmap,
    sigmam,
    create,
    destroy,
    identity,
    cosm,
    sinm,
)
from feedback_grape.utils.states import basis, fock
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
from jax.scipy.linalg import expm

In [2]:
N_cav = 30

In [3]:
def qubit_unitary(alpha):
    """
    TODO: see if alpha, can be sth elser other than scalar, and if the algo understands this
    see if there can be multiple params like alpha and beta input
    """
    return expm(
        -1j
        * (
            alpha * tensor(identity(N_cav), sigmap())
            + alpha.conjugate() * tensor(identity(N_cav), sigmam())
        )
        / 2
    )

In [4]:
def qubit_cavity_unitary(beta):
    return expm(
        -1j
        * (
            beta
            * (
                tensor(destroy(N_cav), identity(2))
                @ tensor(identity(N_cav), sigmap())
            )
            + beta.conjugate()
            * (
                tensor(create(N_cav), identity(2))
                @ tensor(identity(N_cav), sigmam())
            )
        )
        / 2
    )

In [5]:
from feedback_grape.utils.operators import create, destroy


def povm_measure_operator(measurement_outcome, gamma, delta):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    number_operator = tensor(create(N_cav) @ destroy(N_cav), identity(2))
    angle = (gamma * number_operator) + delta / 2
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

In [6]:
# initial state is a thermal state coupled to a qubit in the ground state?
n_average = 9
# natural logarithm
beta = jnp.log((1 / n_average) + 1)
diags = jnp.exp(-beta * jnp.arange(N_cav))
normalized_diags = diags / jnp.sum(diags, axis=0)
rho_cav = jnp.diag(normalized_diags)

In [7]:
rho0 = tensor(rho_cav, basis(2, 0) @ basis(2, 0).conj().T)

In [8]:
rho0.shape

(60, 60)

In [9]:
from feedback_grape.utils.states import coherent

alpha = 3
psi_target = tensor(
    coherent(N_cav, alpha)
    + coherent(N_cav, -alpha)
    + coherent(N_cav, 1j * alpha)
    + coherent(N_cav, -1j * alpha),
    basis(2),
)  # 4-legged state

rho_target = psi_target @ psi_target.conj().T

In [10]:
rho_target.shape

(60, 60)

In [11]:
from feedback_grape.utils.fidelity import fidelity

print(fidelity(U_final=rho0, C_target=rho_target, type="density"))

0.20605832342551042


### Without dissipation

In [None]:
result = optimize_pulse_with_feedback(
    U_0=rho_target,
    C_target=rho_target,
    parameterized_gates=[
        povm_measure_operator,
        qubit_unitary,
        qubit_cavity_unitary,
    ],
    measurement_indices=[0],
    initial_params={
        "POVM": [0.0, jnp.pi / 3],
        "U_q": [jnp.pi / 2],
        "U_qc": [jnp.pi / 2],
    },
    num_time_steps=2,
    mode="lookup",
    goal="fidelity",
    optimizer="adam",
    max_iter=1000,
    convergence_threshold=1e-6,
    learning_rate=0.005,
    type="density",
    batch_size=1,
)

In [None]:
print(result.final_fidelity)

0.9999999898509099


In [None]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho0, type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, type="density"),
    )

initial fidelity: 0.20605832342551042
fidelity of state 0: 0.9999999898509099


### With Dissipation

In [12]:
result = optimize_pulse_with_feedback(
    U_0=rho_target,
    C_target=rho_target,
    parameterized_gates=[
        povm_measure_operator,
        qubit_unitary,
        qubit_cavity_unitary,
    ],
    measurement_indices=[0],
    decay={
        "decay_indices": [0],  # indices of gates before which decay occurs
        # c_ops need to be tensored with the identity operator for the cavity
        # because it is used directly in the lindblad equation
        "c_ops":  # c_ops for each decay index
        [
            [
                tensor(identity(N_cav), jnp.sqrt(0.01) * sigmam()),
            ],
            [tensor(identity(N_cav), jnp.sqrt(0.1) * sigmap())],
        ],
        "time_grid": [0.2],  # time grid for decay
        "Hamiltonian": None,
    },
    initial_params={
        "POVM": [0.0, jnp.pi / 3],
        "U_q": [jnp.pi / 2],
        "U_qc": [jnp.pi / 2],
    },
    num_time_steps=2,
    mode="lookup",
    goal="fidelity",
    optimizer="adam",
    max_iter=1000,
    convergence_threshold=1e-6,
    learning_rate=0.005,
    type="density",
    batch_size=1,
)

Iteration 0, Loss: -0.184245
Iteration 10, Loss: -0.045857
Iteration 20, Loss: -0.060631
Iteration 30, Loss: -0.731739
Iteration 40, Loss: -0.977474
Iteration 50, Loss: -0.182478
Iteration 60, Loss: -1.259945
Iteration 70, Loss: -1.321056
Iteration 80, Loss: -1.365773
Iteration 90, Loss: -1.390369
Iteration 100, Loss: -1.392654
Iteration 110, Loss: -1.369867
Iteration 120, Loss: -1.396049
Iteration 130, Loss: -1.398159
Iteration 140, Loss: -1.373595
Iteration 150, Loss: -1.401985
Iteration 160, Loss: -0.493813
Iteration 170, Loss: -1.184238
Iteration 180, Loss: -1.280387
Iteration 190, Loss: -1.353666
Iteration 200, Loss: -1.361504
Iteration 210, Loss: -0.776920
Iteration 220, Loss: -0.866213
Iteration 230, Loss: -0.815163
Iteration 240, Loss: -1.158534
Iteration 250, Loss: -1.010049
Iteration 260, Loss: -1.074011
Iteration 270, Loss: -1.127229
Iteration 280, Loss: -1.023924
Iteration 290, Loss: -1.082989
Iteration 300, Loss: -1.133557
Iteration 310, Loss: -1.041971
Iteration 320, Loss

In [13]:
# 0.9999999976677415
print(result.final_fidelity)

0.9999999898509099


In [14]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho0, type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, type="density"),
    )

initial fidelity: 0.20605832342551042
fidelity of state 0: 0.9999999898509099


In [15]:
c_ops = (
    [  # c_ops for each decay index
        [
            tensor(identity(N_cav), jnp.sqrt(0.01) * sigmam()),
        ],
        ["hi"],
    ],
)

In [None]:
c_ops_iter = iter(c_ops)
for _ in range(2):
    print(next(c_ops_iter))

[Array([[0. +0.j, 0. +0.j, 0. +0.j, ..., 0. +0.j, 0. +0.j, 0. +0.j],
       [0.1+0.j, 0. +0.j, 0. +0.j, ..., 0. +0.j, 0. +0.j, 0. +0.j],
       [0. +0.j, 0. +0.j, 0. +0.j, ..., 0. +0.j, 0. +0.j, 0. +0.j],
       ...,
       [0. +0.j, 0. +0.j, 0. +0.j, ..., 0. +0.j, 0. +0.j, 0. +0.j],
       [0. +0.j, 0. +0.j, 0. +0.j, ..., 0. +0.j, 0. +0.j, 0. +0.j],
       [0. +0.j, 0. +0.j, 0. +0.j, ..., 0. +0.j, 0.1+0.j, 0. +0.j]],      dtype=complex128)]
['hi']
