In [None]:
import numpy as np
import qutip
import pickle
import functools
import nevergrad as ng
import os
import dynamiqs as dq
import multiprocessing as mp
from multiprocessing import Pool
dq.set_device('cpu')

from CoupledQuantumSystems.drive import DriveTerm
from CoupledQuantumSystems.IFQ import gfIFQ

import jax
import jax.numpy as jnp
import optax
import warnings
warnings.filterwarnings("ignore")

def square_pulse_with_rise_fall_jnp(t,
                                args = {}):
    
    w_d = args['w_d']
    amp = args['amp']
    t_start = args.get('t_start', 0)  # Default start time is 0
    t_rise = args.get('t_rise', 1e-10)  # Default rise time is 0 for no rise
    t_square = args.get('t_square', 0)  # Duration of constant amplitude

    def cos_modulation():
        return 2 * jnp.pi * amp * jnp.cos(w_d * 2 * jnp.pi * t)
    
    t_fall_start = t_start + t_rise + t_square  # Start of fall
    t_end = t_fall_start + t_rise  # End of the pulse

    before_pulse_start = jnp.less(t, t_start)
    during_rise_segment = jnp.logical_and(jnp.greater(t_rise, 0), jnp.logical_and(jnp.greater_equal(t, t_start), jnp.less_equal(t, t_start + t_rise)))
    constant_amplitude_segment = jnp.logical_and(jnp.greater(t, t_start + t_rise), jnp.less_equal(t, t_fall_start))
    during_fall_segment = jnp.logical_and(jnp.greater(t_rise, 0), jnp.logical_and(jnp.greater(t, t_fall_start), jnp.less_equal(t, t_end)))

    return jnp.where(before_pulse_start, 0,
                    jnp.where(during_rise_segment, jnp.sin(jnp.pi * (t - t_start) / (2 * t_rise)) ** 2 * cos_modulation(),
                            jnp.where(constant_amplitude_segment, cos_modulation(),
                                        jnp.where(during_fall_segment, jnp.sin(jnp.pi * (t_end - t) / (2 * t_rise)) ** 2 * cos_modulation(), 0))))

solver = dq.solver.Dopri8(
                    rtol= 1e-06,
                    atol= 1e-06,
                    safety_factor= 0.6,
                    min_factor= 0.1,
                    max_factor = 4.0,
                    max_steps = int(1e4*1000),
                )


# -------------------------------
# Define your parameters and objects
# -------------------------------
EJ = 3
EJoverEC = 6
EJoverEL = 25
EC = EJ / EJoverEC
EL = EJ / EJoverEL

qbt = gfIFQ(EJ=EJ, EC=EC, EL=EL, flux=0, truncated_dim=20)

driven_op=qutip.Qobj(qbt.fluxonium.n_operator(energy_esys=True))

e_ops = [
    qutip.basis(qbt.truncated_dim, i) * qutip.basis(qbt.truncated_dim, i).dag()
    for i in range(10)
]

element = np.abs(
    qbt.fluxonium.matrixelement_table("n_operator", evals_count=3)[1, 2]
)
freq = (qbt.fluxonium.eigenvals()[2] - qbt.fluxonium.eigenvals()[1]) * 2 * np.pi
init_wd = qbt.fluxonium.eigenvals()[2] - qbt.fluxonium.eigenvals()[1]

# -------------------------------
# Objective function
# -------------------------------
def objective(t_tot, amp, w_d,ramp):
    tlist = jnp.linspace(0, t_tot * (1+ramp), 100)

    initial_states = [
        qutip.basis(qbt.truncated_dim, 1),
        qutip.basis(qbt.truncated_dim, 2),
    ]

    def _H(t,args):
        _H =  qbt.diag_hamiltonian.full()
        _H += driven_op.full() * square_pulse_with_rise_fall_jnp(t, args)
        return _H 
    f = functools.partial(_H, args = {
                                    "w_d": w_d,    # No extra 2pi factor
                                    "amp": amp,    # No extra 2pi factor
                                    "t_square": t_tot * (1-ramp),
                                    "t_rise": t_tot * ramp,
                                })
    H =  dq.timecallable(f)

    results = dq.sesolve(
                    H = H,
                    psi0 = initial_states,
                    tsave = tlist,
                    exp_ops = e_ops,
                    solver = solver,
                    options=dq.Options(progress_meter = None),
                    )
    one_minus_pop2 = abs((1 - results.expects[0][2][-1]).real)  # from state |1> to |2>
    one_minus_pop1 = abs((1 - results.expects[1][1][-1]).real)  # from state |2> to |1>

    return one_minus_pop2 + one_minus_pop1

def optimize_for_t_tot(t_tot):
    def dq_objective(x):
        amp, w_d, ramp = x
        return objective(t_tot=t_tot, amp=amp, w_d=w_d, ramp=ramp)
    
    amp_guess = 50/t_tot * 22

    func = jax.value_and_grad(dq_objective)

    file_name = f"nevergrad_optimized_results_{t_tot}.pkl"
    if os.path.exists(file_name):
        pass
    else:
        file_name = f"nevergrad_optimized_results_{int(t_tot)}.pkl"
    results_dict = pickle.load(open(file_name, "rb"))
    params = jnp.array([results_dict["best_amp"], results_dict["best_w_d"], results_dict["best_ramp"]]) 

    val = objective(t_tot=t_tot, amp=params[0], w_d=params[1], ramp=params[2])
    print(f"cost: {val}, t_tot: {t_tot}, amp: {params[0]}, w_d: {params[1]}, ramp: {params[2]}")

    optimizer = optax.nadam(learning_rate=jnp.array([1e-3, 1e-4, 1e-3])) 
    opt_state = optimizer.init(params)

    num_steps = 1000
    for step in range(num_steps):
        print(f"iter: {step}, params: {params}")
        val, grads = func(params)
        print(f"\t val={val} grads: {grads}")
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

    print(f'Optimized params for t_tot={t_tot}: {params}')
    
    # Extract best values
    best_amp = params[0]
    best_w_d = params[1]
    best_ramp = params[2]
    best_cost = val

    # Store results in a pickle
    results_dict = {
        "best_cost": best_cost,
        "best_amp": best_amp,
        "best_w_d": best_w_d,
        "best_ramp": best_ramp,
    }

    with open(f"dq_optimized_results_{t_tot}.pkl", "wb") as f:
        pickle.dump(results_dict, f)
    
    return t_tot, results_dict

if __name__ == '__main__':
    t_tot_list = np.linspace(50, 250, 21)
    
    # # Use all available CPU cores except one
    # num_processes = max(1, mp.cpu_count() - 1)
    num_processes = 1 # minimal example
    # Run optimizations in parallel
    with Pool(processes=num_processes) as pool:
        results = pool.map(optimize_for_t_tot, t_tot_list)
    
    # Collect and print all results
    for t_tot, result in sorted(results):
        print(f"\nResults for t_tot={t_tot}:")
        print(f"Best cost: {result['best_cost']}")
        print(f"Best parameters: amp={result['best_amp']}, w_d={result['best_w_d']}, ramp={result['best_ramp']}")