In [3]:
import numpy as np
import qutip
import pickle
import functools

import dynamiqs as dq
dq.set_device('cpu')

from CoupledQuantumSystems.drive import DriveTerm
from CoupledQuantumSystems.IFQ import gfIFQ

import jax
import jax.numpy as jnp
import optax
import warnings
warnings.filterwarnings("ignore")

In [4]:
def square_pulse_with_rise_fall_jnp(t,
                                args = {}):
    
    w_d = args['w_d']
    amp = args['amp']
    t_start = args.get('t_start', 0)  # Default start time is 0
    t_rise = args.get('t_rise', 1e-10)  # Default rise time is 0 for no rise
    t_square = args.get('t_square', 0)  # Duration of constant amplitude

    def cos_modulation():
        return 2 * jnp.pi * amp * jnp.cos(w_d * 2 * jnp.pi * t)
    
    t_fall_start = t_start + t_rise + t_square  # Start of fall
    t_end = t_fall_start + t_rise  # End of the pulse

    before_pulse_start = jnp.less(t, t_start)
    during_rise_segment = jnp.logical_and(jnp.greater(t_rise, 0), jnp.logical_and(jnp.greater_equal(t, t_start), jnp.less_equal(t, t_start + t_rise)))
    constant_amplitude_segment = jnp.logical_and(jnp.greater(t, t_start + t_rise), jnp.less_equal(t, t_fall_start))
    during_fall_segment = jnp.logical_and(jnp.greater(t_rise, 0), jnp.logical_and(jnp.greater(t, t_fall_start), jnp.less_equal(t, t_end)))

    return jnp.where(before_pulse_start, 0,
                    jnp.where(during_rise_segment, jnp.sin(jnp.pi * (t - t_start) / (2 * t_rise)) ** 2 * cos_modulation(),
                            jnp.where(constant_amplitude_segment, cos_modulation(),
                                        jnp.where(during_fall_segment, jnp.sin(jnp.pi * (t_end - t) / (2 * t_rise)) ** 2 * cos_modulation(), 0))))

solver = dq.solver.Dopri8(
                    rtol= 1e-06,
                    atol= 1e-06,
                    safety_factor= 0.6,
                    min_factor= 0.1,
                    max_factor = 4.0,
                    max_steps = int(1e4*1000),
                )


# -------------------------------
# Define your parameters and objects
# -------------------------------
EJ = 3
EJoverEC = 6
EJoverEL = 25
EC = EJ / EJoverEC
EL = EJ / EJoverEL

qbt = gfIFQ(EJ=EJ, EC=EC, EL=EL, flux=0, truncated_dim=20)

driven_op=qutip.Qobj(qbt.fluxonium.n_operator(energy_esys=True))

e_ops = [
    qutip.basis(qbt.truncated_dim, i) * qutip.basis(qbt.truncated_dim, i).dag()
    for i in range(10)
]

element = np.abs(
    qbt.fluxonium.matrixelement_table("n_operator", evals_count=3)[1, 2]
)
freq = (qbt.fluxonium.eigenvals()[2] - qbt.fluxonium.eigenvals()[1]) * 2 * np.pi
init_wd = qbt.fluxonium.eigenvals()[2] - qbt.fluxonium.eigenvals()[1]

# -------------------------------
# Objective function
# -------------------------------
def objective(t_tot, amp, w_d,ramp):
    tlist = jnp.linspace(0, t_tot * (1+ramp), 100)

    initial_states = [
        qutip.basis(qbt.truncated_dim, 1),
        qutip.basis(qbt.truncated_dim, 2),
    ]

    def _H(t,args):
        _H =  qbt.diag_hamiltonian.full()
        _H += driven_op.full() * square_pulse_with_rise_fall_jnp(t, args)
        return _H 
    f = functools.partial(_H, args = {
                                    "w_d": w_d,    # No extra 2pi factor
                                    "amp": amp,    # No extra 2pi factor
                                    "t_square": t_tot * (1-ramp),
                                    "t_rise": t_tot * ramp,
                                })
    H =  dq.timecallable(f)

    results = dq.sesolve(
                    H = H,
                    psi0 = initial_states,
                    tsave = tlist,
                    exp_ops = e_ops,
                    solver = solver,
                    options=dq.Options(progress_meter = None),
                    )
    one_minus_pop2 = abs((1 - results.expects[0][2][-1]).real)  # from state |1> to |2>
    one_minus_pop1 = abs((1 - results.expects[1][1][-1]).real)  # from state |2> to |1>

    return one_minus_pop2 + one_minus_pop1



t_tot_list = np.linspace(50, 250, 21)
for t_tot in t_tot_list:
    def dq_objective(x):
        amp, w_d,ramp = x
        return objective(t_tot=t_tot, amp=amp, w_d=w_d,ramp=ramp)
    
    amp_guess = 50/t_tot * 22

    func = jax.value_and_grad(dq_objective)
    params = jnp.array([amp_guess, init_wd, 0.15]) 

    optimizer = optax.nadam(learning_rate=jnp.array([1e-3,
                                                    1e-4,
                                                    1e-3])) 
    opt_state = optimizer.init( params )

    num_steps = 1000
    for step in range(num_steps):
        print(f"iter: {step}, params: {params}")
        val, grads = func(params)
        print(f"\t val={val} grads: {grads}")
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

    print(f'Optimized params: {params}')
    
    # Extract best values
    best_amp = params[0]
    best_w_d = params[1]
    best_ramp = params[2]
    best_cost = val

    # Optionally, store results in a pickle
    results_dict = {
        "best_cost": best_cost,
        "best_amp": best_amp,
        "best_w_d": best_w_d,
        "best_ramp": best_ramp,
    }

    with open(f"dq_optimized_results_{t_tot}.pkl", "wb") as f:
        pickle.dump(results_dict, f)

iter: 0, params: [2.2000000e+01 3.2520439e-03 1.5000001e-01]
	 val=1.5756362676620483 grads: [-3.4779453e+01  1.0640644e+05  3.6932852e+02]
iter: 1, params: [2.2001474e+01 3.1046765e-03 1.4852633e-01]
	 val=0.2995262145996094 grads: [ 2.2857506e+01 -6.8315148e+04 -2.2225786e+02]
iter: 2, params: [2.2001162e+01 3.1346960e-03 1.4879014e-01]
	 val=0.6525201797485352 grads: [ 2.8799770e+01 -8.4015391e+04 -2.9533163e+02]
iter: 3, params: [2.2000608e+01 3.1877793e-03 1.4932163e-01]
	 val=1.8474371433258057 grads: [-2.2804089e+01  6.8423078e+04  2.3793832e+02]
iter: 4, params: [2.2000877e+01 3.1599079e-03 1.4903361e-01]
	 val=1.8394994735717773 grads: [ 2.3678612e+01 -7.0381844e+04 -2.4140982e+02]
iter: 5, params: [2.2000519e+01 3.1946765e-03 1.4937371e-01]
	 val=1.7973053455352783 grads: [ 3.5664040e+01 -1.0850958e+05 -3.6749664e+02]
iter: 6, params: [2.1999928e+01 3.2535784e-03 1.4995202e-01]
	 val=1.714721918106079 grads: [-3.0513477e+01  9.3489609e+04  3.2542279e+02]
iter: 7, params: [2.2

KeyboardInterrupt: 