In [23]:
# ruff: noqa
import os

os.sys.path.append("..")
from feedback_grape.fgrape import optimize_pulse_with_feedback
from feedback_grape.utils.operators import (
    sigmap,
    sigmam,
    create,
    destroy,
    identity,
    cosm,
    sinm,
)
from feedback_grape.utils.states import basis, fock
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
from jax.scipy.linalg import expm

## defining parameterized operations that are repeated num_time_steps times

In [24]:
N_cav = 5

In [25]:
def qubit_unitary(alpha):
    """
    TODO: see if alpha, can be sth elser other than scalar, and if the algo understands this
    see if there can be multiple params like alpha and beta input
    """
    return expm(
        -1j
        * (
            alpha * tensor(identity(N_cav), sigmap())
            + alpha.conjugate() * tensor(identity(N_cav), sigmam())
        )
        / 2
    )

In [26]:
def qubit_cavity_unitary(beta):
    return expm(
        -1j
        * (
            beta
            * (
                tensor(destroy(N_cav), identity(2))
                @ tensor(identity(N_cav), sigmap())
            )
            + beta.conjugate()
            * (
                tensor(create(N_cav), identity(2))
                @ tensor(identity(N_cav), sigmam())
            )
        )
        / 2
    )

In [27]:
alpha = 0.1 + 0.1j
beta = 0.1 + 0.1j
Uq = qubit_unitary(alpha)
Uqc = qubit_cavity_unitary(beta)
print(
    "Uq unitary check:",
    jnp.allclose(Uq.conj().T @ Uq, jnp.eye(Uq.shape[0]), atol=1e-7),
)
print(
    "Uqc unitary check:",
    jnp.allclose(Uqc.conj().T @ Uqc, jnp.eye(Uqc.shape[0]), atol=1e-7),
)

Uq unitary check: True
Uqc unitary check: True


In [28]:
qubit_unitary(0.1)

Array([[0.99875027+0.j        , 0.        -0.04997917j,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ],
       [0.        -0.04997917j, 0.99875027+0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ],
       [0.        +0.j        , 0.        +0.j        ,
        0.99875027+0.j        , 0.        -0.04997917j,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ],
       [0.        +0.j        , 0.        +0.j        ,
        0.        -0.04997917j, 0.99875027+0.j        ,
        0.        +0.j        , 0.        +0.

In [29]:
from feedback_grape.utils.operators import create, destroy


def povm_measure_operator(measurement_outcome, gamma, delta):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    number_operator = tensor(create(N_cav) @ destroy(N_cav), identity(2))
    angle = (gamma * number_operator) + delta / 2
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

### defining initial (thermal) state

In [30]:
# initial state is a thermal state coupled to a qubit in the ground state?
n_average = 1
# natural logarithm
beta = jnp.log((1 / n_average) + 1)
diags = jnp.exp(-beta * jnp.arange(N_cav))
normalized_diags = diags / jnp.sum(diags, axis=0)
rho_cav = jnp.diag(normalized_diags)

In [31]:
rho_cav.shape

(5, 5)

In [32]:
rho0 = tensor(rho_cav, basis(2, 0) @ basis(2, 0).conj().T)

In [33]:
from feedback_grape.fgrape import _probability_of_a_measurement_outcome_given_a_certain_state
_probability_of_a_measurement_outcome_given_a_certain_state(rho0, 1, povm_measure_operator, [0.1, -3*jnp.pi / 2])

Array(0.88501454, dtype=float64)

### defining target state

In [34]:
psi_target = tensor((fock(N_cav, 1) + fock(N_cav, 2) + fock(N_cav, 3)) / jnp.sqrt(3), basis(2))

rho_target = psi_target @ psi_target.conj().T
rho_target.shape

(10, 10)

In [35]:
from feedback_grape.utils.fidelity import fidelity
print(fidelity(U_final=rho0, C_target=rho_target, type="density"))

0.3880207144091314


### initialize random params

In [None]:
num_time_steps = 5
num_of_iterations = 1000
learning_rate = 0.05
# avg_photon_numer = 2 When testing kitten state
import jax
import numpy as np

key = jax.random.PRNGKey(0)
two_pi = 2 * np.pi

def random_params(key):
    subkeys = jax.random.split(key, 4)
    return {
        "POVM": {
            "gamma": float(jax.random.uniform(subkeys[0], (), minval=-two_pi, maxval=two_pi)),
            "delta": float(jax.random.uniform(subkeys[1], (), minval=-two_pi, maxval=two_pi)),
        },
        "U_q": {
            "alpha": float(jax.random.uniform(subkeys[2], (), minval=-two_pi, maxval=two_pi)),
        },
        "U_qc": {
            "beta": float(jax.random.uniform(subkeys[3], (), minval=-two_pi, maxval=two_pi)),
        }
    }

result = None
while True:
    key, subkey = jax.random.split(key)
    initial_params = random_params(subkey)
    result = optimize_pulse_with_feedback(
        U_0=rho0,
        C_target=rho_target,
        parameterized_gates=[
            povm_measure_operator,
            qubit_unitary,
            qubit_cavity_unitary,
        ],
        measurement_indices=[0],
        initial_params=initial_params,
        num_time_steps=num_time_steps,
        mode="nn",
        goal="fidelity",
        optimizer="adam",
        max_iter=num_of_iterations,
        convergence_threshold=1e-20,
        learning_rate=learning_rate,
        type="density",
    )
    if 1 - result.final_fidelity <= 0.1:
        print("Final fidelity:", result.final_fidelity)
        print("Final params:", result.arr_of_povm_params)
        break

In [None]:
result

FgResult(optimized_rnn_parameters={'params': {'Dense_0': {'bias': Array([-0.02477128, -0.03159048,  0.09968101,  0.4406398 ], dtype=float32), 'kernel': Array([[-2.90966868e-01,  1.34036928e-01, -4.69395667e-01,
         1.22098148e-01],
       [ 7.39660114e-02,  1.36926293e-01,  3.61723781e-01,
        -3.82151127e-01],
       [-5.01855165e-02,  3.89884785e-02,  5.97730637e-01,
         1.59703195e+00],
       [-9.29817930e-02,  1.32404715e-02,  8.12745333e-01,
         1.82051718e-01],
       [-1.55934066e-01, -1.11989900e-01, -4.06773165e-02,
        -1.96117282e-01],
       [ 1.07954070e-02, -8.18710949e-04, -7.63501287e-01,
        -4.67290074e-01],
       [-2.70321611e-02, -3.29068638e-02, -2.76853710e-01,
        -1.22758424e+00],
       [ 1.09852508e-01,  9.81816947e-02,  1.48313180e-01,
        -5.74714363e-01],
       [-9.82326176e-03,  1.19504243e-01, -1.49631172e-01,
        -2.74633616e-01],
       [ 2.21493319e-02, -2.18173885e-03, -1.11455373e-01,
        -8.82297516e-01]

In [None]:
print(result.final_purity)

None


In [None]:
print(result.final_fidelity)

0.7276930828249264


In [None]:
print(result.iterations)

1000


In [None]:
print(result.arr_of_povm_params)

[[[0.1, 0.1], [0.1], [0.1]], [[-1.1960323259518207e-05, -1.0174736662565964e-05], [1.432765623141281], [4.084984046774791]], [[-5.488541162179325e-05, -4.9437703244878395e-05], [-2.8901165871260597], [2.707455909758734]], [[-6.937244158872202e-05, -5.910006171704757e-05], [-2.5715586835600095], [1.7188638336044595]], [[-8.05808833242716e-05, -6.244098723934277e-05], [-3.0635717895558887], [-0.0012586085398046953]]]


In [None]:
print(result.arr_of_povm_params[1])

[[-1.1960323259518207e-05, -1.0174736662565964e-05], [1.432765623141281], [4.084984046774791]]
