# C. State preparation from a thermal state with Jaynes-Cummings controls

In [1]:
# ruff: noqa
import os

os.sys.path.append("../../../..")

In [None]:
from feedback_grape.fgrape import optimize_pulse
from feedback_grape.utils.operators import (
    sigmap,
    sigmam,
    create,
    destroy,
    identity,
    cosm,
    sinm,
)
from feedback_grape.utils.states import basis, fock
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
import jax
from jax.scipy.linalg import expm

## defining parameterized operations that are repeated num_time_steps times

takes 9 minutes on GPU instead of 17 on CPU

In [3]:
N_cav = 20

In [4]:
def qubit_unitary(alphas):
    alpha_re, alpha_im = alphas
    alpha = alpha_re + 1j * alpha_im
    return tensor(
        identity(N_cav),
        expm(-1j * (alpha * sigmap() + alpha.conjugate() * sigmam()) / 2),
    )

In [5]:
def qubit_cavity_unitary(betas):
    beta_re, beta_im = betas
    beta = beta_re + 1j * beta_im
    return expm(
        -1j
        * (
            beta * (tensor(destroy(N_cav), sigmap()))
            + beta.conjugate() * (tensor(create(N_cav), sigmam()))
        )
        / 2
    )

### povm_measure_operator (callable): <br>
    - It should take a measurement outcome and list of params as input
    - The measurement outcome options are either 1 or -1

In [None]:
from feedback_grape.utils.operators import create, destroy

def povm_measure_operator(measurement_outcome, params):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    gamma, delta = params
    number_operator = tensor(create(N_cav) @ destroy(N_cav), identity(2))
    angle = (gamma * number_operator) + delta / 2
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

### defining initial (thermal) state

In [7]:
# initial state is a thermal state coupled to a qubit in the ground state?
n_average = 1
# natural logarithm
beta = jnp.log((1 / n_average) + 1)
diags = jnp.exp(-beta * jnp.arange(N_cav))
normalized_diags = diags / jnp.sum(diags, axis=0)
rho_cav = jnp.diag(normalized_diags)

In [8]:
rho_cav.shape

(20, 20)

In [9]:
rho0 = tensor(rho_cav, basis(2, 0) @ basis(2, 0).conj().T)

In [10]:
from feedback_grape.utils.povm import (
    _probability_of_a_measurement_outcome_given_a_certain_state,
)

_probability_of_a_measurement_outcome_given_a_certain_state(
    rho0, 1, povm_measure_operator, [0.1, -3 * jnp.pi / 2], evo_type="density"
)

Array(0.9479526, dtype=float64)

### defining target state

In [11]:
psi_target = tensor(
    (fock(N_cav, 1) + fock(N_cav, 2) + fock(N_cav, 3)) / jnp.sqrt(3), basis(2)
)
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = psi_target @ psi_target.conj().T
rho_target.shape

(40, 40)

In [12]:
from feedback_grape.utils.fidelity import fidelity

print(fidelity(U_final=rho0, C_target=rho_target, evo_type="density"))

0.14583347926225051


In [13]:
# Print the eigenvalues of rho0
eigenvalues = jnp.linalg.eigvalsh(rho_target)
print("Eigenvalues of rho0:", eigenvalues)

Eigenvalues of rho0: [-2.72829216e-16 -6.02376918e-17  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]


### initialize random params

In [14]:
import jax

print(
    jax.random.uniform(
        jax.random.PRNGKey(0),
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jax.numpy.pi,
        maxval=jax.numpy.pi,
    ).tolist()
)

[[-0.5123490775685872, -1.782568231202715]]


In [None]:
import jax
from feedback_grape.fgrape import Gate

num_time_steps = 5
num_of_iterations = 1000
learning_rate = 0.02
# avg_photon_numer = 2 When testing kitten state
# If you provide param_constraints for only one parameter, the current behavior throws an error if you don't provide param_constraints for all parameters for all gates
key1, key2, key3 = jax.random.split(jax.random.PRNGKey(42), 3)
measure = Gate(
    gate=povm_measure_operator,
    initial_params=jax.random.uniform(
        key1,
        shape=(2,),  # 2 for gamma and delta
        minval=-2 * jnp.pi,
        maxval=2 * jnp.pi,
    ),
    measurement_flag=True,
    # param_constraints=[[0, jnp.pi], [-2*jnp.pi, 2*jnp.pi]],
)

qub_unitary = Gate(
    gate=qubit_unitary,
    initial_params=jax.random.uniform(
        key2,
        shape=(2,),  # 2 for gamma and delta
        minval=-2 * jnp.pi,
        maxval=2 * jnp.pi,
    ),
    measurement_flag=False,
    # param_constraints=[[-2*jnp.pi, 2*jnp.pi], [-2*jnp.pi, 2*jnp.pi]],
)

qub_cav = Gate(
    gate=qubit_cavity_unitary,
    initial_params=jax.random.uniform(
        key3,
        shape=(2,),  # 2 for gamma and delta
        minval=-2 * jnp.pi,
        maxval=2 * jnp.pi,
    ),
    measurement_flag=False,
    # param_constraints=[[-jnp.pi, jnp.pi], [-jnp.pi, jnp.pi]],
)

system_params = [measure, qub_unitary, qub_cav]


result = optimize_pulse(
    U_0=rho0,
    C_target=rho_target,
    system_params=system_params,
    num_time_steps=num_time_steps,
    mode="lookup",
    goal="fidelity",
    max_iter=num_of_iterations,
    convergence_threshold=1e-6,
    learning_rate=learning_rate,
    evo_type="density",
    batch_size=10,
)

In [16]:
print(result.final_purity)

None


In [17]:
print(result.final_fidelity)

0.8941443638691023


In [18]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho0, evo_type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, evo_type="density"),
    )

initial fidelity: 0.14583347926225051
fidelity of state 0: 0.7367146069217457
fidelity of state 1: 0.9812292203578734
fidelity of state 2: 0.9392403569468494
fidelity of state 3: 0.8706946937983159
fidelity of state 4: 0.7367146069217457
fidelity of state 5: 0.8359109991350477
fidelity of state 6: 0.9812292203578734
fidelity of state 7: 0.9392403569468494
fidelity of state 8: 0.9392403569468494
fidelity of state 9: 0.9812292203578734


In [19]:
result.returned_params

[[Array([[-4.71895968,  0.00786619],
         [-4.71895968,  0.00786619],
         [-4.71895968,  0.00786619],
         [-4.71895968,  0.00786619],
         [-4.71895968,  0.00786619],
         [-4.71895968,  0.00786619],
         [-4.71895968,  0.00786619],
         [-4.71895968,  0.00786619],
         [-4.71895968,  0.00786619],
         [-4.71895968,  0.00786619]], dtype=float64),
  Array([[0.26628386, 2.28535749],
         [3.08872086, 0.53234291],
         [0.26628386, 2.28535749],
         [3.08872086, 0.53234291],
         [0.26628386, 2.28535749],
         [3.08872086, 0.53234291],
         [3.08872086, 0.53234291],
         [0.26628386, 2.28535749],
         [0.26628386, 2.28535749],
         [3.08872086, 0.53234291]], dtype=float64),
  Array([[-1.89416644,  3.41749821],
         [-1.29979158,  1.73414231],
         [-1.89416644,  3.41749821],
         [-1.29979158,  1.73414231],
         [-1.89416644,  3.41749821],
         [-1.29979158,  1.73414231],
         [-1.29979158,  

In [20]:
print(result.iterations)

1000
