In [1]:
import jax
import jax.numpy as jnp
import optax
import equinox as eqx
import functools
import jaxopt
jax.config.update("jax_enable_x64", True)
import numpy as np

In [2]:
np.power(2, 8)

256

In [3]:
from jax.scipy.linalg import sqrtm
from quocslib.tools.fidelities import fidelity_AD
from quocslib.timeevolution.piecewise_integrator_AD import pw_final_evolution_AD_scan, pw_final_evolution_AD


class IsingModel(eqx.Module):

    n_qubits: int
    J: float
    g: float
    n_slices: int
    H_drift: jnp.ndarray
    H_control: jnp.ndarray
    rho_0: jnp.ndarray
    sqrt_rho_target: jnp.ndarray
    rho_target: jnp.ndarray
    rho_final: jnp.ndarray
    u0: jnp.array
    # _pw_evolution_transform: callable


    def __init__(self, args_dict: dict = None):
        if args_dict is None:
            args_dict = {}
        ################################################################################################################
        # Dynamics variables
        ################################################################################################################
        self.n_qubits = args_dict.setdefault("n_qubits", 5)
        self.J = args_dict.setdefault("J", 1)
        self.g = args_dict.setdefault("g", 2)
        self.n_slices = args_dict.setdefault("n_slices", 100)

        self.H_drift = jnp.asarray(get_static_hamiltonian(self.n_qubits, self.J, self.g))
        self.H_control = jnp.asarray(get_control_hamiltonian(self.n_qubits))
        self.rho_0 = jnp.asarray(get_initial_state(self.n_qubits))
        self.rho_target = jnp.asarray(get_target_state(self.n_qubits))
        self.rho_final = jnp.asarray(jnp.zeros_like(self.rho_target))
        self.sqrt_rho_target = sqrtm(self.rho_target)
        self.u0 = jnp.identity(2 ** self.n_qubits, dtype=np.complex128)

        # Let JAX know to jit the following function
        @jax.jit
        def _pw_evolution_transform(drive, dt):
            """
            A wrapper function for the piecewise evolution function of QuOCS
            :param drive: list of drive pulses
            :param dt: time step
            :return: final unitary propagator
            """
            return pw_final_evolution_AD(drive,
                                         self.H_drift,
                                         jnp.asarray([self.H_control]),
                                         self.n_slices,
                                         dt,
                                         jnp.identity(2 ** self.n_qubits, dtype=np.complex128))

        # self._pw_evolution_transform = _pw_evolution_transform

    def get_control_Hamiltonians(self):
        return self.H_control

    def get_drift_Hamiltonian(self):
        return self.H_drift

    def get_target_state(self):
        return self.rho_target

    def get_initial_state(self):
        return self.rho_0

    def get_propagator(self,
                       pulse: jnp.ndarray,
                       timegrid: jnp.ndarray) -> jnp.ndarray:
        """
        Function to calculate the propagator from the pulses, parameters and timegrids.
        :param pulses_list:
        :param time_grids_list:
        :param parameters_list:
        :return: final propagator
        """

        # drive = pulses_list[0, :].reshape(1, len(pulses_list[0, :]))
        drive = pulse.reshape(1, len(pulse))
        # time_grid = time_grids_list[0, :]
        dt = timegrid[-1] / len(timegrid)

        # Compute the time evolution
        # propagator = self._pw_evolution_transform(drive, dt)
        propagator = pw_final_evolution_AD_scan(drive,
                                         self.H_drift,
                                         jnp.asarray([self.H_control]),
                                         self.n_slices,
                                         dt,
                                         self.u0)

        return propagator

    def get_final_state(self,
                pulse: jnp.ndarray,
                timegrid: jnp.ndarray) -> jnp.array:
        """
        Function to calculate the final state from the pulse.
        :param pulses: jnp.arrays of the pulses to be optimized.
        :param timegrids: jnp.arrays of the timegrids connected to the pulses.
        :param parameters: jnp.array of the parameters to be optimized.
        :return dict: The figure of merit in a dictionary
        """
        U_final = self.get_propagator(pulse=pulse, timegrid=timegrid)
        # print(U_final)
        rho_final = U_final @ self.rho_0 @ U_final.T.conj()
        return rho_final
        # fidelity = fom_funct(rho_final, self.rho_target)
        # return {"FoM": fidelity}


i2 = np.eye(2)
sz = 0.5 * np.array([[1, 0], [0, -1]], dtype=np.complex128)
sx = 0.5 * np.array([[0, 1], [1, 0]], dtype=np.complex128)
psi0 = np.array([[1, 0], [0, 0]], dtype=np.complex128)
psiT = np.array([[0, 0], [0, 1]], dtype=np.complex128)


def tensor_together(A):
    res = np.kron(A[0], A[1])
    if len(A) > 2:
        for two in A[2:]:
            res = np.kron(res, two)
    else:
        res = res
    return res


def get_static_hamiltonian(nqu, J, g):
    dim = 2**nqu
    H0 = np.zeros((dim, dim), dtype=np.complex128)
    for j in range(nqu):
        # set up holding array
        rest = [i2] * nqu
        # set the correct elements to sz
        # check, so we can implement a loop around
        if j == nqu - 1:
            idx1 = j
            idx2 = 0
        else:
            idx1 = j
            idx2 = j + 1
        rest[idx1] = sz
        rest[idx2] = sz
        H0 = H0 - J * tensor_together(rest)

    for j in range(nqu):
        # set up holding array
        rest = [i2] * nqu
        # set the correct elements to sz
        # check, so we can implement a loop around
        if j == nqu - 1:
            idx1 = j
            idx2 = 1
        elif j == nqu - 2:
            idx1 = j
            idx2 = 0
        else:
            idx1 = j
            idx2 = j + 2
        rest[idx1] = sz
        rest[idx2] = sz
        H0 = H0 - g * tensor_together(rest)
    return H0


def get_control_hamiltonian(nqu: int):
    dim = 2**nqu
    H_at_t = np.zeros((dim, dim), dtype=np.complex128)
    for j in range(nqu):
        # set up holding array
        rest = [i2] * nqu
        # set the correct elements to sx
        rest[j] = sx
        H_at_t = H_at_t + tensor_together(rest)
    return H_at_t


def get_initial_state(nqu: int):
    state = [psi0] * nqu
    return tensor_together(state)


def get_target_state(nqu: int):
    state = [psiT] * nqu
    return tensor_together(state)


@jax.jit
def fom_funct(rho_evolved, sqrt_rho_aim):
    """
    Function to calculate the overlap between two density matrices.
    :param rho_evolved:
    :param rho_aim:
    :return: overlap fidelity
    """
    fidelity = fidelity_AD(sqrt_rho_aim, rho_evolved)
    return fidelity

In [4]:
from functools import partial


# @partial(jax.jit, static_argnums=2)
@jax.jit
def loss_function(params, model, timegrid):
  final_state = model.get_final_state(pulse=params, timegrid=timegrid)
  loss = 1.0 - fom_funct(final_state, model.sqrt_rho_target)
  return loss

In [5]:
model = IsingModel(args_dict={"n_qubits": 5, "J": 1, "g": 2, "n_slices": 100})

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [6]:
# Initial parameters
timegrid = jnp.linspace(0.0, 1.0, model.n_slices)
key = jax.random.PRNGKey(50)
init_params = jax.random.normal(key, (model.n_slices,), dtype=jnp.float64) * 30

In [7]:
# # Unconstrained optimization
# params = init_params
# solver_unconstrained = jaxopt.LBFGS(fun=loss_function, maxiter = 100, verbose=True)
# result_unconstrained = solver_unconstrained.run(params, model, timegrid)

In [8]:
# loss_function(result_unconstrained.params, model, timegrid)
# result_unconstrained.params

In [9]:
# Box constrained optimization
from jaxopt import ScipyBoundedMinimize
lbounds = [-100.0] * model.n_slices
ubounds = [100.0] * model.n_slices
bounds = (lbounds, ubounds)
lbfgsb = ScipyBoundedMinimize(fun=loss_function, method="l-bfgs-b", options={'disp': True})
result_bounded = lbfgsb.run(init_params, bounds=bounds, model=model, timegrid=timegrid)

RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =          100     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  9.07539D-01    |proj g|=  2.91678D-03

At iterate    1    f=  9.07217D-01    |proj g|=  2.92581D-03
  ys=-1.153E-06  -gs= 3.213E-04 BFGS update SKIPPED

At iterate    2    f=  2.13250D-01    |proj g|=  2.46702D-03

At iterate    3    f=  1.79902D-01    |proj g|=  3.62414D-03

At iterate    4    f=  1.53983D-01    |proj g|=  2.13901D-03

At iterate    5    f=  1.41952D-01    |proj g|=  1.46240D-03

At iterate    6    f=  1.24540D-01    |proj g|=  2.40044D-03

At iterate    7    f=  1.06012D-01    |proj g|=  2.93769D-03

At iterate    8    f=  8.99431D-02    |proj g|=  2.42906D-03

At iterate    9    f=  6.83901D-02    |proj g|=  1.30597D-03

At iterate   10    f=  5.59505D-02    |proj g|=  5.16569D-04

At iterate   11    f=  5.12111D-02    |proj g|=  3.01676D-04

At iterate   12    f=  4.92608D-02  

In [10]:
loss_function(result_bounded.params, model, timegrid)

Array(0.00131895, dtype=float64)

### Sparisification

In [1]:
0.95 * 0.95

0.9025

In [19]:
from jax.experimental import sparse

In [20]:
model_8 = IsingModel(args_dict={"n_qubits": 8, "J": 1, "g": 2, "n_slices": 100})

In [21]:
# Sparsify the log function
@jax.jit
def loss_function(params, model, timegrid):
  final_state = model.get_final_state(pulse=params, timegrid=timegrid)
  loss = 1.0 - fom_funct(final_state, model.sqrt_rho_target)
  return loss
loss_function_sparse = sparse.sparsify(loss_function)

In [22]:
lbounds = [-100.0] * model.n_slices
ubounds = [100.0] * model.n_slices
bounds = (lbounds, ubounds)
lbfgsb = ScipyBoundedMinimize(fun=loss_function_sparse, method="l-bfgs-b", options={'disp': True})
result_bounded = lbfgsb.run(init_params, bounds=bounds, model=model_8, timegrid=timegrid)

RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =          100     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  9.75748D-01    |proj g|=  1.26414D-03

At iterate    1    f=  9.75682D-01    |proj g|=  1.26712D-03
  ys=-1.613E-07  -gs= 6.579E-05 BFGS update SKIPPED

At iterate    2    f=  2.99806D-01    |proj g|=  3.46228D-03
  ys=-7.678E-02  -gs= 1.412E-01 BFGS update SKIPPED

At iterate    3    f=  2.92351D-01    |proj g|=  2.94393D-03

At iterate    4    f=  2.42572D-01    |proj g|=  3.22604D-03

At iterate    5    f=  2.13329D-01    |proj g|=  2.20855D-03

At iterate    6    f=  1.54407D-01    |proj g|=  2.29238D-03

At iterate    7    f=  1.17306D-01    |proj g|=  2.11694D-03

At iterate    8    f=  9.19909D-02    |proj g|=  1.96433D-03

At iterate    9    f=  7.64864D-02    |proj g|=  1.28822D-03

At iterate   10    f=  7.23439D-02    |proj g|=  1.03853D-03

At iterate   11    f=  7.15126D-02    |proj g|

KeyboardInterrupt: 

Stocastic optimization

In [None]:
start_learning_rate = 0.9
# optimizer = optax.adamw(start_learning_rate)
# optimizer = optax.adamax(start_learning_rate)
optimizer = optax.sgd(start_learning_rate)
# init_params = jnp.zeros(model.n_slices, dtype=jnp.complex128)
opt_state = optimizer.init(init_params)

In [None]:
params = init_params
loss_value, grads = jax.value_and_grad(loss_function)(params, model, timegrid)

In [None]:
params = init_params
for i in range(1000):
  loss_value, grads = jax.value_and_grad(loss_function)(params, model, timegrid)
  print(f"Loss {i}: {loss_value}")
  # print(f"Grads {i}: {grads}")
  if loss_value < 1e-6:
    break
  # updates, opt_state = optimizer.update(grads, opt_state) for standard adam
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)