In [1]:
# ruff: noqa
import os

os.sys.path.append("..")
from feedback_grape.fgrape import optimize_pulse_with_feedback
from feedback_grape.utils.operators import (
    sigmap,
    sigmam,
    create,
    destroy,
    identity,
    cosm,
    sinm,
)
from feedback_grape.utils.states import basis, fock
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
import jax
from jax.scipy.linalg import expm

## defining parameterized operations that are repeated num_time_steps times

takes 9 minutes on GPU instead of 17 on CPU

In [2]:
N_cav = 20

In [3]:
def qubit_unitary(alpha_re, alpha_im):
    alpha = alpha_re + 1j * alpha_im
    return tensor(
        identity(N_cav),
        expm(-1j * (alpha * sigmap() + alpha.conjugate() * sigmam()) / 2),
    )

In [4]:
def qubit_cavity_unitary(beta_re, beta_im):
    beta = beta_re + 1j * beta_im
    return expm(
        -1j
        * (
            beta * (tensor(destroy(N_cav), sigmap()))
            + beta.conjugate() * (tensor(create(N_cav), sigmam()))
        )
        / 2
    )

In [5]:
Uq = qubit_unitary(0.1, 0.1)
Uqc = qubit_cavity_unitary(0.1, 0.1)
print(
    "Uq unitary check:",
    jnp.allclose(Uq.conj().T @ Uq, jnp.eye(Uq.shape[0]), atol=1e-7),
)
print(
    "Uqc unitary check:",
    jnp.allclose(Uqc.conj().T @ Uqc, jnp.eye(Uqc.shape[0]), atol=1e-7),
)

Uq unitary check: True
Uqc unitary check: True


In [6]:
from feedback_grape.utils.operators import create, destroy


def povm_measure_operator(measurement_outcome, gamma, delta):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    number_operator = tensor(create(N_cav) @ destroy(N_cav), identity(2))
    angle = (gamma * number_operator) + delta / 2
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

### defining initial (thermal) state

In [7]:
# initial state is a thermal state coupled to a qubit in the ground state?
n_average = 1
# natural logarithm
beta = jnp.log((1 / n_average) + 1)
diags = jnp.exp(-beta * jnp.arange(N_cav))
normalized_diags = diags / jnp.sum(diags, axis=0)
rho_cav = jnp.diag(normalized_diags)

In [8]:
rho_cav.shape

(20, 20)

In [9]:
rho0 = tensor(rho_cav, basis(2, 0) @ basis(2, 0).conj().T)

In [10]:
from feedback_grape.utils.povm import (
    _probability_of_a_measurement_outcome_given_a_certain_state,
)

_probability_of_a_measurement_outcome_given_a_certain_state(
    rho0, 1, povm_measure_operator, [0.1, -3 * jnp.pi / 2]
)

Array(0.9479526, dtype=float64)

### defining target state

In [11]:
psi_target = tensor(
    (fock(N_cav, 1) + fock(N_cav, 2) + fock(N_cav, 3)) / jnp.sqrt(3), basis(2)
)
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = psi_target @ psi_target.conj().T
rho_target.shape

(40, 40)

In [12]:
from feedback_grape.utils.fidelity import fidelity

print(fidelity(U_final=rho0, C_target=rho_target, type="density"))

0.38188149897874146


In [13]:
# Print the eigenvalues of rho0
eigenvalues = jnp.linalg.eigvalsh(rho_target)
print("Eigenvalues of rho0:", eigenvalues)

Eigenvalues of rho0: [-2.72829216e-16 -6.02376918e-17  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]


### initialize random params

In [14]:
import jax

print(
    jax.random.uniform(
        jax.random.PRNGKey(0),
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jax.numpy.pi,
        maxval=jax.numpy.pi,
    ).tolist()
)

[[-0.5123490775685872, -1.782568231202715]]


In [15]:
import jax

num_time_steps = 5
num_of_iterations = 1000
learning_rate = 0.05
# avg_photon_numer = 2 When testing kitten state
# TODO: remove this tolist thing
key = jax.random.PRNGKey(0)
measure = {
    "gate": povm_measure_operator,
    "initial_params": jax.random.uniform(
        key,
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jnp.pi,
        maxval=jnp.pi,
    )[0].tolist(),
    "measurement_flag": True,
    # "param_constraints": [[0, 0.5], [-1, 1]],
}

qub_unitary = {
    "gate": qubit_unitary,
    "initial_params": jax.random.uniform(
        key,
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jnp.pi,
        maxval=jnp.pi,
    )[0].tolist(),
    "measurement_flag": False,
    # "param_constraints": [[0, 0.5], [-1, 1]],
}

qub_cav = {
    "gate": qubit_cavity_unitary,
    "initial_params": jax.random.uniform(
        key,
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jnp.pi,
        maxval=jnp.pi,
    )[0].tolist(),
    "measurement_flag": False,
    # "param_constraints": [[0, 0.5], [-1, 1]],
}

system_params = [measure, qub_unitary, qub_cav]


result = optimize_pulse_with_feedback(
    U_0=rho0,
    C_target=rho_target,
    system_params=system_params,
    num_time_steps=num_time_steps,
    mode="lookup",
    goal="fidelity",
    max_iter=num_of_iterations,
    convergence_threshold=1e-6,
    learning_rate=learning_rate,
    type="density",
    batch_size=10,
)

Iteration 0, Loss: 0.310884
Iteration 10, Loss: 0.683557
Iteration 20, Loss: 0.530108
Iteration 30, Loss: 0.813211
Iteration 40, Loss: 0.987716
Iteration 50, Loss: 0.889173
Iteration 60, Loss: 0.995758
Iteration 70, Loss: 0.892294
Iteration 80, Loss: 1.082102
Iteration 90, Loss: 0.928762
Iteration 100, Loss: 0.940376
Iteration 110, Loss: 0.888819
Iteration 120, Loss: 1.021796
Iteration 130, Loss: 1.024147
Iteration 140, Loss: 0.883187
Iteration 150, Loss: 0.929602
Iteration 160, Loss: 1.021366
Iteration 170, Loss: 0.676595
Iteration 180, Loss: 1.079179
Iteration 190, Loss: 0.783117
Iteration 200, Loss: 0.432901
Iteration 210, Loss: 0.313344
Iteration 220, Loss: 0.357303
Iteration 230, Loss: 0.473585
Iteration 240, Loss: 0.643192
Iteration 250, Loss: 0.402930
Iteration 260, Loss: 0.700507
Iteration 270, Loss: 0.668225
Iteration 280, Loss: 0.709793
Iteration 290, Loss: 0.395643
Iteration 300, Loss: 0.669371
Iteration 310, Loss: 0.318889
Iteration 320, Loss: 0.492880
Iteration 330, Loss: 

In [16]:
# import jax

# num_time_steps = 5
# num_of_iterations = 100
# learning_rate = 0.05
# # avg_photon_numer = 2 When testing kitten state

# key = jax.random.PRNGKey(0)  
# initial_params = {
#     "POVM": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "U_q": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "U_qc": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
# }


# result = optimize_pulse_with_feedback(
#     U_0=rho0,
#     C_target=rho_target,
#     parameterized_gates=[
#         povm_measure_operator,
#         qubit_unitary,
#         qubit_cavity_unitary,
#     ],
#     measurement_indices=[0],
#     initial_params=initial_params,
#     num_time_steps=num_time_steps,
#     mode="lookup",
#     lookup_min_init_value=-jnp.pi,
#     lookup_max_init_value=jnp.pi,
#     goal="fidelity",
#     max_iter=num_of_iterations,
#     convergence_threshold=1e-6,
#     learning_rate=learning_rate,
#     type="density",
#     batch_size=10,
# )

In [17]:
print(result.final_purity)

None


In [18]:
# 0.602319331001811
print(result.final_fidelity)

0.8182228948764161


In [19]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho0, type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, type="density"),
    )

initial fidelity: 0.38188149897874146
fidelity of state 0: 0.8534223742536475
fidelity of state 1: 0.861455474896109
fidelity of state 2: 0.4451958759841024
fidelity of state 3: 0.861455474896109
fidelity of state 4: 0.861455474896109
fidelity of state 5: 0.861455474896109
fidelity of state 6: 0.861455474896109
fidelity of state 7: 0.8534223742536475
fidelity of state 8: 0.861455474896109
fidelity of state 9: 0.861455474896109


In [20]:
result.final_state

Array([[[ 6.88277432e-02-1.79977561e-17j,
         -5.37263325e-03+1.47250190e-02j,
         -1.73926786e-02-8.79174434e-03j, ...,
         -8.74367386e-03+5.54191072e-04j,
         -5.00333052e-03-4.12465847e-04j,
         -8.50597237e-03+5.61282247e-04j],
        [-5.37263325e-03-1.47250190e-02j,
          4.61256119e-03-1.62630326e-19j,
         -1.40285381e-02+3.07723281e-03j, ...,
          1.35383860e-03+1.85975249e-03j,
          6.08341730e-04+1.16952748e-03j,
          1.33009395e-03+1.80436096e-03j],
        [-1.73926786e-02+8.79174434e-03j,
         -1.40285381e-02-3.07723281e-03j,
          3.01186793e-01-5.20417043e-17j, ...,
         -4.83017820e-03-1.15990039e-03j,
         -2.47337859e-03-1.12213536e-03j,
         -4.96551432e-03-1.04921399e-03j],
        ...,
        [-8.74367386e-03-5.54191072e-04j,
          1.35383860e-03-1.85975249e-03j,
         -4.83017820e-03+1.15990039e-03j, ...,
          1.52581314e-03-2.71050543e-20j,
          8.60157822e-04+1.28144192e-04j

In [21]:
result.optimized_trainable_parameters['lookup_table']

[[Array([ 5.48037561e-11, -2.51327412e+00, -3.55396690e-02, -1.23146225e+00,
         -1.30948291e-01, -3.52153406e+00], dtype=float64),
  Array([-1.25588274e-10, -2.19911486e+00, -4.04872059e-01, -1.93001936e+00,
         -9.83774363e-01, -1.87973087e+00], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0

In [22]:
result.returned_params

[[Array([[-1.56139992,  0.01620462],
         [-1.56139992,  0.01620462],
         [-1.56139992,  0.01620462],
         [-1.56139992,  0.01620462],
         [-1.56139992,  0.01620462],
         [-1.56139992,  0.01620462],
         [-1.56139992,  0.01620462],
         [-1.56139992,  0.01620462],
         [-1.56139992,  0.01620462],
         [-1.56139992,  0.01620462]], dtype=float64),
  Array([[-0.40487206, -1.93001936],
         [-0.03553967, -1.23146225],
         [-0.03553967, -1.23146225],
         [-0.03553967, -1.23146225],
         [-0.03553967, -1.23146225],
         [-0.03553967, -1.23146225],
         [-0.03553967, -1.23146225],
         [-0.40487206, -1.93001936],
         [-0.03553967, -1.23146225],
         [-0.03553967, -1.23146225]], dtype=float64),
  Array([[-0.98377436, -1.87973087],
         [-0.13094829, -3.52153406],
         [-0.13094829, -3.52153406],
         [-0.13094829, -3.52153406],
         [-0.13094829, -3.52153406],
         [-0.13094829, -3.52153406],
    

In [23]:
print(result.iterations)

1000


In [24]:
print(result.returned_params[1])

[Array([[-1.25588274e-10, -2.19911486e+00],
       [ 5.48037561e-11, -2.51327412e+00],
       [ 5.48037561e-11, -2.51327412e+00],
       [ 5.48037561e-11, -2.51327412e+00],
       [ 5.48037561e-11, -2.51327412e+00],
       [ 5.48037561e-11, -2.51327412e+00],
       [ 5.48037561e-11, -2.51327412e+00],
       [-1.25588274e-10, -2.19911486e+00],
       [ 5.48037561e-11, -2.51327412e+00],
       [ 5.48037561e-11, -2.51327412e+00]], dtype=float64), Array([[-0.39295708, -1.01276577],
       [-0.05146666, -0.75345878],
       [-0.05146666, -0.75345878],
       [-0.05146666, -0.75345878],
       [-0.05146666, -0.75345878],
       [-0.05146666, -0.75345878],
       [-0.05146666, -0.75345878],
       [-0.39295708, -1.01276577],
       [-0.05146666, -0.75345878],
       [-0.05146666, -0.75345878]], dtype=float64), Array([[-2.03327775, -3.12462253],
       [ 0.2058885 ,  0.08555048],
       [ 0.2058885 ,  0.08555048],
       [ 0.2058885 ,  0.08555048],
       [ 0.2058885 ,  0.08555048],
       [ 0

In [25]:
# measure = {
#     "gate": povm_measure_operator,
#     "initial_params": [0, jnp.pi / 2],
#     "measurement_flag": True,
#     "param_constraints": [[0, 0.5], [-1, 1]],
# }

# qub_unitary = {
#     "gate": qubit_unitary,
#     "initial_params": [0, jnp.pi / 2],
#     "measurement_flag": False,
#     "param_constraints": [[0, 0.5], [-1, 1]],
# }

# qub_cav = {
#     "gate": qubit_cavity_unitary,
#     "initial_params": [0, jnp.pi / 2],
#     "measurement_flag": False,
#     "param_constraints": [[0, 0.5], [-1, 1]],
# }

# system_params = [measure, qub_unitary, qub_cav]


# result = optimize_pulse_with_feedback(
#     U_0=rho0,
#     C_target=rho_target,
#     system_params=system_params,
#     num_time_steps=num_time_steps,
#     mode="lookup",
#     goal="fidelity",
#     max_iter=num_of_iterations,
#     convergence_threshold=1e-6,
#     learning_rate=learning_rate,
#     type="density",
#     batch_size=10,
# )

In [26]:
# result = optimize_pulse_with_feedback(
#     U_0=rho0,
#     C_target=rho_target,
#     num_time_steps=num_time_steps,
#     mode="lookup",
#     goal="fidelity",
#     max_iter=num_of_iterations,
#     convergence_threshold=1e-6,
#     learning_rate=learning_rate,
#     type="density",
#     batch_size=10,
# )

In [27]:
# measure = {
#     "gate": povm_measure_operator,
#     "initial_params": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "measurement_flag": True,
#     "param_constraints": [[0, 0.5], [-1, 1]],
# }

# qub_unitary = {
#     "gate": qubit_unitary,
#     "initial_params": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "measurement_flag": False,
#     "param_constraints": [[0, 0.5], [-1, 1]],
# }

# qub_cav = {
#     "gate": qubit_cavity_unitary,
#     "initial_params": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "measurement_flag": False,
#     "param_constraints": [[0, 0.5], [-1, 1]],
# }

# system_params = [measure, qub_unitary, qub_cav]

In [28]:
# initial_params = {
#     "POVM": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "U_q": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "U_qc": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
# }

# parameterized_gates=[
#     povm_measure_operator,
#     qubit_unitary,
#     qubit_cavity_unitary,
# ],
# measurement_indices=[0]

In [29]:
# def convert_system_params(system_params):
#     """
#     Convert system_params format to (initial_params, parameterized_gates, measurement_indices) format.
    
#     Args:
#         system_params: List of dictionaries, each containing:
#             - "gate": gate function
#             - "initial_params": list of parameters
#             - "measurement_flag": boolean indicating if this is a measurement gate
    
#     Returns:
#         tuple: (initial_params, parameterized_gates, measurement_indices)
#             - initial_params: dict mapping gate names/types to parameter lists
#             - parameterized_gates: list of gate functions
#             - measurement_indices: list of indices where measurement gates appear
#     """
#     initial_params = {}
#     parameterized_gates = []
#     measurement_indices = []
#     param_constraints = []
    
#     for i, gate_config in enumerate(system_params):
#         gate_func = gate_config["gate"]
#         params = gate_config["initial_params"]
#         is_measurement = gate_config["measurement_flag"]
        
#         # Add gate to parameterized_gates list
#         parameterized_gates.append(gate_func)
        
#         # If this is a measurement gate, add its index
#         if is_measurement:
#             measurement_indices.append(i)
        
#         param_name = f"gate_{i}"
        
#         initial_params[param_name] = params

#         # Add parameter constraints if provided
#         # TODO: make sure to have default
#         param_constraints.append(gate_config.get("param_constraints", None))
    
#     return initial_params, parameterized_gates, measurement_indices, param_constraints


# # Alternative version if you want to use gate function references directly
# def convert_system_params_with_mapping(system_params, gate_name_map=None):
#     """
#     Convert system_params with explicit gate name mapping.
    
#     Args:
#         system_params: List of gate configurations
#         gate_name_map: Optional dict mapping gate functions to parameter names
    
#     Returns:
#         tuple: (initial_params, parameterized_gates, measurement_indices)
#     """
#     if gate_name_map is None:
#         gate_name_map = {}
    
#     initial_params = {}
#     parameterized_gates = []
#     measurement_indices = []
    
#     for i, gate_config in enumerate(system_params):
#         gate_func = gate_config["gate"]
#         params = gate_config["initial_params"]
#         is_measurement = gate_config["measurement_flag"]
        
#         parameterized_gates.append(gate_func)
        
#         if is_measurement:
#             measurement_indices.append(i)
        
#         # Use provided mapping or fallback to function name
#         param_name = gate_name_map.get(gate_func, gate_func.__name__)
#         initial_params[param_name] = params
    
#     return initial_params, parameterized_gates, measurement_indices

In [30]:
# convert_system_params(system_params)