# A. State preparation with Jaynes-Cummings controls

In [1]:
# ruff: noqa
import os

os.sys.path.append("..")
from feedback_grape.fgrape import optimize_pulse_with_feedback
from feedback_grape.utils.operators import (
    sigmap,
    sigmam,
    create,
    destroy,
    identity,
)
from feedback_grape.utils.states import basis, fock
from feedback_grape.utils.tensor import tensor
import jax.numpy as jnp
from jax.scipy.linalg import expm

ModuleNotFoundError: No module named 'utils'

As a preliminary step, we consider state preparation
of a target state starting from a pure state. In addition,
we assume that any coupling to an external environment
is negligible and that the parametrized controls can be
implemented perfectly.

Here no feedback is required, we are just testing the parameterized gates setup.

As a first example, we consider the state preparation
of a cavity resonantly coupled to an externally driven
qubit

Here, we consider a particular sequence of
parametrized unitary gates originally introduced by Law
and Eberly

In [None]:
N_cav = 30

In [None]:
def qubit_unitary(alpha_re, alpha_im):
    alpha = alpha_re + 1j * alpha_im
    return tensor(
        identity(N_cav),
        expm(-1j * (alpha * sigmap() + alpha.conjugate() * sigmam()) / 2),
    )

In [None]:
def qubit_cavity_unitary(beta_re):
    beta = beta_re
    return expm(
        -1j
        * (
            beta * (tensor(destroy(N_cav), sigmap()))
            + beta.conjugate() * (tensor(create(N_cav), sigmam()))
        )
        / 2
    )

In [None]:
# Uq = qubit_unitary(0.1, 0.1)
# Uqc = qubit_cavity_unitary(0.2, 0.2)
# print(
#     "Uq unitary check:",
#     jnp.allclose(Uq.conj().T @ Uq, jnp.eye(Uq.shape[0]), atol=1e-7),
# )
# print(
#     "Uqc unitary check:",
#     jnp.allclose(Uqc.conj().T @ Uqc, jnp.eye(Uqc.shape[0]), atol=1e-7),
# )

In [None]:
# qubit_unitary(0.1, 0.1).shape

In [None]:
# qubit_unitary(0.1, 0.1)

In their groundbreaking work, Law and Eberly have
shown that any arbitrary superposition of Fock states with
maximal excitation number N can be prepared out of the
ground state in a sequence of N such interleaved gates,
also providing an algorithm to find the correct angles and
interaction durations

In [None]:
# print(qubit_cavity_unitary(0.1, 0.1).shape)

In [None]:
# qubit_cavity_unitary(0.1, 0.1)

## First target is the state $ | 1 ⟩ + | 3 ⟩ $ 

In [None]:
# TODO: Why if time steps are less than 5 then it plateus at 0.5
time_steps = 5  # corressponds to maximal excitation number of an arbitrary Fock State Superposition

In [None]:
psi0 = tensor(basis(N_cav), basis(2))
psi0 = psi0 / jnp.linalg.norm(psi0)
psi_target = tensor((fock(N_cav, 1) + fock(N_cav, 3)) / jnp.sqrt(2), basis(2))
psi_target = psi_target / jnp.linalg.norm(psi_target)

In [None]:
psi0.shape

(60, 1)

In [None]:
psi_target.shape

(60, 1)

In [None]:
print(fock(N_cav, 1))

[[0.+0.j]
 [1.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]
 [0.+0.j]]


Law and Eberly provided an algorithm to determine the correct parameters for state preparation. These include:

- The rotation angle $ |\alpha| $,
- The azimuthal angle $ \arg\left(\frac{\alpha}{|\alpha|}\right) $,
- The interaction duration $ |\beta| $. <br>

So Goal is to find the best control vector (rather than control amplitudes, this time) that leads to finding the optimal state-preparation strategies. Performing as well as the Law-Eberly algorithm.

In [None]:
# Answer: see why if initial [paramaters are complex, adam fidelity is so bad and l-bfgs does error]
# --> they only train real parameters, so we need to split the complex parameters into real and imaginary parts

## Optimizing
Currently l-bfgs with the same learning rate of 0.3 converges at a local minimum of 0.5, adam also converges at 0.5 but at smaller learning rates

In [None]:
from feedback_grape.utils.fidelity import ket2dm
import jax

key = jax.random.PRNGKey(42)
# not provideing param_constraints just propagates the same initial_parameters for each time step
qub_unitary = {
    "gate": qubit_unitary,
    "initial_params": jax.random.uniform(
        key,
        shape=(1, 2),  # 2 for gamma and delta
        minval=-jnp.pi,
        maxval=jnp.pi,
    )[0].tolist(),
    "measurement_flag": False,
    "param_constraints": [
        [-2 * jnp.pi, 2 * jnp.pi],
        [-2 * jnp.pi, 2 * jnp.pi],
    ],
}

qub_cav = {
    "gate": qubit_cavity_unitary,
    "initial_params": jax.random.uniform(
        key,
        shape=(1, 1),  # 2 for gamma and delta
        minval=-jnp.pi,
        maxval=jnp.pi,
    )[0].tolist(),
    "measurement_flag": False,
    "param_constraints": [[-2 * jnp.pi, 2 * jnp.pi]],
}

system_params = [qub_unitary, qub_cav]


result = optimize_pulse_with_feedback(
    U_0=ket2dm(psi0),
    C_target=ket2dm(psi_target),
    system_params=system_params,
    num_time_steps=time_steps,
    max_iter=1000,
    convergence_threshold=1e-16,
    type="density",
    mode="no-measurement",
    # decay={
    #     "decay_indices": [0],  # indices of gates before which decay occurs
    #     # c_ops need to be tensored with the identity operator for the cavity
    #     # because it is used directly in the lindblad equation
    #     "c_ops": {
    #         # weird behavior when gamma is 0.00
    #         "tm": [tensor(identity(N_cav), jnp.sqrt(0.15) * sigmam())],
    #         # "tc": [tensor(identity(N_cav), jnp.sqrt(0.15) * sigmap())],
    #     },  # c_ops for each decay index
    #     "tsave": jnp.linspace(0, 1, 2),  # time grid for decay
    #     "Hamiltonian": None,
    # },
    goal="fidelity",
    learning_rate=0.02,
    batch_size=10,
    eval_batch_size=2,
)

Iteration 0, Loss: -0.143059
Iteration 10, Loss: -0.324662
Iteration 20, Loss: -0.482136
Iteration 30, Loss: -0.583045
Iteration 40, Loss: -0.631174
Iteration 50, Loss: -0.654056
Iteration 60, Loss: -0.672719
Iteration 70, Loss: -0.685403
Iteration 80, Loss: -0.692890
Iteration 90, Loss: -0.698045
Iteration 100, Loss: -0.701554
Iteration 110, Loss: -0.703945
Iteration 120, Loss: -0.705716
Iteration 130, Loss: -0.706916
Iteration 140, Loss: -0.707906
Iteration 150, Loss: -0.708878
Iteration 160, Loss: -0.709865
Iteration 170, Loss: -0.710815
Iteration 180, Loss: -0.711790
Iteration 190, Loss: -0.712906
Iteration 200, Loss: -0.714382
Iteration 210, Loss: -0.716144
Iteration 220, Loss: -0.718533
Iteration 230, Loss: -0.721827
Iteration 240, Loss: -0.727115
Iteration 250, Loss: -0.737736
Iteration 260, Loss: -0.759792
Iteration 270, Loss: -0.797262
Iteration 280, Loss: -0.833381
Iteration 290, Loss: -0.846673
Iteration 300, Loss: -0.860648
Iteration 310, Loss: -0.876126
Iteration 320, Loss

In [None]:
len(result.returned_params)

5

In [None]:
# here makes sense for each batch size we have a different set of parameters since there are no measurements and therefore no stochasticisty or randomness
result.returned_params

[[Array([[ 0.87397271, -6.22210517],
         [ 0.87397271, -6.22210517]], dtype=float64),
  Array([[-3.56962496],
         [-3.56962496]], dtype=float64)],
 [Array([[-4.15643271e-07,  3.14159260e+00],
         [-4.15643271e-07,  3.14159260e+00]], dtype=float64),
  Array([[2.77362709],
         [2.77362709]], dtype=float64)],
 [Array([[3.14159259e+00, 9.12424551e-08],
         [3.14159259e+00, 9.12424551e-08]], dtype=float64),
  Array([[-1.50698692],
         [-1.50698692]], dtype=float64)],
 [Array([[1.52230563, 6.09598281],
         [1.52230563, 6.09598281]], dtype=float64),
  Array([[2.54470129],
         [2.54470129]], dtype=float64)],
 [Array([[ 5.97007551e-08, -3.14159295e+00],
         [ 5.97007551e-08, -3.14159295e+00]], dtype=float64),
  Array([[3.1415927],
         [3.1415927]], dtype=float64)]]

In [None]:
print(result.final_state)

[[[8.27609081e-17+9.26963224e-31j 9.94359997e-17+4.11032234e-16j
   4.82353389e-09+4.25605186e-09j ... 0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j 0.00000000e+00+0.00000000e+00j]
  [9.94360042e-17-4.11032235e-16j 2.17506612e-15+3.30872248e-24j
   2.69330960e-08-1.88425082e-08j ... 0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j 0.00000000e+00+0.00000000e+00j]
  [4.82353389e-09-4.25605186e-09j 2.69330960e-08+1.88425082e-08j
   5.00000032e-01-1.48892510e-23j ... 0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j 0.00000000e+00+0.00000000e+00j]
  ...
  [0.00000000e+00+0.00000000e+00j 0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j ... 0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j 0.00000000e+00+0.00000000e+00j]
  [0.00000000e+00+0.00000000e+00j 0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j ... 0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j 0.00000000e+00+0.00000000e

In [None]:
print(result)

FgResult(optimized_trainable_parameters=[[Array([ 0.87397271, -6.22210517], dtype=float64), Array([-3.56962496], dtype=float64)], [Array([-4.15643271e-07,  3.14159260e+00], dtype=float64), Array([2.77362709], dtype=float64)], [Array([3.14159259e+00, 9.12424551e-08], dtype=float64), Array([-1.50698692], dtype=float64)], [Array([1.52230563, 6.09598281], dtype=float64), Array([2.54470129], dtype=float64)], [Array([ 5.97007551e-08, -3.14159295e+00], dtype=float64), Array([3.1415927], dtype=float64)]], iterations=627, final_state=Array([[[8.27609081e-17+9.26963224e-31j, 9.94359997e-17+4.11032234e-16j,
         4.82353389e-09+4.25605186e-09j, ...,
         0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j],
        [9.94360042e-17-4.11032235e-16j, 2.17506612e-15+3.30872248e-24j,
         2.69330960e-08-1.88425082e-08j, ...,
         0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
         0.00000000e+00+0.00000000e+00j],
 

In [None]:
result.final_fidelity

Array(1., dtype=float64)

In [None]:
from feedback_grape.utils.fgrape_helpers import prepare_parameters_from_dict


def get_trainable_parameters(
    initial_parameters, param_constraints, num_time_steps, rng_key
):
    trainable_params = []
    flat_params, _ = prepare_parameters_from_dict(initial_parameters)
    trainable_params.append(flat_params)
    for i in range(num_time_steps - 1):
        gate_params_list = []
        if param_constraints != []:
            for gate_params, gate_constraints in zip(
                flat_params, param_constraints
            ):
                sampled_params = []
                for var_bounds in gate_constraints:
                    rng_key, subkey = jax.random.split(rng_key)
                    var = jax.random.uniform(
                        subkey,
                        shape=(),
                        minval=var_bounds[0],
                        maxval=var_bounds[1],
                    )
                    sampled_params.append(var)
                gate_params_list.append(jnp.array(sampled_params))
        else:
            for gate_params in flat_params:
                sampled_params = []
                for _ in range(gate_params.shape[0]):
                    rng_key, subkey = jax.random.split(rng_key)
                    var = jax.random.uniform(
                        subkey,
                        shape=(),
                        minval=-jnp.pi,
                        maxval=jnp.pi,
                    )
                    sampled_params.append(var)
                gate_params_list.append(jnp.array(sampled_params))
        trainable_params.append(gate_params_list)

    return trainable_params

In [None]:
initial_params = {
    "qubit_unitary": [0.1, 0.1],
    "qubit_cavity_unitary": [0.2, 0.2],
}
# param_constraints = [[[0, 0.5], [0.5, 1.0]], [[1.0, 1.5], [1.5, 2.0]]]
param_constraints = []
num_time_steps = 3
key_1 = jax.random.PRNGKey(42)


trainable_params = get_trainable_parameters(
    initial_params, param_constraints, num_time_steps, key_1
)

In [None]:
from pprint import pprint

pprint(trainable_params)

[[Array([0.1, 0.1], dtype=float64), Array([0.2, 0.2], dtype=float64)],
 [Array([1.52127861, 0.99764266], dtype=float64),
  Array([1.17781588, 2.76860586], dtype=float64)],
 [Array([-0.05248534,  0.79283613], dtype=float64),
  Array([-1.34523035,  2.7576985 ], dtype=float64)]]
