In [1]:
# ruff: noqa
import os

os.sys.path.append("..")
from feedback_grape.fgrape import optimize_pulse_with_feedback
from feedback_grape.utils.operators import (
    sigmap,
    sigmam,
    create,
    destroy,
    identity,
    cosm,
    sinm,
)
from feedback_grape.utils.states import basis, fock
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
import jax
from jax.scipy.linalg import expm

## defining parameterized operations that are repeated num_time_steps times

In [2]:
# # N = 100 with
# the following parameters: took 2 hours and output an average fidelity of 0.6350644364724286
# N_cav = 20
# import jax

# num_time_steps = 5
# num_of_iterations = 1000
# learning_rate = 0.05
# # avg_photon_numer = 2 When testing kitten state

# key = jax.random.PRNGKey(88)
# initial_params = {
#     "POVM": jax.random.uniform(key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi)[0].tolist(),
#     "U_q": jax.random.uniform(key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi)[0].tolist(),
#     "U_qc": jax.random.uniform(key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi)[0].tolist(),
# }


# result = optimize_pulse_with_feedback(
#     U_0=rho0,
#     C_target=rho_target,
#     parameterized_gates=[
#         povm_measure_operator,
#         qubit_unitary,
#         qubit_cavity_unitary,
#     ],
#     measurement_indices=[0],
#     initial_params=initial_params,
#     num_time_steps=num_time_steps,
#     mode="lookup",
#     lookup_min_init_value=-jnp.pi,
#     lookup_max_init_value=jnp.pi,
#     goal="fidelity",
#     optimizer="adam",
#     max_iter=num_of_iterations,
#     convergence_threshold=1e-6,
#     learning_rate=learning_rate,
#     type="density",
#     batch_size=2,
#     RNN=RNN
# )

In [3]:
N_cav = 20

In [4]:
def qubit_unitary(alpha_re, alpha_im):
    alpha = alpha_re + 1j * alpha_im
    return tensor(
        identity(N_cav),
        expm(-1j * (alpha * sigmap() + alpha.conjugate() * sigmam()) / 2),
    )

In [5]:
def qubit_cavity_unitary(beta_re, beta_im):
    beta = beta_re + 1j * beta_im
    return expm(
        -1j
        * (
            beta * (tensor(destroy(N_cav), sigmap()))
            + beta.conjugate() * (tensor(create(N_cav), sigmam()))
        )
        / 2
    )

In [6]:
Uq = qubit_unitary(0.1, 0.1)
Uqc = qubit_cavity_unitary(0.1, 0.1)
print(
    "Uq unitary check:",
    jnp.allclose(Uq.conj().T @ Uq, jnp.eye(Uq.shape[0]), atol=1e-7),
)
print(
    "Uqc unitary check:",
    jnp.allclose(Uqc.conj().T @ Uqc, jnp.eye(Uqc.shape[0]), atol=1e-7),
)

Uq unitary check: True
Uqc unitary check: True


In [7]:
from feedback_grape.utils.operators import create, destroy


def povm_measure_operator(measurement_outcome, gamma, delta):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    number_operator = tensor(create(N_cav) @ destroy(N_cav), identity(2))
    angle = (gamma * number_operator) + delta / 2
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

### defining initial (thermal) state

In [8]:
# initial state is a thermal state coupled to a qubit in the ground state?
n_average = 1
# natural logarithm
beta = jnp.log((1 / n_average) + 1)
diags = jnp.exp(-beta * jnp.arange(N_cav))
normalized_diags = diags / jnp.sum(diags, axis=0)
rho_cav = jnp.diag(normalized_diags)

In [9]:
rho_cav.shape

(20, 20)

In [10]:
rho0 = tensor(rho_cav, basis(2, 0) @ basis(2, 0).conj().T)

In [11]:
from feedback_grape.utils.povm import (
    _probability_of_a_measurement_outcome_given_a_certain_state,
)

_probability_of_a_measurement_outcome_given_a_certain_state(
    rho0, 1, povm_measure_operator, [0.1, -3 * jnp.pi / 2]
)

Array(0.9479526, dtype=float64)

### defining target state

In [12]:
psi_target = tensor(
    (fock(N_cav, 1) + fock(N_cav, 2) + fock(N_cav, 3)) / jnp.sqrt(3), basis(2)
)
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = psi_target @ psi_target.conj().T
rho_target.shape

(40, 40)

In [13]:
from feedback_grape.utils.fidelity import fidelity

print(fidelity(U_final=rho0, C_target=rho_target, type="density"))

0.3818814900083147


### initialize random params

In [15]:
import jax

print(
    jax.random.uniform(
        jax.random.PRNGKey(0),
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jax.numpy.pi,
        maxval=jax.numpy.pi,
    ).tolist()
)

[[-0.5123490775685872, -1.782568231202715]]


In [None]:
import jax

num_time_steps = 5
num_of_iterations = 1000
learning_rate = 0.05
# avg_photon_numer = 2 When testing kitten state

key = jax.random.PRNGKey(0)
initial_params = {
    "POVM": jax.random.uniform(
        key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
    )[0].tolist(),
    "U_q": jax.random.uniform(
        key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
    )[0].tolist(),
    "U_qc": jax.random.uniform(
        key, shape=(1, 2), minval=-jnp.pi, maxval=jnp.pi
    )[0].tolist(),
}


result = optimize_pulse_with_feedback(
    U_0=rho0,
    C_target=rho_target,
    parameterized_gates=[
        povm_measure_operator,
        qubit_unitary,
        qubit_cavity_unitary,
    ],
    measurement_indices=[0],
    initial_params=initial_params,
    num_time_steps=num_time_steps,
    mode="lookup",
    lookup_min_init_value=-jnp.pi,
    lookup_max_init_value=jnp.pi,
    goal="fidelity",
    max_iter=num_of_iterations,
    convergence_threshold=1e-6,
    learning_rate=learning_rate,
    type="density",
    batch_size=1,
)

Iteration 0, Loss: 0.205495
Iteration 10, Loss: 0.199460
Iteration 20, Loss: 0.382446
Iteration 30, Loss: 0.468029
Iteration 40, Loss: 0.570421
Iteration 50, Loss: 0.229321
Iteration 60, Loss: 0.211824
Iteration 70, Loss: 0.805874
Iteration 80, Loss: 0.974972
Iteration 90, Loss: 0.799284
Iteration 100, Loss: 0.879406
Iteration 110, Loss: 0.075007
Iteration 120, Loss: 1.042973
Iteration 130, Loss: 0.641436
Iteration 140, Loss: 0.635126
Iteration 150, Loss: 0.266109
Iteration 160, Loss: 0.552036
Iteration 170, Loss: 0.106394
Iteration 180, Loss: 0.817852
Iteration 190, Loss: 0.888952
Iteration 200, Loss: 0.123760
Iteration 210, Loss: 0.851424
Iteration 220, Loss: 0.634080
Iteration 230, Loss: 0.749260
Iteration 240, Loss: 0.313148
Iteration 250, Loss: 0.312770
Iteration 260, Loss: 0.620389
Iteration 270, Loss: 0.138180
Iteration 280, Loss: 0.634205
Iteration 290, Loss: 0.923655
Iteration 300, Loss: 0.920671
Iteration 310, Loss: 0.799627
Iteration 320, Loss: 0.520353
Iteration 330, Loss: 

In [32]:
result

FgResult(optimized_trainable_parameters={'initial_params': [Array([-0.55868141, -1.76949085], dtype=float64), Array([-0.51234908, -1.78256823], dtype=float64), Array([-0.51234908, -1.78256823], dtype=float64)], 'lookup_table': [[Array([-0.16279808, -2.7677011 ,  1.66810334, -1.95091371, -0.7356064 ,
       -0.51728184], dtype=float64), Array([-1.53844188e-03, -2.82658451e+00,  1.73141307e+00, -2.00858101e+00,
       -5.97821713e-01, -5.57022033e-01], dtype=float64), Array([0., 0., 0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0., 0., 0.], dtype=float64), A

In [33]:
print(result.final_purity)

None


In [34]:
print(result.final_fidelity)

0.3485331287499053


In [25]:
result.optimized_trainable_parameters['lookup_table']

[[Array([1.44752688, 0.11856753, 2.38747835, 0.56396388, 1.05940005,
         1.42821844], dtype=float64),
  Array([1.37077901, 0.21669131, 2.34303692, 0.42056495, 1.05932891,
         1.34995575], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dt

In [26]:
def convert_to_index(measurement_history):
    # Convert measurement history from [1, -1, ...] to [0, 1, ...] and then to an integer index
    binary_history = jnp.where(jnp.array(measurement_history) == 1, 0, 1)
    # Convert binary list to integer index (e.g., [0,1] -> 1)
    reversed_binary = binary_history[::-1]
    int_index = jnp.sum(
        (2 ** jnp.arange(len(reversed_binary))) * reversed_binary
    )
    return int_index

In [27]:
from numpy import int64


x = [
    jnp.array([-1]),
    jnp.array([1]),
    jnp.array([1]),
    jnp.array([1]),
    jnp.array([1]),
]

In [28]:
len(x)

5

In [29]:
convert_to_index(x)

Array(31, dtype=int64)

In [30]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho0, type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, type="density"),
    )

initial fidelity: 0.3818814900083147
fidelity of state 0: 0.36170891069239874
fidelity of state 1: 0.5782315688243697
fidelity of state 2: 0.3669477699239496
fidelity of state 3: 0.3146258696896755
fidelity of state 4: 0.3669477699239496
fidelity of state 5: 0.4418651907277243
fidelity of state 6: 0.38593667040591645
fidelity of state 7: 0.33350235551806784
fidelity of state 8: 0.3655008258099386
fidelity of state 9: 0.41104797540719173


In [None]:
result.final_state

Array([[[ 3.64418841e-02-3.46944695e-18j,
          1.74921366e-02+1.24900952e-02j,
         -7.18528016e-02+1.20187242e-02j, ...,
         -3.80157528e-03+9.90472387e-04j,
         -1.89460388e-03-2.76496567e-03j,
         -1.87411074e-03+3.11066823e-03j],
        [ 1.74921366e-02-1.24900952e-02j,
          3.85322526e-02-1.73472348e-18j,
         -5.24559906e-02+7.06732227e-02j, ...,
         -3.28074748e-03+6.32578994e-03j,
         -2.18411317e-03+6.29395784e-03j,
          5.96391197e-04+1.15863114e-03j],
        [-7.18528016e-02-1.20187242e-02j,
         -5.24559906e-02-7.06732227e-02j,
          2.47389629e-01-2.08166817e-17j, ...,
          1.69252539e-02-1.12540224e-03j,
          1.39388326e-02+2.46330184e-03j,
          3.10194486e-03-5.73239314e-03j],
        ...,
        [-3.80157528e-03-9.90472387e-04j,
         -3.28074748e-03-6.32578994e-03j,
          1.69252539e-02+1.12540224e-03j, ...,
          2.57726611e-03+5.14996032e-19j,
          1.92497820e-03-6.34154675e-04j

In [None]:
result.optimized_trainable_parameters['lookup_table']

[[Array([ 8.25764925e-17, -2.82743339e+00,  1.05567664e+00, -8.65645033e-01,
         -1.10358073e+00, -1.13445749e-01], dtype=float64),
  Array([ 1.13790776e-04, -2.67070712e+00,  2.37190809e+00, -1.39489586e+00,
         -7.02793564e-01, -3.64905402e-01], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0., 0

In [None]:
result.returned_params

[[Array([[-0.75556801, -0.30392829],
         [-0.75556801, -0.30392829],
         [-0.75556801, -0.30392829],
         [-0.75556801, -0.30392829],
         [-0.75556801, -0.30392829],
         [-0.75556801, -0.30392829],
         [-0.75556801, -0.30392829],
         [-0.75556801, -0.30392829],
         [-0.75556801, -0.30392829],
         [-0.75556801, -0.30392829]], dtype=float64),
  Array([[ 2.37190809, -1.39489586],
         [ 1.05567664, -0.86564503],
         [ 1.05567664, -0.86564503],
         [ 1.05567664, -0.86564503],
         [ 1.05567664, -0.86564503],
         [ 1.05567664, -0.86564503],
         [ 1.05567664, -0.86564503],
         [ 2.37190809, -1.39489586],
         [ 1.05567664, -0.86564503],
         [ 1.05567664, -0.86564503]], dtype=float64),
  Array([[-0.70279356, -0.3649054 ],
         [-1.10358073, -0.11344575],
         [-1.10358073, -0.11344575],
         [-1.10358073, -0.11344575],
         [-1.10358073, -0.11344575],
         [-1.10358073, -0.11344575],
    

In [None]:
# from feedback_grape.utils.povm import povm
# from feedback_grape.fgrape import apply_gate
# from feedback_grape.utils.purity import purity
# import jax

# time_steps = 5

# rho = rho0
# print(
#     "initial purity:",
#     fidelity(U_final=rho, C_target=rho_target, type="density"),
# )
# time_step_keys = jax.random.split(jax.random.PRNGKey(2), time_steps)
# print("time step keys:", time_step_keys.shape)
# for i in range(time_steps):
#     params = result.returned_params[i][0]
#     rho, _, _ = povm(rho, povm_measure_operator, params[0], time_step_keys[i])
#     rho = apply_gate(
#         rho,
#         qubit_unitary,
#         params[1],
#         type="density",
#     )
#     rho = apply_gate(
#         rho,
#         qubit_cavity_unitary,
#         params[2],
#         type="density",
#     )
#     print(
#         f"fid of rho after time step {i}",
#         fidelity(U_final=rho, C_target=rho_target, type="density"),
#     )
# final_rho_cav = rho

In [None]:
print(result.iterations)

1000


In [None]:
print(result.returned_params[1])

[Array([[ 1.13790776e-04, -2.67070712e+00],
       [ 8.25764925e-17, -2.82743339e+00],
       [ 8.25764925e-17, -2.82743339e+00],
       [ 8.25764925e-17, -2.82743339e+00],
       [ 8.25764925e-17, -2.82743339e+00],
       [ 8.25764925e-17, -2.82743339e+00],
       [ 8.25764925e-17, -2.82743339e+00],
       [ 1.13790776e-04, -2.67070712e+00],
       [ 8.25764925e-17, -2.82743339e+00],
       [ 8.25764925e-17, -2.82743339e+00]], dtype=float64), Array([[-3.64897638,  3.38182758],
       [-1.82079981,  2.09023662],
       [-1.82079981,  2.09023662],
       [-1.82079981,  2.09023662],
       [-1.82079981,  2.09023662],
       [-1.82079981,  2.09023662],
       [-1.82079981,  2.09023662],
       [-3.64897638,  3.38182758],
       [-1.82079981,  2.09023662],
       [-1.82079981,  2.09023662]], dtype=float64), Array([[-2.24983481,  0.69887889],
       [-1.6713748 ,  0.91550651],
       [-1.6713748 ,  0.91550651],
       [-1.6713748 ,  0.91550651],
       [-1.6713748 ,  0.91550651],
       [-1

## Code Tests 

In [None]:
import jax

F = []
F_1_0 = jnp.array([jnp.pi / 2, 0])
key = jax.random.PRNGKey(0)
appo = jax.random.uniform(key, shape=(2**1, 4)) * jnp.pi
F.append(appo)
F

[Array([[1.31462179, 0.67951221, 3.03264681, 1.80484666],
        [1.67203883, 1.11496751, 2.77405451, 1.98724708]], dtype=float64)]

In [None]:
variables = [F] + [F_1_0]

In [None]:
variables

[[Array([[1.31462179, 0.67951221, 3.03264681, 1.80484666],
         [1.67203883, 1.11496751, 2.77405451, 1.98724708]], dtype=float64)],
 Array([1.57079633, 0.        ], dtype=float64)]

In [None]:
def prepare_parameters_from_dict(params_dict):
    """
    Convert a nested dictionary of parameters to a flat list and record shapes.

    Args:
        params_dict: Nested dictionary of parameters.

    Returns:
        tuple: Flattened parameters list and list of shapes.
    """
    flat_params = []
    param_shapes = []

    # returns a flat list of the leaves
    def flatten_dict(d):
        result = []
        for key, value in d.items():
            if isinstance(value, dict):
                result.extend(flatten_dict(value))
            else:
                result.append(value)
        return result

    # flatten each top-level gate
    for gate_name, gate_params in params_dict.items():
        if isinstance(gate_params, dict):
            # Extract parameters for this gate
            gate_flat_params = jnp.array(flatten_dict(gate_params))
        else:
            # If already a flat array
            gate_flat_params = jnp.array(gate_params)
        # this is checking if use can enter sth like {'gate1': 1} instead of {"gate1": {"param1": 1}}
        # if not (isinstance(gate_flat_params, list)):
        #     flat_params.append([gate_flat_params])
        #     param_shapes.append(1)
        # else:
        #     flat_params.append(gate_flat_params)
        #     param_shapes.append(len(gate_flat_params))
        flat_params.append(gate_flat_params)
        param_shapes.append(gate_flat_params.shape)
    return flat_params, param_shapes

In [None]:
import numpy as np


def reshape_params(param_shapes, rnn_flattened_params):
    """
    Reshape the parameters for the gates.
    """
    # Reshape the flattened parameters from RNN output according
    # to each gate corressponding params
    reshaped_params = []
    param_idx = 0
    for shape in param_shapes:
        num_params = int(np.prod(shape))
        # rnn outputs a flat list, this takes each and assigns according to the shape
        gate_params = rnn_flattened_params[
            param_idx : param_idx + num_params
        ].reshape(shape)
        reshaped_params.append(gate_params)
        param_idx += num_params

    new_params = reshaped_params
    return new_params

In [None]:
import os

os.sys.path.append("..")
import jax
import jax.numpy as jnp
from feedback_grape.fgrape import prepare_parameters_from_dict, reshape_params

initial_params = {
    "POVM": [jnp.pi / 3, jnp.pi / 3],
    "U_q": [jnp.pi / 3],
    "U_qc": [jnp.pi / 3],
}

flat_params, param_shapes = prepare_parameters_from_dict(initial_params)
num_of_columns = 4
num_of_sub_lists = 3
F = []


def construct_ragged_row(num_of_rows):
    res = []
    for i in range(num_of_rows):
        flattened = jax.random.uniform(
            jax.random.PRNGKey(0 + i),
            shape=(num_of_columns,),
            minval=-jnp.pi,
            maxval=jnp.pi,
        )
        res.append(flattened)
    return res


for i in range(1, num_of_sub_lists + 1):
    F.append(construct_ragged_row(num_of_rows=2**i))

for i in range(len(F)):
    print("length of F[{}]: {}".format(i, len(F[i])))

print("############")
min_num_of_rows = 2 ** len(F)
for i in range(len(F)):
    if len(F[i]) < min_num_of_rows:
        zeros_arrays = [
            jnp.zeros((num_of_columns,), dtype=jnp.float32)
            for _ in range(min_num_of_rows - len(F[i]))
        ]
        F[i] = F[i] + zeros_arrays

print("############")
from pprint import pprint

for i in range(len(F[0])):
    print("length of F[{}]: {}".format(i, len(F[0][i])))
    pprint(F[0][i])

length of F[0]: 2
length of F[1]: 4
length of F[2]: 8
############
############
length of F[0]: 4
Array([-0.51234908, -1.78256823,  2.92370097,  0.46810066], dtype=float64)
length of F[1]: 4
Array([-2.39923276, -0.29054238,  0.43072842,  2.07843927], dtype=float64)
length of F[2]: 4
Array([0., 0., 0., 0.], dtype=float32)
length of F[3]: 4
Array([0., 0., 0., 0.], dtype=float32)
length of F[4]: 4
Array([0., 0., 0., 0.], dtype=float32)
length of F[5]: 4
Array([0., 0., 0., 0.], dtype=float32)
length of F[6]: 4
Array([0., 0., 0., 0.], dtype=float32)
length of F[7]: 4
Array([0., 0., 0., 0.], dtype=float32)


In [None]:
print(F)

[[Array([-0.51234908, -1.78256823,  2.92370097,  0.46810066], dtype=float64), Array([-2.39923276, -0.29054238,  0.43072842,  2.07843927], dtype=float64), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32)], [Array([-0.51234908, -1.78256823,  2.92370097,  0.46810066], dtype=float64), Array([-2.39923276, -0.29054238,  0.43072842,  2.07843927], dtype=float64), Array([-0.4773859 , -2.20922398,  1.04214091, -0.38718841], dtype=float64), Array([ 2.2958313 , -0.29402536, -1.6428162 , -1.85833117], dtype=float64), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32)], [Array([-0.51234908, -1.78256823,  2.92370097,  0.46810066], dtype=float64), Array([-2.39923276, -0.29054238,  0.43072842,  2.07843927], dt

In [None]:
def convert_to_index(measurement_history):
    # Convert measurement history from [1, -1, ...] to [0, 1, ...] and then to an integer index
    binary_history = jnp.where(jnp.array(measurement_history) == 1, 0, 1)
    print(f"binary_history: {binary_history}")
    # Convert binary list to integer index (e.g., [0,1] -> 1)
    # Reverse the binary_history to operate from last element backwards
    reversed_binary = binary_history[::-1]
    int_index = sum(
        (2**i) * reversed_binary[i] for i in range(len(reversed_binary))
    )
    return int_index


def extract_from_lut(lut, measurement_history):
    """
    Extract parameters from the lookup table based on the measurement history.

    Args:
        lut: Lookup table for parameters.
        measurement_history: History of measurements.
        time_step: Current time step.

    Returns:
        Extracted parameters.
    """
    sub_array_idx = len(measurement_history) - 1
    print(f"sub_array_idx: {sub_array_idx}")
    sub_array_param_idx = convert_to_index(measurement_history)
    print(f"sub_array_param_idx: {sub_array_param_idx}")
    return lut[sub_array_idx][sub_array_param_idx]


extracted_lut_params = extract_from_lut(F, [1, -1])

sub_array_idx: 1
binary_history: [0 1]
sub_array_param_idx: 1


In [None]:
print(extracted_lut_params)

[-2.39923276 -0.29054238  0.43072842  2.07843927]


In [None]:
import tensorflow as tf

print(tf.random.uniform(shape=[1]))

tf.Tensor([0.4571935], shape=(1,), dtype=float32)


In [None]:
import jax

jax.random.uniform(
    jax.random.PRNGKey(0),
    shape=(1,),
)

Array([0.41845711], dtype=float64)

In [None]:
def convert_to_index(measurement_history):
    # Convert measurement history from [1, -1, ...] to [0, 1, ...] and then to an integer index
    binary_history = jnp.where(jnp.array(measurement_history) == 1, 0, 1)
    # Convert binary list to integer index (e.g., [0,1] -> 1)
    reversed_binary = binary_history[::-1]
    int_index = jnp.sum(
        (2 ** jnp.arange(len(reversed_binary))) * reversed_binary
    )
    return int_index

In [None]:
convert_to_index([1, -1, 1, -1, -1])  # Example usage, should return 10 01010

Array(11, dtype=int64)

In [None]:
i = [[[1, 2], [2]]]
i + [jnp.zeros((2,))]  # Example usage, should return [[1], [2], [0.0, 0.0]]

[[[1, 2], [2]], Array([0., 0.], dtype=float64)]

In [None]:
F = [[[[1, 2], [1], [1]]], [[[1, 2], [1], [1]]], [[[1, 2], [1], [1]]]]
for i in range(len(F)):
    if len(F[i]) < 8:
        zeros_arrays = [
            jnp.zeros((4,), dtype=jnp.float64) for _ in range(8 - len(F[i]))
        ]
        F[i] = F[i] + zeros_arrays

In [None]:
from pprint import pprint

pprint("F:")
for i in range(len(F)):
    print("length of F[{}]: {}".format(i, len(F[i])))
    pprint(F[i])

'F:'
length of F[0]: 8
[[[1, 2], [1], [1]],
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64)]
length of F[1]: 8
[[[1, 2], [1], [1]],
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64)]
length of F[2]: 8
[[[1, 2], [1], [1]],
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64)]


In [None]:
from feedback_grape.utils.operators import create, destroy

N_cav = 2
a_dag = create(N_cav)
a = destroy(N_cav)

comm = (a @ a_dag) - (a_dag @ a)

In [None]:
comm

Array([[ 1.+0.j,  0.+0.j],
       [ 0.+0.j, -1.+0.j]], dtype=complex128)