In [1]:
# ruff: noqa
import os

os.sys.path.append("..")
from feedback_grape.fgrape import optimize_pulse_with_feedback
from feedback_grape.utils.operators import (
    sigmap,
    sigmam,
    create,
    destroy,
    identity,
    cosm,
    sinm,
)
from feedback_grape.utils.states import basis, fock
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
import jax
from jax.scipy.linalg import expm

## defining parameterized operations that are repeated num_time_steps times

takes 9 minutes on GPU instead of 17 on CPU

In [2]:
N_cav = 20

In [3]:
def qubit_unitary(alpha_re, alpha_im):
    alpha = alpha_re + 1j * alpha_im
    return tensor(
        identity(N_cav),
        expm(-1j * (alpha * sigmap() + alpha.conjugate() * sigmam()) / 2),
    )

In [4]:
def qubit_cavity_unitary(beta_re, beta_im):
    beta = beta_re + 1j * beta_im
    return expm(
        -1j
        * (
            beta * (tensor(destroy(N_cav), sigmap()))
            + beta.conjugate() * (tensor(create(N_cav), sigmam()))
        )
        / 2
    )

In [5]:
Uq = qubit_unitary(0.1, 0.1)
Uqc = qubit_cavity_unitary(0.1, 0.1)
print(
    "Uq unitary check:",
    jnp.allclose(Uq.conj().T @ Uq, jnp.eye(Uq.shape[0]), atol=1e-7),
)
print(
    "Uqc unitary check:",
    jnp.allclose(Uqc.conj().T @ Uqc, jnp.eye(Uqc.shape[0]), atol=1e-7),
)

Uq unitary check: True
Uqc unitary check: True


In [6]:
from feedback_grape.utils.operators import create, destroy


def povm_measure_operator(measurement_outcome, gamma, delta):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    number_operator = tensor(create(N_cav) @ destroy(N_cav), identity(2))
    angle = (gamma * number_operator) + delta / 2
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

### defining initial (thermal) state

In [7]:
# initial state is a thermal state coupled to a qubit in the ground state?
n_average = 1
# natural logarithm
beta = jnp.log((1 / n_average) + 1)
diags = jnp.exp(-beta * jnp.arange(N_cav))
normalized_diags = diags / jnp.sum(diags, axis=0)
rho_cav = jnp.diag(normalized_diags)

In [8]:
rho_cav.shape

(20, 20)

In [9]:
rho0 = tensor(rho_cav, basis(2, 0) @ basis(2, 0).conj().T)

In [10]:
from feedback_grape.utils.povm import (
    _probability_of_a_measurement_outcome_given_a_certain_state,
)

_probability_of_a_measurement_outcome_given_a_certain_state(
    rho0, 1, povm_measure_operator, [0.1, -3 * jnp.pi / 2]
)

Array(0.9479526, dtype=float64)

### defining target state

In [11]:
psi_target = tensor(
    (fock(N_cav, 1) + fock(N_cav, 2) + fock(N_cav, 3)) / jnp.sqrt(3), basis(2)
)
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = psi_target @ psi_target.conj().T
rho_target.shape

(40, 40)

In [12]:
from feedback_grape.utils.fidelity import fidelity

print(fidelity(U_final=rho0, C_target=rho_target, type="density"))

0.38188149897874146


In [13]:
# Print the eigenvalues of rho0
eigenvalues = jnp.linalg.eigvalsh(rho_target)
print("Eigenvalues of rho0:", eigenvalues)

Eigenvalues of rho0: [-2.72829216e-16 -6.02376918e-17  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]


### initialize random params

In [14]:
import jax

print(
    jax.random.uniform(
        jax.random.PRNGKey(0),
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jax.numpy.pi,
        maxval=jax.numpy.pi,
    ).tolist()
)

[[-0.5123490775685872, -1.782568231202715]]


In [None]:
import jax

num_time_steps = 5
num_of_iterations = 100
learning_rate = 0.05
# avg_photon_numer = 2 When testing kitten state
# TODO: remove this tolist thing
key = jax.random.PRNGKey(0)
measure = {
    "gate": povm_measure_operator,
    "initial_params": jax.random.uniform(
        key,
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jnp.pi,
        maxval=jnp.pi,
    )[0].tolist(),
    "measurement_flag": True,
    # "param_constrains": [[0, 0.5], [-1, 1]],
}

qub_unitary = {
    "gate": qubit_unitary,
    "initial_params": jax.random.uniform(
        key,
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jnp.pi,
        maxval=jnp.pi,
    )[0].tolist(),
    "measurement_flag": False,
    # "param_constrains": [[0, 0.5], [-1, 1]],
}

qub_cav = {
    "gate": qubit_cavity_unitary,
    "initial_params": jax.random.uniform(
        key,
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jnp.pi,
        maxval=jnp.pi,
    )[0].tolist(),
    "measurement_flag": False,
    # "param_constrains": [[0, 0.5], [-1, 1]],
}

system_params = [measure, qub_unitary, qub_cav]


result = optimize_pulse_with_feedback(
    U_0=rho0,
    C_target=rho_target,
    system_params=system_params,
    num_time_steps=num_time_steps,
    mode="lookup",
    lookup_min_init_value=-jnp.pi,
    lookup_max_init_value=jnp.pi,
    goal="fidelity",
    max_iter=num_of_iterations,
    convergence_threshold=1e-6,
    learning_rate=learning_rate,
    type="density",
    batch_size=10,
)

parameter shapes: [2, 2, 2]
Number of parameters: 6
Iteration 0, Loss: 0.361411
Iteration 10, Loss: 0.682436
Iteration 20, Loss: 0.557376
Iteration 30, Loss: 0.816386
Iteration 40, Loss: 0.942064
Iteration 50, Loss: 0.755550
Iteration 60, Loss: 0.775314
Iteration 70, Loss: 0.760871
Iteration 80, Loss: 1.013357
Iteration 90, Loss: 0.872184


In [None]:
# import jax

# num_time_steps = 5
# num_of_iterations = 100
# learning_rate = 0.05
# # avg_photon_numer = 2 When testing kitten state

# key = jax.random.PRNGKey(0)  
# initial_params = {
#     "POVM": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "U_q": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
#     "U_qc": jax.random.uniform(
#         key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
#     )[0].tolist(),
# }


# result = optimize_pulse_with_feedback(
#     U_0=rho0,
#     C_target=rho_target,
#     parameterized_gates=[
#         povm_measure_operator,
#         qubit_unitary,
#         qubit_cavity_unitary,
#     ],
#     measurement_indices=[0],
#     initial_params=initial_params,
#     num_time_steps=num_time_steps,
#     mode="lookup",
#     lookup_min_init_value=-jnp.pi,
#     lookup_max_init_value=jnp.pi,
#     goal="fidelity",
#     max_iter=num_of_iterations,
#     convergence_threshold=1e-6,
#     learning_rate=learning_rate,
#     type="density",
#     batch_size=10,
# )

In [17]:
print(result.final_purity)

None


In [18]:
# 0.602319331001811
print(result.final_fidelity)

0.602319331001811


In [19]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho0, type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, type="density"),
    )

initial fidelity: 0.38188149897874146
fidelity of state 0: 0.7483633308286942
fidelity of state 1: 0.6947796274505317
fidelity of state 2: 0.33841586658999984
fidelity of state 3: 0.6252296159890302
fidelity of state 4: 0.6505356128266584
fidelity of state 5: 0.6746359419330781
fidelity of state 6: 0.6746359419330781
fidelity of state 7: 0.2940488214045853
fidelity of state 8: 0.6479126091293765
fidelity of state 9: 0.6746359419330781


In [20]:
result.final_state

Array([[[ 1.60976520e-01+1.11022302e-16j,
          1.56050318e-02+2.68488241e-02j,
         -1.92635248e-01-1.39964495e-01j, ...,
          2.95040012e-03+2.13061592e-03j,
          4.84505451e-03-4.55256483e-03j,
          2.14671019e-03+5.94591950e-04j],
        [ 1.56050318e-02-2.68488241e-02j,
          1.25350038e-02+0.00000000e+00j,
         -3.33598586e-02+2.19974413e-02j, ...,
         -1.01673379e-03-9.29723833e-04j,
         -6.95832994e-04-4.61701210e-04j,
         -9.64188620e-04-4.80969683e-04j],
        [-1.92635248e-01+1.39964495e-01j,
         -3.33598586e-02-2.19974413e-02j,
          3.76755186e-01+2.77555756e-17j, ...,
         -8.31117287e-03+4.24442367e-04j,
         -1.63506142e-03+1.02578975e-02j,
         -5.59266731e-03+3.00147031e-03j],
        ...,
        [ 2.95040012e-03-2.13061592e-03j,
         -1.01673379e-03+9.29723833e-04j,
         -8.31117287e-03-4.24442367e-04j, ...,
          7.54475866e-04-1.21972744e-19j,
         -4.33501772e-05-6.11466916e-04j

In [21]:
result.optimized_trainable_parameters['lookup_table']

[[Array([-0.06641503, -3.19510687,  1.4741188 , -2.02568962, -0.28924361,
          0.45943259], dtype=float64),
  Array([ 0.03798034, -2.49834026,  3.00886766, -1.26668487, -0.14336197,
         -0.11319613], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.,

In [22]:
result.returned_params

[[Array([[-0.9969448 , -0.33415867],
         [-0.9969448 , -0.33415867],
         [-0.9969448 , -0.33415867],
         [-0.9969448 , -0.33415867],
         [-0.9969448 , -0.33415867],
         [-0.9969448 , -0.33415867],
         [-0.9969448 , -0.33415867],
         [-0.9969448 , -0.33415867],
         [-0.9969448 , -0.33415867],
         [-0.9969448 , -0.33415867]], dtype=float64),
  Array([[ 3.00886766, -1.26668487],
         [ 1.4741188 , -2.02568962],
         [ 1.4741188 , -2.02568962],
         [ 1.4741188 , -2.02568962],
         [ 1.4741188 , -2.02568962],
         [ 1.4741188 , -2.02568962],
         [ 1.4741188 , -2.02568962],
         [ 3.00886766, -1.26668487],
         [ 1.4741188 , -2.02568962],
         [ 1.4741188 , -2.02568962]], dtype=float64),
  Array([[-0.14336197, -0.11319613],
         [-0.28924361,  0.45943259],
         [-0.28924361,  0.45943259],
         [-0.28924361,  0.45943259],
         [-0.28924361,  0.45943259],
         [-0.28924361,  0.45943259],
    

In [23]:
print(result.iterations)

100


In [24]:
print(result.returned_params[1])

[Array([[ 0.03798034, -2.49834026],
       [-0.06641503, -3.19510687],
       [-0.06641503, -3.19510687],
       [-0.06641503, -3.19510687],
       [-0.06641503, -3.19510687],
       [-0.06641503, -3.19510687],
       [-0.06641503, -3.19510687],
       [ 0.03798034, -2.49834026],
       [-0.06641503, -3.19510687],
       [-0.06641503, -3.19510687]], dtype=float64), Array([[-2.09498036,  2.25193726],
       [-2.95636348,  2.06498388],
       [-2.95636348,  2.06498388],
       [-2.95636348,  2.06498388],
       [-2.95636348,  2.06498388],
       [-2.95636348,  2.06498388],
       [-2.95636348,  2.06498388],
       [-2.09498036,  2.25193726],
       [-2.95636348,  2.06498388],
       [-2.95636348,  2.06498388]], dtype=float64), Array([[-1.67406941,  1.1946259 ],
       [-1.62498524,  0.78680796],
       [-1.62498524,  0.78680796],
       [-1.62498524,  0.78680796],
       [-1.62498524,  0.78680796],
       [-1.62498524,  0.78680796],
       [-1.62498524,  0.78680796],
       [-1.67406941,

In [25]:
# measure = {
#     "gate": povm_measure_operator,
#     "initial_params": [0, jnp.pi / 2],
#     "measurement_flag": True,
#     "param_constrains": [[0, 0.5], [-1, 1]],
# }

# qub_unitary = {
#     "gate": qubit_unitary,
#     "initial_params": [0, jnp.pi / 2],
#     "measurement_flag": False,
#     "param_constrains": [[0, 0.5], [-1, 1]],
# }

# qub_cav = {
#     "gate": qubit_cavity_unitary,
#     "initial_params": [0, jnp.pi / 2],
#     "measurement_flag": False,
#     "param_constrains": [[0, 0.5], [-1, 1]],
# }

# system_params = [measure, qub_unitary, qub_cav]


# result = optimize_pulse_with_feedback(
#     U_0=rho0,
#     C_target=rho_target,
#     system_params=system_params,
#     num_time_steps=num_time_steps,
#     mode="lookup",
#     goal="fidelity",
#     max_iter=num_of_iterations,
#     convergence_threshold=1e-6,
#     learning_rate=learning_rate,
#     type="density",
#     batch_size=10,
# )

In [26]:
initial_params

{'POVM': [-0.5123490775685872, -1.782568231202715],
 'U_q': [-0.5123490775685872, -1.782568231202715],
 'U_qc': [-0.5123490775685872, -1.782568231202715]}

In [27]:
result = optimize_pulse_with_feedback(
    U_0=rho0,
    C_target=rho_target,
    num_time_steps=num_time_steps,
    mode="lookup",
    goal="fidelity",
    max_iter=num_of_iterations,
    convergence_threshold=1e-6,
    learning_rate=learning_rate,
    type="density",
    batch_size=10,
)

TypeError: optimize_pulse_with_feedback() missing 2 required positional arguments: 'parameterized_gates' and 'initial_params'

In [None]:
measure = {
    "gate": povm_measure_operator,
    "initial_params": jax.random.uniform(
        key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
    )[0].tolist(),
    "measurement_flag": True,
    "param_constrains": [[0, 0.5], [-1, 1]],
}

qub_unitary = {
    "gate": qubit_unitary,
    "initial_params": jax.random.uniform(
        key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
    )[0].tolist(),
    "measurement_flag": False,
    "param_constrains": [[0, 0.5], [-1, 1]],
}

qub_cav = {
    "gate": qubit_cavity_unitary,
    "initial_params": jax.random.uniform(
        key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
    )[0].tolist(),
    "measurement_flag": False,
    "param_constrains": [[0, 0.5], [-1, 1]],
}

system_params = [measure, qub_unitary, qub_cav]

In [None]:
initial_params = {
    "POVM": jax.random.uniform(
        key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
    )[0].tolist(),
    "U_q": jax.random.uniform(
        key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
    )[0].tolist(),
    "U_qc": jax.random.uniform(
        key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
    )[0].tolist(),
}

parameterized_gates=[
    povm_measure_operator,
    qubit_unitary,
    qubit_cavity_unitary,
],
measurement_indices=[0]

In [None]:
def convert_system_params(system_params):
    """
    Convert system_params format to (initial_params, parameterized_gates, measurement_indices) format.
    
    Args:
        system_params: List of dictionaries, each containing:
            - "gate": gate function
            - "initial_params": list of parameters
            - "measurement_flag": boolean indicating if this is a measurement gate
    
    Returns:
        tuple: (initial_params, parameterized_gates, measurement_indices)
            - initial_params: dict mapping gate names/types to parameter lists
            - parameterized_gates: list of gate functions
            - measurement_indices: list of indices where measurement gates appear
    """
    initial_params = {}
    parameterized_gates = []
    measurement_indices = []
    param_constrains = []
    
    for i, gate_config in enumerate(system_params):
        gate_func = gate_config["gate"]
        params = gate_config["initial_params"]
        is_measurement = gate_config["measurement_flag"]
        
        # Add gate to parameterized_gates list
        parameterized_gates.append(gate_func)
        
        # If this is a measurement gate, add its index
        if is_measurement:
            measurement_indices.append(i)
        
        param_name = f"gate_{i}"
        
        initial_params[param_name] = params

        # Add parameter constraints if provided
        # TODO: make sure to have default
        param_constrains.append(gate_config.get("param_constrains", None))
    
    return initial_params, parameterized_gates, measurement_indices, param_constrains


# Alternative version if you want to use gate function references directly
def convert_system_params_with_mapping(system_params, gate_name_map=None):
    """
    Convert system_params with explicit gate name mapping.
    
    Args:
        system_params: List of gate configurations
        gate_name_map: Optional dict mapping gate functions to parameter names
    
    Returns:
        tuple: (initial_params, parameterized_gates, measurement_indices)
    """
    if gate_name_map is None:
        gate_name_map = {}
    
    initial_params = {}
    parameterized_gates = []
    measurement_indices = []
    
    for i, gate_config in enumerate(system_params):
        gate_func = gate_config["gate"]
        params = gate_config["initial_params"]
        is_measurement = gate_config["measurement_flag"]
        
        parameterized_gates.append(gate_func)
        
        if is_measurement:
            measurement_indices.append(i)
        
        # Use provided mapping or fallback to function name
        param_name = gate_name_map.get(gate_func, gate_func.__name__)
        initial_params[param_name] = params
    
    return initial_params, parameterized_gates, measurement_indices

In [None]:
convert_system_params(system_params)

({'gate_0': [-1.6608348227001049, -0.6459624348505918],
  'gate_1': [-1.6608348227001049, -0.6459624348505918],
  'gate_2': [-1.6608348227001049, -0.6459624348505918]},
 [<function __main__.povm_measure_operator(measurement_outcome, gamma, delta)>,
  <function __main__.qubit_unitary(alpha_re, alpha_im)>,
  <function __main__.qubit_cavity_unitary(beta_re, beta_im)>],
 [0])