In [1]:
# ruff: noqa
import os

os.sys.path.append("..")
from feedback_grape.fgrape import optimize_pulse_with_feedback
from feedback_grape.utils.operators import (
    sigmap,
    sigmam,
    create,
    destroy,
    identity,
    cosm,
    sinm,
)
from feedback_grape.utils.states import basis, fock
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
import jax
from jax.scipy.linalg import expm

## defining parameterized operations that are repeated num_time_steps times

takes 9 minutes on GPU instead of 17 on CPU

In [2]:
N_cav = 20

In [3]:
def qubit_unitary(alpha_re, alpha_im):
    alpha = alpha_re + 1j * alpha_im
    return tensor(
        identity(N_cav),
        expm(-1j * (alpha * sigmap() + alpha.conjugate() * sigmam()) / 2),
    )

In [4]:
def qubit_cavity_unitary(beta_re, beta_im):
    beta = beta_re + 1j * beta_im
    return expm(
        -1j
        * (
            beta * (tensor(destroy(N_cav), sigmap()))
            + beta.conjugate() * (tensor(create(N_cav), sigmam()))
        )
        / 2
    )

In [5]:
Uq = qubit_unitary(0.1, 0.1)
Uqc = qubit_cavity_unitary(0.1, 0.1)
print(
    "Uq unitary check:",
    jnp.allclose(Uq.conj().T @ Uq, jnp.eye(Uq.shape[0]), atol=1e-7),
)
print(
    "Uqc unitary check:",
    jnp.allclose(Uqc.conj().T @ Uqc, jnp.eye(Uqc.shape[0]), atol=1e-7),
)

Uq unitary check: True
Uqc unitary check: True


In [6]:
from feedback_grape.utils.operators import create, destroy


def povm_measure_operator(measurement_outcome, gamma, delta):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    number_operator = tensor(create(N_cav) @ destroy(N_cav), identity(2))
    angle = (gamma * number_operator) + delta / 2
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

### defining initial (thermal) state

In [7]:
# initial state is a thermal state coupled to a qubit in the ground state?
n_average = 1
# natural logarithm
beta = jnp.log((1 / n_average) + 1)
diags = jnp.exp(-beta * jnp.arange(N_cav))
normalized_diags = diags / jnp.sum(diags, axis=0)
rho_cav = jnp.diag(normalized_diags)

In [8]:
rho_cav.shape

(20, 20)

In [9]:
rho0 = tensor(rho_cav, basis(2, 0) @ basis(2, 0).conj().T)

In [10]:
from feedback_grape.utils.povm import (
    _probability_of_a_measurement_outcome_given_a_certain_state,
)

_probability_of_a_measurement_outcome_given_a_certain_state(
    rho0, 1, povm_measure_operator, [0.1, -3 * jnp.pi / 2]
)

Array(0.9479526, dtype=float64)

### defining target state

In [11]:
psi_target = tensor(
    (fock(N_cav, 1) + fock(N_cav, 2) + fock(N_cav, 3)) / jnp.sqrt(3), basis(2)
)
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = psi_target @ psi_target.conj().T
rho_target.shape

(40, 40)

In [12]:
from feedback_grape.utils.fidelity import fidelity

print(fidelity(U_final=rho0, C_target=rho_target, evo_type="density"))

0.38188149897874146


In [13]:
# Print the eigenvalues of rho0
eigenvalues = jnp.linalg.eigvalsh(rho_target)
print("Eigenvalues of rho0:", eigenvalues)

Eigenvalues of rho0: [-2.72829216e-16 -6.02376918e-17  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]


### initialize random params

In [14]:
import jax

print(
    jax.random.uniform(
        jax.random.PRNGKey(0),
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jax.numpy.pi,
        maxval=jax.numpy.pi,
    ).tolist()
)

[[-0.5123490775685872, -1.782568231202715]]


In [15]:
import jax

num_time_steps = 5
num_of_iterations = 1000
learning_rate = 0.05
# avg_photon_numer = 2 When testing kitten state
# TODO: remove this tolist thing
# TODO: see if one should allow providing gate parameters for only one gate --> I don't think so
key = jax.random.PRNGKey(0)
measure = {
    "gate": povm_measure_operator,
    "initial_params": jax.random.uniform(
        key,
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jnp.pi,
        maxval=jnp.pi,
    )[0].tolist(),
    "measurement_flag": True,
    # "param_constraints": [[0, jnp.pi], [-2*jnp.pi, 2*jnp.pi]],
}

qub_unitary = {
    "gate": qubit_unitary,
    "initial_params": jax.random.uniform(
        key,
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jnp.pi,
        maxval=jnp.pi,
    )[0].tolist(),
    "measurement_flag": False,
    # "param_constraints": [[-2*jnp.pi, 2*jnp.pi], [-2*jnp.pi, 2*jnp.pi]],
}

qub_cav = {
    "gate": qubit_cavity_unitary,
    "initial_params": jax.random.uniform(
        key,
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jnp.pi,
        maxval=jnp.pi,
    )[0].tolist(),
    "measurement_flag": False,
    # "param_constraints": [[-jnp.pi, jnp.pi], [-jnp.pi, jnp.pi]],
}

system_params = [measure, qub_unitary, qub_cav]


result = optimize_pulse_with_feedback(
    U_0=rho0,
    C_target=rho_target,
    system_params=system_params,
    num_time_steps=num_time_steps,
    mode="lookup",
    goal="fidelity",
    max_iter=num_of_iterations,
    convergence_threshold=1e-6,
    learning_rate=learning_rate,
    evo_type="density",
    batch_size=10,
)

2025-06-23 09:47:30.148466: E external/xla/xla/service/slow_operation_alarm.cc:73] 
********************************
[Compiling module jit_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2025-06-23 09:47:41.849567: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 2m11.704962s

********************************
[Compiling module jit_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


Iteration 0, Loss: 0.438212
Iteration 10, Loss: 0.601029
Iteration 20, Loss: 0.730168
Iteration 30, Loss: 0.541349
Iteration 40, Loss: 1.081585
Iteration 50, Loss: 0.899316
Iteration 60, Loss: 0.953659
Iteration 70, Loss: 0.962733
Iteration 80, Loss: 0.944012
Iteration 90, Loss: 1.133018
Iteration 100, Loss: 1.068808
Iteration 110, Loss: 1.143562
Iteration 120, Loss: 1.074213
Iteration 130, Loss: 0.971917
Iteration 140, Loss: 0.885523
Iteration 150, Loss: 0.831739
Iteration 160, Loss: 0.730409
Iteration 170, Loss: 1.131186
Iteration 180, Loss: 1.251616
Iteration 190, Loss: 0.994662
Iteration 200, Loss: 1.069638
Iteration 210, Loss: 1.131658
Iteration 220, Loss: 0.991689
Iteration 230, Loss: 1.320514
Iteration 240, Loss: 1.187055
Iteration 250, Loss: 1.000081
Iteration 260, Loss: 0.961800
Iteration 270, Loss: 1.082210
Iteration 280, Loss: 1.115361
Iteration 290, Loss: 0.764079
Iteration 300, Loss: 1.284933
Iteration 310, Loss: 0.906655
Iteration 320, Loss: 0.832359
Iteration 330, Loss: 

In [16]:
# import jax

# num_time_steps = 5
# num_of_iterations = 100
# learning_rate = 0.05
# # avg_photon_numer = 2 When testing kitten state

# key = jax.random.PRNGKey(0)
# initial_params = {
#     "POVM": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "U_q": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "U_qc": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
# }


# result = optimize_pulse_with_feedback(
#     U_0=rho0,
#     C_target=rho_target,
#     parameterized_gates=[
#         povm_measure_operator,
#         qubit_unitary,
#         qubit_cavity_unitary,
#     ],
#     measurement_indices=[0],
#     initial_params=initial_params,
#     num_time_steps=num_time_steps,
#     mode="lookup",
#     lookup_min_init_value=-jnp.pi,
#     lookup_max_init_value=jnp.pi,
#     goal="fidelity",
#     max_iter=num_of_iterations,
#     convergence_threshold=1e-6,
#     learning_rate=learning_rate,
#     evo_type="density",
#     batch_size=10,
# )

In [17]:
print(result.final_purity)

None


In [18]:
# 0.602319331001811
print(result.final_fidelity)

0.7729112294290601


In [19]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho0, evo_type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, evo_type="density"),
    )

initial fidelity: 0.38188149897874146
fidelity of state 0: 0.5755941603557482
fidelity of state 1: 0.8472061276885658
fidelity of state 2: 0.767798931270149
fidelity of state 3: 0.7803519518776776
fidelity of state 4: 0.8146532291142413
fidelity of state 5: 0.7803519518776776
fidelity of state 6: 0.8472061276885658
fidelity of state 7: 0.767798931270149
fidelity of state 8: 0.767798931270149
fidelity of state 9: 0.7803519518776776


In [20]:
result.final_state

Array([[[ 7.28984427e-02+7.80625564e-18j,
         -7.70018820e-03-3.42452775e-03j,
         -8.50541939e-02+2.97616643e-02j, ...,
         -2.70961638e-03-1.52414570e-03j,
         -2.76573783e-02+2.49456839e-03j,
         -1.74677353e-03+1.47032774e-03j],
        [-7.70018820e-03+3.42452775e-03j,
          8.51942608e-02+1.21430643e-17j,
          4.28855121e-02-3.76017808e-03j, ...,
          8.41194513e-03-2.19817819e-03j,
          2.24556966e-02+3.60931739e-03j,
          5.14476419e-03-6.57884400e-03j],
        [-8.50541939e-02-2.97616643e-02j,
          4.28855121e-02+3.76017808e-03j,
          1.29924031e-01+1.82145965e-17j, ...,
          5.56174069e-03+1.25670556e-03j,
          4.16588630e-02+9.55551995e-03j,
          4.68495980e-03-4.23786146e-03j],
        ...,
        [-2.70961638e-03+1.52414570e-03j,
          8.41194513e-03+2.19817819e-03j,
          5.56174069e-03-1.25670556e-03j, ...,
          1.90481479e-03-6.50521303e-19j,
          3.21545899e-03+4.41098287e-04j

In [21]:
result.optimized_trainable_parameters['lookup_table']

[[Array([-0.76905561, -1.5402756 ,  0.03191816, -1.75404934, -0.06484542,
         -2.82110531], dtype=float64),
  Array([-1.47446526, -1.1039136 ,  0.68042511, -3.90795148, -0.21359896,
         -1.48232369], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.,

In [22]:
result.returned_params

[[Array([[-0.91692414, -1.57937911],
         [-0.91692414, -1.57937911],
         [-0.91692414, -1.57937911],
         [-0.91692414, -1.57937911],
         [-0.91692414, -1.57937911],
         [-0.91692414, -1.57937911],
         [-0.91692414, -1.57937911],
         [-0.91692414, -1.57937911],
         [-0.91692414, -1.57937911],
         [-0.91692414, -1.57937911]], dtype=float64),
  Array([[ 0.68042511, -3.90795148],
         [ 0.03191816, -1.75404934],
         [ 0.68042511, -3.90795148],
         [ 0.03191816, -1.75404934],
         [ 0.68042511, -3.90795148],
         [ 0.03191816, -1.75404934],
         [ 0.03191816, -1.75404934],
         [ 0.68042511, -3.90795148],
         [ 0.68042511, -3.90795148],
         [ 0.03191816, -1.75404934]], dtype=float64),
  Array([[-0.21359896, -1.48232369],
         [-0.06484542, -2.82110531],
         [-0.21359896, -1.48232369],
         [-0.06484542, -2.82110531],
         [-0.21359896, -1.48232369],
         [-0.06484542, -2.82110531],
    

In [23]:
print(result.iterations)

1000


In [24]:
print(result.returned_params[1])

[Array([[-1.47446526, -1.1039136 ],
       [-0.76905561, -1.5402756 ],
       [-1.47446526, -1.1039136 ],
       [-0.76905561, -1.5402756 ],
       [-1.47446526, -1.1039136 ],
       [-0.76905561, -1.5402756 ],
       [-0.76905561, -1.5402756 ],
       [-1.47446526, -1.1039136 ],
       [-1.47446526, -1.1039136 ],
       [-0.76905561, -1.5402756 ]], dtype=float64), Array([[-0.34323657,  0.52392861],
       [ 3.29113447, -0.05401933],
       [-0.51011992, -1.18823013],
       [-0.26776948, -0.6404829 ],
       [-0.34323657,  0.52392861],
       [-0.26776948, -0.6404829 ],
       [ 3.29113447, -0.05401933],
       [-0.51011992, -1.18823013],
       [-0.51011992, -1.18823013],
       [-0.26776948, -0.6404829 ]], dtype=float64), Array([[ 0.66344151, -2.12811321],
       [-0.31401645, -2.31034005],
       [-1.31292827, -3.03346913],
       [ 0.92754035, -1.5090427 ],
       [ 0.66344151, -2.12811321],
       [ 0.92754035, -1.5090427 ],
       [-0.31401645, -2.31034005],
       [-1.31292827,

In [25]:
# measure = {
#     "gate": povm_measure_operator,
#     "initial_params": [0, jnp.pi / 2],
#     "measurement_flag": True,
#     "param_constraints": [[0, 0.5], [-1, 1]],
# }

# qub_unitary = {
#     "gate": qubit_unitary,
#     "initial_params": [0, jnp.pi / 2],
#     "measurement_flag": False,
#     "param_constraints": [[0, 0.5], [-1, 1]],
# }

# qub_cav = {
#     "gate": qubit_cavity_unitary,
#     "initial_params": [0, jnp.pi / 2],
#     "measurement_flag": False,
#     "param_constraints": [[0, 0.5], [-1, 1]],
# }

# system_params = [measure, qub_unitary, qub_cav]


# result = optimize_pulse_with_feedback(
#     U_0=rho0,
#     C_target=rho_target,
#     system_params=system_params,
#     num_time_steps=num_time_steps,
#     mode="lookup",
#     goal="fidelity",
#     max_iter=num_of_iterations,
#     convergence_threshold=1e-6,
#     learning_rate=learning_rate,
#     evo_type="density",
#     batch_size=10,
# )

In [26]:
# result = optimize_pulse_with_feedback(
#     U_0=rho0,
#     C_target=rho_target,
#     num_time_steps=num_time_steps,
#     mode="lookup",
#     goal="fidelity",
#     max_iter=num_of_iterations,
#     convergence_threshold=1e-6,
#     learning_rate=learning_rate,
#     evo_type="density",
#     batch_size=10,
# )

In [27]:
# measure = {
#     "gate": povm_measure_operator,
#     "initial_params": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "measurement_flag": True,
#     "param_constraints": [[0, 0.5], [-1, 1]],
# }

# qub_unitary = {
#     "gate": qubit_unitary,
#     "initial_params": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "measurement_flag": False,
#     "param_constraints": [[0, 0.5], [-1, 1]],
# }

# qub_cav = {
#     "gate": qubit_cavity_unitary,
#     "initial_params": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "measurement_flag": False,
#     "param_constraints": [[0, 0.5], [-1, 1]],
# }

# system_params = [measure, qub_unitary, qub_cav]

In [28]:
# initial_params = {
#     "POVM": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "U_q": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "U_qc": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
# }

# parameterized_gates=[
#     povm_measure_operator,
#     qubit_unitary,
#     qubit_cavity_unitary,
# ],
# measurement_indices=[0]

In [29]:
# def convert_system_params(system_params):
#     """
#     Convert system_params format to (initial_params, parameterized_gates, measurement_indices) format.

#     Args:
#         system_params: List of dictionaries, each containing:
#             - "gate": gate function
#             - "initial_params": list of parameters
#             - "measurement_flag": boolean indicating if this is a measurement gate

#     Returns:
#         tuple: (initial_params, parameterized_gates, measurement_indices)
#             - initial_params: dict mapping gate names/types to parameter lists
#             - parameterized_gates: list of gate functions
#             - measurement_indices: list of indices where measurement gates appear
#     """
#     initial_params = {}
#     parameterized_gates = []
#     measurement_indices = []
#     param_constraints = []

#     for i, gate_config in enumerate(system_params):
#         gate_func = gate_config["gate"]
#         params = gate_config["initial_params"]
#         is_measurement = gate_config["measurement_flag"]

#         # Add gate to parameterized_gates list
#         parameterized_gates.append(gate_func)

#         # If this is a measurement gate, add its index
#         if is_measurement:
#             measurement_indices.append(i)

#         param_name = f"gate_{i}"

#         initial_params[param_name] = params

#         # Add parameter constraints if provided
#         # TODO: make sure to have default
#         param_constraints.append(gate_config.get("param_constraints", None))

#     return initial_params, parameterized_gates, measurement_indices, param_constraints


# # Alternative version if you want to use gate function references directly
# def convert_system_params_with_mapping(system_params, gate_name_map=None):
#     """
#     Convert system_params with explicit gate name mapping.

#     Args:
#         system_params: List of gate configurations
#         gate_name_map: Optional dict mapping gate functions to parameter names

#     Returns:
#         tuple: (initial_params, parameterized_gates, measurement_indices)
#     """
#     if gate_name_map is None:
#         gate_name_map = {}

#     initial_params = {}
#     parameterized_gates = []
#     measurement_indices = []

#     for i, gate_config in enumerate(system_params):
#         gate_func = gate_config["gate"]
#         params = gate_config["initial_params"]
#         is_measurement = gate_config["measurement_flag"]

#         parameterized_gates.append(gate_func)

#         if is_measurement:
#             measurement_indices.append(i)

#         # Use provided mapping or fallback to function name
#         param_name = gate_name_map.get(gate_func, gate_func.__name__)
#         initial_params[param_name] = params

#     return initial_params, parameterized_gates, measurement_indices

In [30]:
# convert_system_params(system_params)