# C. State preparation from a thermal state with Jaynes-Cummings controls

In [38]:
# ruff: noqa
import os

os.sys.path.append("../../../..")

In [39]:
from feedback_grape.fgrape import optimize_pulse
from feedback_grape.utils.operators import (
    sigmap,
    sigmam,
    create,
    destroy,
    identity,
    cosm,
    sinm,
)
from feedback_grape.utils.states import basis, fock
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
import jax
from jax.scipy.linalg import expm

## defining parameterized operations that are repeated num_time_steps times

takes 9 minutes on GPU instead of 17 on CPU

In [40]:
N_cav = 20

In [41]:
def qubit_unitary(alphas):
    alpha_re, alpha_im = alphas
    alpha = alpha_re + 1j * alpha_im
    return tensor(
        identity(N_cav),
        expm(-1j * (alpha * sigmap() + alpha.conjugate() * sigmam()) / 2),
    )

In [42]:
def qubit_cavity_unitary(betas):
    beta_re, beta_im = betas
    beta = beta_re + 1j * beta_im
    return expm(
        -1j
        * (
            beta * (tensor(destroy(N_cav), sigmap()))
            + beta.conjugate() * (tensor(create(N_cav), sigmam()))
        )
        / 2
    )

### povm_measure_operator (callable): <br>
    - It should take a measurement outcome and list of params as input
    - The measurement outcome options are either 1 or -1

In [43]:
from feedback_grape.utils.operators import create, destroy

def povm_measure_operator(measurement_outcome, params):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    gamma, delta = params
    cav_operator = gamma * create(N_cav) @ destroy(N_cav) + delta / 2  * identity(N_cav)
    angle = tensor(cav_operator, identity(2))
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

### defining initial (thermal) state

In [44]:
# initial state is a thermal state coupled to a qubit in the ground state?
n_average = 1
# natural logarithm
beta = jnp.log((1 / n_average) + 1)
diags = jnp.exp(-beta * jnp.arange(N_cav))
normalized_diags = diags / jnp.sum(diags, axis=0)
rho_cav = jnp.diag(normalized_diags)

In [45]:
rho_cav.shape

(20, 20)

In [46]:
rho0 = tensor(rho_cav, basis(2, 0) @ basis(2, 0).conj().T)

In [47]:
from feedback_grape.utils.povm import (
    _probability_of_a_measurement_outcome_given_a_certain_state,
)

_probability_of_a_measurement_outcome_given_a_certain_state(
    rho0,
    +1,
    povm_measure_operator(+1, params=[0.1, -3 * jnp.pi / 2]),
    povm_measure_operator(-1, params=[0.1, -3 * jnp.pi / 2]),
    evo_type="density"
)

Array(0.40800029, dtype=float64)

### defining target state

In [48]:
psi_target = tensor(
    (fock(N_cav, 1) + fock(N_cav, 2) + fock(N_cav, 3)) / jnp.sqrt(3), basis(2)
)
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = psi_target @ psi_target.conj().T
rho_target.shape

(40, 40)

In [49]:
from feedback_grape.utils.fidelity import fidelity

print(fidelity(U_final=rho0, C_target=rho_target, evo_type="density"))

0.14583347241097044


In [50]:
# Print the eigenvalues of rho0
eigenvalues = jnp.linalg.eigvalsh(rho_target)
print("Eigenvalues of rho0:", eigenvalues)

Eigenvalues of rho0: [-1.46938075e-16  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  1.92355405e-18  3.83987166e-18  5.60289567e-17  1.00000000e+00]


### initialize random params

In [51]:
import jax

print(
    jax.random.uniform(
        jax.random.PRNGKey(0),
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jax.numpy.pi,
        maxval=jax.numpy.pi,
    ).tolist()
)

[[-0.5123490775685872, -1.782568231202715]]


In [52]:
import jax
from feedback_grape.fgrape import Gate

num_time_steps = 5
num_of_iterations = 1000
learning_rate = 0.02
# avg_photon_numer = 2 When testing kitten state
# If you provide param_constraints for only one parameter, the current behavior throws an error if you don't provide param_constraints for all parameters for all gates
key1, key2, key3 = jax.random.split(jax.random.PRNGKey(42), 3)
measure = Gate(
    gate=povm_measure_operator,
    initial_params=jax.random.uniform(
        key1,
        shape=(2,),  # 2 for gamma and delta
        minval=-2 * jnp.pi,
        maxval=2 * jnp.pi,
    ),
    measurement_flag=True,
    # param_constraints=[[0, jnp.pi], [-2*jnp.pi, 2*jnp.pi]],
)

qub_unitary = Gate(
    gate=qubit_unitary,
    initial_params=jax.random.uniform(
        key2,
        shape=(2,),  # 2 for gamma and delta
        minval=-2 * jnp.pi,
        maxval=2 * jnp.pi,
    ),
    measurement_flag=False,
    # param_constraints=[[-2*jnp.pi, 2*jnp.pi], [-2*jnp.pi, 2*jnp.pi]],
)

qub_cav = Gate(
    gate=qubit_cavity_unitary,
    initial_params=jax.random.uniform(
        key3,
        shape=(2,),  # 2 for gamma and delta
        minval=-2 * jnp.pi,
        maxval=2 * jnp.pi,
    ),
    measurement_flag=False,
    # param_constraints=[[-jnp.pi, jnp.pi], [-jnp.pi, jnp.pi]],
)

system_params = [measure, qub_unitary, qub_cav]

for reward_weights in [[1]*num_time_steps, [0]*(num_time_steps-1) + [1]]:
    result = optimize_pulse(
        U_0=rho0,
        C_target=rho_target,
        system_params=system_params,
        num_time_steps=num_time_steps,
        reward_weights=reward_weights,
        mode="lookup",
        goal="fidelity",
        max_iter=num_of_iterations,
        convergence_threshold=1e-6,
        learning_rate=learning_rate,
        evo_type="density",
        batch_size=10,
        eval_time_steps=num_time_steps*2,
    )

    print(f"reward weights: {reward_weights}\n final_fidelity: {result.final_fidelity}\n fidelity_each_timestep: {jnp.array(result.fidelity_each_timestep)}\n")

reward weights: [1, 1, 1, 1, 1]
 final_fidelity: 0.09844387294064301
 fidelity_each_timestep: [0.14583347 0.15010296 0.06118554 0.11299171 0.3734138  0.46918595
 0.09148511 0.11923448 0.10266431 0.09725998 0.09844387]

reward weights: [0, 0, 0, 0, 1]
 final_fidelity: 0.041271467876362794
 fidelity_each_timestep: [0.14583347 0.00673517 0.00997238 0.07401794 0.04399013 0.87899288
 0.03968911 0.43266268 0.03203552 0.24068979 0.04127147]



In [53]:
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho0, evo_type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, evo_type="density"),
    )

initial fidelity: 0.14583347241097044
fidelity of state 0: 0.11797565826025147
fidelity of state 1: 0.022981809938812877
fidelity of state 2: 0.05519759535173936
fidelity of state 3: 0.04038933114610371
fidelity of state 4: 0.010299283411518973
fidelity of state 5: 0.03175239854550286
fidelity of state 6: 0.012479172881022866
fidelity of state 7: 0.07153337995110355
fidelity of state 8: 0.022923297683128344
fidelity of state 9: 0.027182751594443968


In [54]:
result.returned_params

[[Array([[-4.71238898e+00,  6.11581715e-17],
         [-4.71238898e+00,  6.11581715e-17],
         [-4.71238898e+00,  6.11581715e-17],
         [-4.71238898e+00,  6.11581715e-17],
         [-4.71238898e+00,  6.11581715e-17],
         [-4.71238898e+00,  6.11581715e-17],
         [-4.71238898e+00,  6.11581715e-17],
         [-4.71238898e+00,  6.11581715e-17],
         [-4.71238898e+00,  6.11581715e-17],
         [-4.71238898e+00,  6.11581715e-17]], dtype=float64),
  Array([[2.9689331 , 1.37856256],
         [3.9058591 , 1.40551967],
         [2.9689331 , 1.37856256],
         [3.9058591 , 1.40551967],
         [2.9689331 , 1.37856256],
         [3.9058591 , 1.40551967],
         [3.9058591 , 1.40551967],
         [2.9689331 , 1.37856256],
         [2.9689331 , 1.37856256],
         [3.9058591 , 1.40551967]], dtype=float64),
  Array([[-2.90965046, -1.61421797],
         [-0.3001794 , -0.0168173 ],
         [-2.90965046, -1.61421797],
         [-0.3001794 , -0.0168173 ],
         [-2.90965

In [55]:
print(result.iterations)

1000
