# A. State preparation with Jaynes-Cummings controls

In [1]:
# ruff: noqa
import os

os.sys.path.append("..")
from feedback_grape.fgrape import optimize_pulse_with_feedback
from feedback_grape.utils.operators import (
    sigmap,
    sigmam,
    create,
    destroy,
    identity,
)
from feedback_grape.utils.states import basis, fock
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
from jax.scipy.linalg import expm

As a preliminary step, we consider state preparation
of a target state starting from a pure state. In addition,
we assume that any coupling to an external environment
is negligible and that the parametrized controls can be
implemented perfectly.

Here no feedback is required, we are just testing the parameterized gates setup.

As a first example, we consider the state preparation
of a cavity resonantly coupled to an externally driven
qubit

Here, we consider a particular sequence of
parametrized unitary gates originally introduced by Law
and Eberly

In [2]:
N_cav = 30

In [3]:
def qubit_unitary(alpha_re, alpha_im):
    alpha = alpha_re + 1j * alpha_im
    return tensor(
        identity(N_cav),
        expm(-1j * (alpha * sigmap() + alpha.conjugate() * sigmam()) / 2),
    )

In [4]:
def qubit_cavity_unitary(beta_re, beta_im):
    beta = beta_re + 1j * beta_im
    return expm(
        -1j
        * (
            beta * (tensor(destroy(N_cav), sigmap()))
            + beta.conjugate() * (tensor(create(N_cav), sigmam()))
        )
        / 2
    )

In [5]:
# Uq = qubit_unitary(0.1, 0.1)
# Uqc = qubit_cavity_unitary(0.2, 0.2)
# print(
#     "Uq unitary check:",
#     jnp.allclose(Uq.conj().T @ Uq, jnp.eye(Uq.shape[0]), atol=1e-7),
# )
# print(
#     "Uqc unitary check:",
#     jnp.allclose(Uqc.conj().T @ Uqc, jnp.eye(Uqc.shape[0]), atol=1e-7),
# )

In [6]:
# qubit_unitary(0.1, 0.1).shape

In [7]:
# qubit_unitary(0.1, 0.1)

In their groundbreaking work, Law and Eberly have
shown that any arbitrary superposition of Fock states with
maximal excitation number N can be prepared out of the
ground state in a sequence of N such interleaved gates,
also providing an algorithm to find the correct angles and
interaction durations

In [8]:
# print(qubit_cavity_unitary(0.1, 0.1).shape)

In [9]:
# qubit_cavity_unitary(0.1, 0.1)

## First target is the state $ | 1 ⟩ + | 3 ⟩ $ 

In [10]:
# TODO: Why if time steps are less than 5 then it plateus at 0.5
time_steps = 5  # corressponds to maximal excitation number of an arbitrary Fock State Superposition

In [11]:
psi0 = tensor(basis(N_cav), basis(2))
psi0 = psi0 / jnp.linalg.norm(psi0)
psi_target = tensor((fock(N_cav, 1) + fock(N_cav, 3)) / jnp.sqrt(2), basis(2))
psi_target = psi_target / jnp.linalg.norm(psi_target)

In [12]:
psi0.shape

(60, 1)

In [13]:
psi_target.shape

(60, 1)

In [14]:
print(fock(N_cav, 1))

[[0.+0.j]
 [1.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]]


Law and Eberly provided an algorithm to determine the correct parameters for state preparation. These include:

- The rotation angle $ |\alpha| $,
- The azimuthal angle $ \arg\left(\frac{\alpha}{|\alpha|}\right) $,
- The interaction duration $ |\beta| $. <br>

So Goal is to find the best control vector (rather than control amplitudes, this time) that leads to finding the optimal state-preparation strategies. Performing as well as the Law-Eberly algorithm.

In [15]:
# TODO: see why if initial [paramaters are complex, adam fidelity is so bad and l-bfgs does error]
import jax.random as random

key = random.PRNGKey(42)  # Random seed for reproducibility
num_gates = len(
    [qubit_unitary, qubit_cavity_unitary]
)  # Number of parameterized gates
# doing a .tolist() will lead to weird shape if you multiply outside with 5 or sth
# playing wiht the range will increase or decrease fidelity
initial_parameters = 5 * random.uniform(
    key,
    shape=(time_steps, num_gates, 2),
    minval=-jnp.pi,
    maxval=2*jnp.pi,
)

print(f"Initial parameters: \n {initial_parameters}")

Initial parameters: 
 [[[  4.40109956 -15.00063041]
  [ 11.34795501  18.68394201]]

 [[ 16.70386422  14.63688792]
  [ 14.66972408  12.237571  ]]

 [[ -4.552878    -3.3220743 ]
  [ -4.30601676 -13.000761  ]]

 [[ 18.45859804  22.85768124]
  [ 12.01671016  -4.82850043]]

 [[-14.68603543  13.94188746]
  [ -5.23058656  29.91035793]]]


## Optimizing
Currently l-bfgs with the same learning rate of 0.3 converges at a local minimum of 0.5, adam also converges at 0.5 but at smaller learning rates

In [16]:
from feedback_grape.utils.fidelity import ket2dm

result = optimize_pulse_with_feedback(
    U_0=ket2dm(psi0),
    C_target=ket2dm(psi_target),
    parameterized_gates=[qubit_unitary, qubit_cavity_unitary],
    initial_params=initial_parameters,
    num_time_steps=time_steps,
    max_iter=1000,
    convergence_threshold=1e-16,
    type="density",
    mode="no-measurement",
    # decay={
    #     "decay_indices": [0],  # indices of gates before which decay occurs
    #     # c_ops need to be tensored with the identity operator for the cavity
    #     # because it is used directly in the lindblad equation
    #     "c_ops": {
    #         # weird behavior when gamma is 0.00
    #         "tm": [tensor(identity(N_cav), jnp.sqrt(0.0015) * sigmam())],
    #         # "tc": [tensor(identity(N_cav), jnp.sqrt(0.15) * sigmap())],
    #     },  # c_ops for each decay index
    #     "tsave": jnp.linspace(0, 1, 2),  # time grid for decay
    #     "Hamiltonian": None,
    # },
    goal="fidelity",
    learning_rate=0.01,
    batch_size=10,
    eval_batch_size=1
)

Iteration 0, Loss: -0.229757
Iteration 10, Loss: -0.325373
Iteration 20, Loss: -0.421851
Iteration 30, Loss: -0.513910
Iteration 40, Loss: -0.593821
Iteration 50, Loss: -0.653871
Iteration 60, Loss: -0.696064
Iteration 70, Loss: -0.723898
Iteration 80, Loss: -0.741365
Iteration 90, Loss: -0.753371
Iteration 100, Loss: -0.763371
Iteration 110, Loss: -0.772574
Iteration 120, Loss: -0.781272
Iteration 130, Loss: -0.789695
Iteration 140, Loss: -0.798120
Iteration 150, Loss: -0.806801
Iteration 160, Loss: -0.816083
Iteration 170, Loss: -0.826462
Iteration 180, Loss: -0.838564
Iteration 190, Loss: -0.853077
Iteration 200, Loss: -0.870567
Iteration 210, Loss: -0.891175
Iteration 220, Loss: -0.913044
Iteration 230, Loss: -0.934060
Iteration 240, Loss: -0.951622
Iteration 250, Loss: -0.965055
Iteration 260, Loss: -0.974556
Iteration 270, Loss: -0.981016
Iteration 280, Loss: -0.985373
Iteration 290, Loss: -0.988372
Iteration 300, Loss: -0.990554
Iteration 310, Loss: -0.992193
Iteration 320, Loss

In [17]:
jnp.array(result.returned_params)

Array([[[[  7.07071327, -17.47324898]],

        [[ 13.43665757,  17.34150913]]],


       [[[ 15.82165926,  15.27411773]],

        [[ 14.40872735,  11.77617119]]],


       [[[ -4.2302568 ,  -4.64533809]],

        [[ -4.84003419, -14.48472724]]],


       [[[ 20.70677358,  19.25265726]],

        [[ 12.01604936,  -3.31537511]]],


       [[[-15.65903925,  15.4406667 ]],

        [[ -6.21590111,  30.78631907]]]], dtype=float64)

In [25]:
print(result.final_state)

[[[4.55597494e-11+3.33151737e-24j 2.55393006e-08+1.50794686e-07j
   4.59907497e-06+1.20934815e-06j ... 0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j 0.00000000e+00+0.00000000e+00j]
  [2.55393006e-08-1.50794686e-07j 5.13420146e-04-1.08420217e-19j
   6.58081831e-03-1.45442011e-02j ... 0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j 0.00000000e+00+0.00000000e+00j]
  [4.59907497e-06-1.20934815e-06j 6.58081831e-03+1.45442011e-02j
   4.96359480e-01-2.42861287e-17j ... 0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j 0.00000000e+00+0.00000000e+00j]
  ...
  [0.00000000e+00+0.00000000e+00j 0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j ... 0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j 0.00000000e+00+0.00000000e+00j]
  [0.00000000e+00+0.00000000e+00j 0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j ... 0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j 0.00000000e+00+0.00000000e

In [19]:
print(result)

FgResult(optimized_trainable_parameters=Array([[[  7.07071327, -17.47324898],
        [ 13.43665757,  17.34150913]],

       [[ 15.82165926,  15.27411773],
        [ 14.40872735,  11.77617119]],

       [[ -4.2302568 ,  -4.64533809],
        [ -4.84003419, -14.48472724]],

       [[ 20.70677358,  19.25265726],
        [ 12.01604936,  -3.31537511]],

       [[-15.65903925,  15.4406667 ],
        [ -6.21590111,  30.78631907]]], dtype=float64), iterations=1000, final_state=Array([[[4.55597494e-11+3.33151737e-24j, 2.55393006e-08+1.50794686e-07j,
         4.59907497e-06+1.20934815e-06j, ...,
         0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j],
        [2.55393006e-08-1.50794686e-07j, 5.13420146e-04-1.08420217e-19j,
         6.58081831e-03-1.45442011e-02j, ...,
         0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j],
        [4.59907497e-06-1.20934815e-06j, 6.58081831e-03+1

In [20]:
result.final_fidelity

Array(0.99961268, dtype=float64)

In [21]:
# from feedback_grape.grape_paramaterized import calculate_trajectory
# from feedback_grape.utils.fidelity import fidelity

# U_final = calculate_trajectory(
#     U_0=psi0,
#     parameters=result.returned_params[0],
#     time_steps=time_steps,
#     parameterized_gates=[qubit_unitary, qubit_cavity_unitary],
#     propcomp="memory-efficient",
#     type="state",
#     decay=None,
# )
# print(
#     fidelity(C_target=psi_target, U_final=result.final_operator, type="state")
# )