In [1]:
# ruff: noqa
import os

os.sys.path.append("..")
from feedback_grape.fgrape import optimize_pulse_with_feedback
from feedback_grape.utils.operators import (
    sigmap,
    sigmam,
    create,
    destroy,
    identity,
    cosm,
    sinm,
)
from feedback_grape.utils.states import basis, fock
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
from jax.scipy.linalg import expm

## defining parameterized operations that are repeated num_time_steps times

In [2]:
N_cav = 10

In [3]:
def qubit_unitary(alpha):
    """
    TODO: see if alpha, can be sth elser other than scalar, and if the algo understands this
    see if there can be multiple params like alpha and beta input
    """
    return expm(
        -1j
        * (
            alpha * tensor(identity(N_cav), sigmap())
            + alpha.conjugate() * tensor(identity(N_cav), sigmam())
        )
        / 2
    )

In [4]:
def qubit_cavity_unitary(beta):
    return expm(
        -1j
        * (
            beta
            * (
                tensor(destroy(N_cav), identity(2))
                @ tensor(identity(N_cav), sigmap())
            )
            + beta.conjugate()
            * (
                tensor(create(N_cav), identity(2))
                @ tensor(identity(N_cav), sigmam())
            )
        )
        / 2
    )

In [5]:
alpha = 0.1 + 0.1j
beta = 0.1 + 0.1j
Uq = qubit_unitary(alpha)
Uqc = qubit_cavity_unitary(beta)
print(
    "Uq unitary check:",
    jnp.allclose(Uq.conj().T @ Uq, jnp.eye(Uq.shape[0]), atol=1e-7),
)
print(
    "Uqc unitary check:",
    jnp.allclose(Uqc.conj().T @ Uqc, jnp.eye(Uqc.shape[0]), atol=1e-7),
)

Uq unitary check: True
Uqc unitary check: True


In [6]:
qubit_unitary(0.1)

Array([[0.99875026+0.j        , 0.        -0.04997917j,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ],
       [0.        -0.04997917j, 0.99875026+0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j 

In [7]:
from feedback_grape.utils.operators import create, destroy


def povm_measure_operator(measurement_outcome, gamma, delta):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    number_operator = tensor(create(N_cav) @ destroy(N_cav), identity(2))
    angle = (gamma * number_operator) + delta / 2
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

### defining initial (thermal) state

In [8]:
# initial state is a thermal state coupled to a qubit in the ground state?
n_average = 1
# natural logarithm
beta = jnp.log((1 / n_average) + 1)
diags = jnp.exp(-beta * jnp.arange(N_cav))
normalized_diags = diags / jnp.sum(diags, axis=0)
rho_cav = jnp.diag(normalized_diags)

In [9]:
rho_cav.shape

(10, 10)

In [10]:
rho0 = tensor(rho_cav, basis(2, 0) @ basis(2, 0).conj().T)

In [11]:
from feedback_grape.utils.povm import (
    _probability_of_a_measurement_outcome_given_a_certain_state,
)

_probability_of_a_measurement_outcome_given_a_certain_state(
    rho0, 1, povm_measure_operator, [0.1, -3 * jnp.pi / 2]
)

Array(0.95971493, dtype=float64)

### defining target state

In [12]:
psi_target = tensor(
    (fock(N_cav, 1) + fock(N_cav, 2) + fock(N_cav, 3)) / jnp.sqrt(3), basis(2)
)

rho_target = psi_target @ psi_target.conj().T
rho_target.shape

(20, 20)

In [13]:
from feedback_grape.utils.fidelity import fidelity

print(fidelity(U_final=rho0, C_target=rho_target, type="density"))

0.3820679100780015


In [14]:
import flax.linen as nn


class RNN(nn.Module):
    hidden_size: int  # number of features in the hidden state
    output_size: int  # number of features in the output ( 2 in the case of gamma and beta)

    @nn.compact
    def __call__(self, measurement, hidden_state):
        """
        If your GRU has a hidden state increasing number of features in the hidden stateH means:

        - You're allowing the model to store more information across time steps

        - Each time step can represent more complex features, patterns, or dependencies

        - You're giving the GRU more representational capacity
        """
        gru_cell = nn.GRUCell(features=self.hidden_size)
        self.make_rng('dropout')
        if measurement.ndim == 1:
            measurement = measurement.reshape(1, -1)
        new_hidden_state, _ = gru_cell(hidden_state, measurement)
        new_hidden_state = nn.Dropout(rate=0.2, deterministic=False)(
            new_hidden_state
        )
        # this returns the povm_params after linear regression through the hidden state which contains
        # the information of the previous time steps and this is optimized to output best povm_params
        # new_hidden_state = nn.Dense(features=self.hidden_size)(new_hidden_state)
        output = nn.Dense(
            features=self.output_size,
            kernel_init=nn.initializers.glorot_uniform(),
            bias_init=nn.initializers.constant(jnp.pi),
        )(new_hidden_state)
        # output = jnp.asarray(output)
        return output[0], new_hidden_state

### initialize random params

In [15]:
num_time_steps = 5
num_of_iterations = 1000
learning_rate = 0.05
# avg_photon_numer = 2 When testing kitten state


initial_params = {
    "POVM": [jnp.pi / 3, jnp.pi / 3],
    "U_q": [jnp.pi / 3],
    "U_qc": [jnp.pi / 3],
}


result = optimize_pulse_with_feedback(
    U_0=rho0,
    C_target=rho_target,
    parameterized_gates=[
        povm_measure_operator,
        qubit_unitary,
        qubit_cavity_unitary,
    ],
    measurement_indices=[0],
    initial_params=initial_params,
    num_time_steps=num_time_steps,
    mode="lookup",
    goal="fidelity",
    optimizer="adam",
    max_iter=num_of_iterations,
    convergence_threshold=1e-20,
    learning_rate=learning_rate,
    type="density",
    batch_size=10,
    RNN=RNN,
)

Iteration 0, Loss: 0.325266
Iteration 10, Loss: 0.598453
Iteration 20, Loss: 0.856740
Iteration 30, Loss: 0.541849
Iteration 40, Loss: 0.746586
Iteration 50, Loss: 0.716033
Iteration 60, Loss: 0.358708
Iteration 70, Loss: 0.625978
Iteration 80, Loss: 0.832234
Iteration 90, Loss: 0.729699
Iteration 100, Loss: 0.593591
Iteration 110, Loss: 0.764127
Iteration 120, Loss: 0.410453
Iteration 130, Loss: 0.880016
Iteration 140, Loss: 0.472813
Iteration 150, Loss: 0.680386
Iteration 160, Loss: 0.647792
Iteration 170, Loss: 0.711039
Iteration 180, Loss: 0.621183
Iteration 190, Loss: 0.692695
Iteration 200, Loss: 0.638986
Iteration 210, Loss: 0.820091
Iteration 220, Loss: 0.718212
Iteration 230, Loss: 0.953349
Iteration 240, Loss: 0.833595
Iteration 250, Loss: 0.724370
Iteration 260, Loss: 0.926623
Iteration 270, Loss: 0.681148
Iteration 280, Loss: 0.665806
Iteration 290, Loss: 0.478922
Iteration 300, Loss: 0.551558
Iteration 310, Loss: 0.576074
Iteration 320, Loss: 0.558215
Iteration 330, Loss: 

In [16]:
result

FgResult(optimized_trainable_parameters={'initial_params': [Array([1.51454515, 0.05927654], dtype=float64), Array([1.04719755], dtype=float64), Array([1.04719755], dtype=float64)], 'lookup_table': [[Array([ 0.26358866, -0.73890009,  5.23342429,  4.52071003], dtype=float64), Array([-3.14242731,  0.0248276 ,  1.75733912,  1.05560556], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0., 0., 0., 0.], dtype=float64), Array([0.

In [17]:
print(result.final_purity)

None


In [18]:
print(result.final_fidelity)

0.7602333159442032


In [19]:
result.optimized_trainable_parameters['lookup_table']

[[Array([ 0.26358866, -0.73890009,  5.23342429,  4.52071003], dtype=float64),
  Array([-3.14242731,  0.0248276 ,  1.75733912,  1.05560556], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Ar

In [20]:
def convert_to_index(measurement_history):
    # Convert measurement history from [1, -1, ...] to [0, 1, ...] and then to an integer index
    binary_history = jnp.where(jnp.array(measurement_history) == 1, 0, 1)
    # Convert binary list to integer index (e.g., [0,1] -> 1)
    reversed_binary = binary_history[::-1]
    int_index = jnp.sum(
        (2 ** jnp.arange(len(reversed_binary))) * reversed_binary
    )
    return int_index

In [21]:
from numpy import int64


x = [
    jnp.array([-1]),
    jnp.array([1]),
    jnp.array([1]),
    jnp.array([1]),
    jnp.array([1]),
]

In [22]:
len(x)

5

In [23]:
convert_to_index(x)

Array(31, dtype=int64)

In [24]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho0, type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, type="density"),
    )

initial fidelity: 0.3820679100780015
fidelity of state 0: 0.686397141231233
fidelity of state 1: 0.905359981389137
fidelity of state 2: 0.663491871769858
fidelity of state 3: 0.905359981389137
fidelity of state 4: 0.39348466142454636
fidelity of state 5: 0.8647252769973807
fidelity of state 6: 0.905359981389137
fidelity of state 7: 0.686397141231233
fidelity of state 8: 0.905359981389137
fidelity of state 9: 0.686397141231233


In [25]:
result.final_state

Array([[[ 4.96575158e-02-3.90312782e-17j,
          4.45375356e-02+7.20186056e-04j,
         -1.04659652e-01-5.91976587e-02j, ...,
         -4.00418145e-03+1.38095913e-02j,
          1.57677741e-02+6.96547272e-03j,
          1.08628027e-02-5.89695152e-03j],
        [ 4.45375356e-02-7.20186056e-04j,
          6.19217941e-02+5.20417043e-17j,
         -7.56696597e-02-7.94800133e-02j, ...,
         -1.12894131e-02+2.14895349e-02j,
          2.05864370e-02+1.89981223e-02j,
          1.17334461e-02-3.38448639e-04j],
        [-1.04659652e-01+5.91976587e-02j,
         -7.56696597e-02+7.94800133e-02j,
          3.74344868e-01-1.21430643e-17j, ...,
         -3.75904191e-02-3.39703547e-02j,
         -5.13851245e-02+3.48433057e-02j,
         -2.37389124e-02+3.89765563e-02j],
        ...,
        [-4.00418145e-03-1.38095913e-02j,
         -1.12894131e-02-2.14895349e-02j,
         -3.75904191e-02+3.39703547e-02j, ...,
          1.59949524e-02+4.22838847e-18j,
          3.84574450e-03-1.71151740e-02j

In [26]:
result.optimized_trainable_parameters['lookup_table']

[[Array([ 0.26358866, -0.73890009,  5.23342429,  4.52071003], dtype=float64),
  Array([-3.14242731,  0.0248276 ,  1.75733912,  1.05560556], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Array([0., 0., 0., 0.], dtype=float64),
  Ar

In [27]:
result.returned_params

[[Array([[1.51454515, 0.05927654],
         [1.51454515, 0.05927654],
         [1.51454515, 0.05927654],
         [1.51454515, 0.05927654],
         [1.51454515, 0.05927654],
         [1.51454515, 0.05927654],
         [1.51454515, 0.05927654],
         [1.51454515, 0.05927654],
         [1.51454515, 0.05927654],
         [1.51454515, 0.05927654]], dtype=float64),
  Array([[5.23342429],
         [5.23342429],
         [1.75733912],
         [5.23342429],
         [5.23342429],
         [1.75733912],
         [5.23342429],
         [5.23342429],
         [5.23342429],
         [5.23342429]], dtype=float64),
  Array([[4.52071003],
         [4.52071003],
         [1.05560556],
         [4.52071003],
         [4.52071003],
         [1.05560556],
         [4.52071003],
         [4.52071003],
         [4.52071003],
         [4.52071003]], dtype=float64)],
 [Array([[ 0.26358866, -0.73890009],
         [ 0.26358866, -0.73890009],
         [-3.14242731,  0.0248276 ],
         [ 0.26358866, -0.7

In [28]:
# from feedback_grape.utils.povm import povm
# from feedback_grape.fgrape import apply_gate
# from feedback_grape.utils.purity import purity
# import jax

# time_steps = 5

# rho = rho0
# print(
#     "initial purity:",
#     fidelity(U_final=rho, C_target=rho_target, type="density"),
# )
# time_step_keys = jax.random.split(jax.random.PRNGKey(2), time_steps)
# print("time step keys:", time_step_keys.shape)
# for i in range(time_steps):
#     params = result.returned_params[i][0]
#     rho, _, _ = povm(rho, povm_measure_operator, params[0], time_step_keys[i])
#     rho = apply_gate(
#         rho,
#         qubit_unitary,
#         params[1],
#         type="density",
#     )
#     rho = apply_gate(
#         rho,
#         qubit_cavity_unitary,
#         params[2],
#         type="density",
#     )
#     print(
#         f"fid of rho after time step {i}",
#         fidelity(U_final=rho, C_target=rho_target, type="density"),
#     )
# final_rho_cav = rho

In [29]:
print(result.iterations)

1000


In [30]:
print(result.returned_params[1])

[Array([[ 0.26358866, -0.73890009],
       [ 0.26358866, -0.73890009],
       [-3.14242731,  0.0248276 ],
       [ 0.26358866, -0.73890009],
       [ 0.26358866, -0.73890009],
       [-3.14242731,  0.0248276 ],
       [ 0.26358866, -0.73890009],
       [ 0.26358866, -0.73890009],
       [ 0.26358866, -0.73890009],
       [ 0.26358866, -0.73890009]], dtype=float64), Array([[ 3.1206339 ],
       [ 3.1206339 ],
       [ 0.23311092],
       [ 3.1206339 ],
       [-1.29278068],
       [ 0.23311092],
       [ 3.1206339 ],
       [ 3.1206339 ],
       [ 3.1206339 ],
       [ 3.1206339 ]], dtype=float64), Array([[ 1.83781023],
       [ 1.83781023],
       [-1.36127119],
       [ 1.83781023],
       [-0.49512343],
       [-1.36127119],
       [ 1.83781023],
       [ 1.83781023],
       [ 1.83781023],
       [ 1.83781023]], dtype=float64)]


## Code Tests 

In [31]:
import jax

F = []
F_1_0 = jnp.array([jnp.pi / 2, 0])
key = jax.random.PRNGKey(0)
appo = jax.random.uniform(key, shape=(2**1, 4)) * jnp.pi
F.append(appo)
F

[Array([[1.31462179, 0.67951221, 3.03264681, 1.80484666],
        [1.67203883, 1.11496751, 2.77405451, 1.98724708]], dtype=float64)]

In [32]:
variables = [F] + [F_1_0]

In [33]:
variables

[[Array([[1.31462179, 0.67951221, 3.03264681, 1.80484666],
         [1.67203883, 1.11496751, 2.77405451, 1.98724708]], dtype=float64)],
 Array([1.57079633, 0.        ], dtype=float64)]

In [34]:
def prepare_parameters_from_dict(params_dict):
    """
    Convert a nested dictionary of parameters to a flat list and record shapes.

    Args:
        params_dict: Nested dictionary of parameters.

    Returns:
        tuple: Flattened parameters list and list of shapes.
    """
    flat_params = []
    param_shapes = []

    # returns a flat list of the leaves
    def flatten_dict(d):
        result = []
        for key, value in d.items():
            if isinstance(value, dict):
                result.extend(flatten_dict(value))
            else:
                result.append(value)
        return result

    # flatten each top-level gate
    for gate_name, gate_params in params_dict.items():
        if isinstance(gate_params, dict):
            # Extract parameters for this gate
            gate_flat_params = jnp.array(flatten_dict(gate_params))
        else:
            # If already a flat array
            gate_flat_params = jnp.array(gate_params)
        # this is checking if use can enter sth like {'gate1': 1} instead of {"gate1": {"param1": 1}}
        # if not (isinstance(gate_flat_params, list)):
        #     flat_params.append([gate_flat_params])
        #     param_shapes.append(1)
        # else:
        #     flat_params.append(gate_flat_params)
        #     param_shapes.append(len(gate_flat_params))
        flat_params.append(gate_flat_params)
        param_shapes.append(gate_flat_params.shape)
    return flat_params, param_shapes

In [35]:
import numpy as np


def reshape_params(param_shapes, rnn_flattened_params):
    """
    Reshape the parameters for the gates.
    """
    # Reshape the flattened parameters from RNN output according
    # to each gate corressponding params
    reshaped_params = []
    param_idx = 0
    for shape in param_shapes:
        num_params = int(np.prod(shape))
        # rnn outputs a flat list, this takes each and assigns according to the shape
        gate_params = rnn_flattened_params[
            param_idx : param_idx + num_params
        ].reshape(shape)
        reshaped_params.append(gate_params)
        param_idx += num_params

    new_params = reshaped_params
    return new_params

In [36]:
import os

os.sys.path.append("..")
import jax
import jax.numpy as jnp
from feedback_grape.fgrape import prepare_parameters_from_dict, reshape_params

initial_params = {
    "POVM": [jnp.pi / 3, jnp.pi / 3],
    "U_q": [jnp.pi / 3],
    "U_qc": [jnp.pi / 3],
}

flat_params, param_shapes = prepare_parameters_from_dict(initial_params)
num_of_columns = 4
num_of_sub_lists = 3
F = []


def construct_ragged_row(num_of_rows):
    res = []
    for i in range(num_of_rows):
        flattened = jax.random.uniform(
            jax.random.PRNGKey(0 + i),
            shape=(num_of_columns,),
            minval=-jnp.pi,
            maxval=jnp.pi,
        )
        res.append(flattened)
    return res


for i in range(1, num_of_sub_lists + 1):
    F.append(construct_ragged_row(num_of_rows=2**i))

for i in range(len(F)):
    print("length of F[{}]: {}".format(i, len(F[i])))

print("############")
min_num_of_rows = 2 ** len(F)
for i in range(len(F)):
    if len(F[i]) < min_num_of_rows:
        zeros_arrays = [
            jnp.zeros((num_of_columns,), dtype=jnp.float32)
            for _ in range(min_num_of_rows - len(F[i]))
        ]
        F[i] = F[i] + zeros_arrays

print("############")
from pprint import pprint

for i in range(len(F[0])):
    print("length of F[{}]: {}".format(i, len(F[0][i])))
    pprint(F[0][i])

length of F[0]: 2
length of F[1]: 4
length of F[2]: 8
############
############
length of F[0]: 4
Array([-0.51234908, -1.78256823,  2.92370097,  0.46810066], dtype=float64)
length of F[1]: 4
Array([-2.39923276, -0.29054238,  0.43072842,  2.07843927], dtype=float64)
length of F[2]: 4
Array([0., 0., 0., 0.], dtype=float32)
length of F[3]: 4
Array([0., 0., 0., 0.], dtype=float32)
length of F[4]: 4
Array([0., 0., 0., 0.], dtype=float32)
length of F[5]: 4
Array([0., 0., 0., 0.], dtype=float32)
length of F[6]: 4
Array([0., 0., 0., 0.], dtype=float32)
length of F[7]: 4
Array([0., 0., 0., 0.], dtype=float32)


In [37]:
print(F)

[[Array([-0.51234908, -1.78256823,  2.92370097,  0.46810066], dtype=float64), Array([-2.39923276, -0.29054238,  0.43072842,  2.07843927], dtype=float64), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32)], [Array([-0.51234908, -1.78256823,  2.92370097,  0.46810066], dtype=float64), Array([-2.39923276, -0.29054238,  0.43072842,  2.07843927], dtype=float64), Array([-0.4773859 , -2.20922398,  1.04214091, -0.38718841], dtype=float64), Array([ 2.2958313 , -0.29402536, -1.6428162 , -1.85833117], dtype=float64), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32), Array([0., 0., 0., 0.], dtype=float32)], [Array([-0.51234908, -1.78256823,  2.92370097,  0.46810066], dtype=float64), Array([-2.39923276, -0.29054238,  0.43072842,  2.07843927], dt

In [38]:
def convert_to_index(measurement_history):
    # Convert measurement history from [1, -1, ...] to [0, 1, ...] and then to an integer index
    binary_history = jnp.where(jnp.array(measurement_history) == 1, 0, 1)
    print(f"binary_history: {binary_history}")
    # Convert binary list to integer index (e.g., [0,1] -> 1)
    # Reverse the binary_history to operate from last element backwards
    reversed_binary = binary_history[::-1]
    int_index = sum(
        (2**i) * reversed_binary[i] for i in range(len(reversed_binary))
    )
    return int_index


def extract_from_lut(lut, measurement_history):
    """
    Extract parameters from the lookup table based on the measurement history.

    Args:
        lut: Lookup table for parameters.
        measurement_history: History of measurements.
        time_step: Current time step.

    Returns:
        Extracted parameters.
    """
    sub_array_idx = len(measurement_history) - 1
    print(f"sub_array_idx: {sub_array_idx}")
    sub_array_param_idx = convert_to_index(measurement_history)
    print(f"sub_array_param_idx: {sub_array_param_idx}")
    return lut[sub_array_idx][sub_array_param_idx]


extracted_lut_params = extract_from_lut(F, [1, -1])

sub_array_idx: 1
binary_history: [0 1]
sub_array_param_idx: 1


In [39]:
print(extracted_lut_params)

[-2.39923276 -0.29054238  0.43072842  2.07843927]


In [40]:
import tensorflow as tf

print(tf.random.uniform(shape=[1]))

tf.Tensor([0.39600933], shape=(1,), dtype=float32)


In [41]:
import jax

jax.random.uniform(
    jax.random.PRNGKey(0),
    shape=(1,),
)

Array([0.41845711], dtype=float64)

In [42]:
def convert_to_index(measurement_history):
    # Convert measurement history from [1, -1, ...] to [0, 1, ...] and then to an integer index
    binary_history = jnp.where(jnp.array(measurement_history) == 1, 0, 1)
    # Convert binary list to integer index (e.g., [0,1] -> 1)
    reversed_binary = binary_history[::-1]
    int_index = jnp.sum(
        (2 ** jnp.arange(len(reversed_binary))) * reversed_binary
    )
    return int_index

In [43]:
convert_to_index([1, -1, 1, -1, -1])  # Example usage, should return 10 01010

Array(11, dtype=int64)

In [44]:
i = [[[1, 2], [2]]]
i + [jnp.zeros((2,))]  # Example usage, should return [[1], [2], [0.0, 0.0]]

[[[1, 2], [2]], Array([0., 0.], dtype=float64)]

In [45]:
F = [[[[1, 2], [1], [1]]], [[[1, 2], [1], [1]]], [[[1, 2], [1], [1]]]]
for i in range(len(F)):
    if len(F[i]) < 8:
        zeros_arrays = [
            jnp.zeros((4,), dtype=jnp.float64) for _ in range(8 - len(F[i]))
        ]
        F[i] = F[i] + zeros_arrays

In [46]:
from pprint import pprint

pprint("F:")
for i in range(len(F)):
    print("length of F[{}]: {}".format(i, len(F[i])))
    pprint(F[i])

'F:'
length of F[0]: 8
[[[1, 2], [1], [1]],
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64)]
length of F[1]: 8
[[[1, 2], [1], [1]],
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64)]
length of F[2]: 8
[[[1, 2], [1], [1]],
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64),
 Array([0., 0., 0., 0.], dtype=float64)]
