# Baseline-inspired pulses (interactive)

Adjust the controls and click **Play** to rerun the Bloch-sphere animation with a detuning shift and scaled Rabi pulses. The pulse plot below is synchronized: a marker walks along each control as the state evolves.


In [None]:

import json
from pathlib import Path

import ipywidgets as widgets
import matplotlib.pyplot as plt
from matplotlib import animation, rcParams
import numpy as np
from IPython.display import HTML, display
from qutip import Bloch, Qobj, sesolve, sigmax, sigmay, sigmaz

rcParams["animation.embed_limit"] = 120  # MB embed budget
rcParams["animation.html"] = "jshtml"


def find_baseline_dir():
    for root in [Path.cwd(), *Path.cwd().parents]:
        candidate = root / "data" / "baselines" / "_baseline_crab"
        if candidate.exists():
            return candidate
    raise FileNotFoundError("Could not find data/baselines/_baseline_crab from this working directory")


BASELINE_DIR = find_baseline_dir()
metadata = json.loads((BASELINE_DIR / "metadata.json").read_text())
arrays = np.load(BASELINE_DIR / "arrays.npz")

# Static operators and initial state (density converted to ket for sesolve)
SX = sigmax()
SY = sigmay()
SZ = sigmaz()
HX = 0.5 * SX
HZ = 0.5 * SZ

rho0 = Qobj(arrays["rho0"])
_, eigvecs = rho0.eigenstates()
psi0 = eigvecs[-1]

delta_ref = float(np.max(np.abs(arrays["Delta0"])))

T_DEFAULT = 5.0
N_SAMPLES = 1000
MAX_FRAMES = 50
COLORS = ["#1f77b4", "#ff7f0e"]

print(
    f"Baseline loaded from {BASELINE_DIR} (Nt={metadata['summary']['Nt']}, T0={arrays['T'].item()} us)."
)


Baseline loaded from /home/yehon/projects/grape-crab-qoc/data/baselines/_baseline_crab (Nt=201, T0=0.1 us).


In [None]:

# Simulation helpers


def time_grid(T_value):
    return np.linspace(0.0, T_value, N_SAMPLES)


def base_protocols(t, T_value):
    pulse_shape = np.sin(np.pi * t / T_value) ** 2
    detuning_shape = np.cos(np.pi * t / T_value)

    omega_ref = np.pi / np.trapezoid(pulse_shape, t)  # area-normalized pi pulse
    pi_omega = omega_ref * pulse_shape
    pi_delta = np.zeros_like(t)

    rap_omega = 10.0 *  pulse_shape
    rap_delta = 10.0 *  detuning_shape

    return [
        {"label": "pi pulse: Omega sin^2, Delta=0", "omega": pi_omega, "delta": pi_delta},
        {"label": "RAP: Omega sin^2, Delta cos", "omega": rap_omega, "delta": rap_delta},
    ]


def simulate_protocol(omega_t, delta_t, t):
    H = [
        [HX, lambda tt, args=None: np.interp(tt, t, omega_t)],
        [HZ, lambda tt, args=None: np.interp(tt, t, delta_t)],
    ]
    result = sesolve(H, psi0, t, e_ops=[SX, SY, SZ])
    return np.vstack(result.expect)


def downsample_indices(n, max_frames=MAX_FRAMES):
    frames = min(n, max_frames)
    return np.linspace(0, n - 1, frames, dtype=int)


def make_protocols(T_value, omega_scale, detuning_shift):
    t_full = time_grid(T_value)
    proto_defs = base_protocols(t_full, T_value)
    idxs = downsample_indices(len(t_full))
    protocols = []
    for base in proto_defs:
        omega = omega_scale * base["omega"]
        delta = base["delta"] + detuning_shift
        coords_full = simulate_protocol(omega, delta, t_full)
        protocols.append(
            {
                "label": base["label"],
                "omega_full": omega,
                "delta_full": delta,
                "t_frames": t_full[idxs],
                "omega_frames": omega[idxs],
                "delta_frames": delta[idxs],
                "coords": coords_full[:, idxs],
            }
        )
    return t_full, protocols


def pulse_limits(protocols):
    max_val = 0.0
    for proto in protocols:
        max_val = max(
            max_val,
            float(np.max(np.abs(proto["omega_full"]))),
            float(np.max(np.abs(proto["delta_full"]))),
        )
    span = max_val if max_val > 0 else 1.0
    span *= 1.1
    return -span, span


def build_animation(t_full, protocols, interval_ms=50):
    fig = plt.figure(figsize=(12, 7))
    gs = fig.add_gridspec(2, 3, height_ratios=[3, 1], width_ratios=[1, 1, 1.1])
    bloch_axes = [fig.add_subplot(gs[0, i], projection="3d") for i in range(len(protocols))]
    pulse_ax = fig.add_subplot(gs[1, :])

    bloch_objs, bloch_lines, bloch_points, bloch_vectors = [], [], [], []
    for ax, proto, color in zip(bloch_axes, protocols, COLORS):
        bloch = Bloch(fig=fig, axes=ax)
        bloch.sphere_alpha = 0.14
        bloch.frame_color = "#888888"
        bloch.frame_alpha = 0.2
        bloch.vector_width = 5
        bloch.point_size = [30]
        bloch.point_marker = ["o"]
        bloch.vector_color = ["#d62728"]
        bloch.point_color = ["#d62728"]
        bloch.xlabel = ["x", ""]
        bloch.ylabel = ["y", ""]
        bloch.zlabel = ["z", ""]
        bloch.render()
        ax.set_title(proto["label"])
        trace, = ax.plot([], [], [], color=color, lw=2, alpha=0.9)
        point, = ax.plot([], [], [], marker="o", color="#d62728", markersize=6)
        vector = ax.quiver(0, 0, 0, 0, 0, 0, color="#d62728", lw=2.2, arrow_length_ratio=0.25)
        bloch_objs.append(bloch)
        bloch_lines.append(trace)
        bloch_points.append(point)
        bloch_vectors.append(vector)

    omega_lines, delta_lines, omega_markers, delta_markers = [], [], [], []
    for proto, color in zip(protocols, COLORS):
        omega_line, = pulse_ax.plot(t_full, proto["omega_full"], color=color, label=f"Omega | {proto['label']}")
        delta_line, = pulse_ax.plot(
            t_full,
            proto["delta_full"],
            color=color,
            linestyle="--",
            alpha=0.8,
            label=f"Delta | {proto['label']}"
        )
        omega_marker, = pulse_ax.plot([], [], marker="o", color=color)
        delta_marker, = pulse_ax.plot([], [], marker="x", color=color, alpha=0.9)
        omega_lines.append(omega_line)
        delta_lines.append(delta_line)
        omega_markers.append(omega_marker)
        delta_markers.append(delta_marker)

    pulse_ax.set_xlabel("Time (us)")
    pulse_ax.set_ylabel("Amplitude (rad/us)")
    pulse_ax.legend(loc="upper right")
    pulse_ax.set_ylim(*pulse_limits(protocols))
    pulse_ax.set_xlim(0, float(t_full[-1]) if len(t_full) else 1.0)

    n_frames = protocols[0]["coords"].shape[1]

    def update(frame):
        idx = frame % n_frames
        artists = []
        for i, proto in enumerate(protocols):
            coords = proto["coords"]
            x, y, z = coords[:, idx]
            bloch_lines[i].set_data(coords[0, : idx + 1], coords[1, : idx + 1])
            bloch_lines[i].set_3d_properties(coords[2, : idx + 1])
            bloch_points[i].set_data([x], [y])
            bloch_points[i].set_3d_properties([z])

            bloch_vectors[i].remove()
            bloch_vectors[i] = bloch_axes[i].quiver(
                0,
                0,
                0,
                x,
                y,
                z,
                color="#d62728",
                lw=2.2,
                arrow_length_ratio=0.25,
                normalize=False,
            )
            artists.extend([bloch_lines[i], bloch_points[i], bloch_vectors[i]])

            t_now = proto["t_frames"][idx]
            omega_markers[i].set_data([t_now], [proto["omega_frames"][idx]])
            delta_markers[i].set_data([t_now], [proto["delta_frames"][idx]])
            artists.extend([omega_markers[i], delta_markers[i]])
        return artists

    anim = animation.FuncAnimation(fig, update, frames=n_frames, interval=interval_ms, blit=False)
    return fig, anim


In [None]:
# Interactive controls: sliders update parameters silently; Play triggers the simulation.

omega_scale = widgets.FloatSlider(
    value=1.0, min=0.25, max=2.5, step=0.05, description="Omega scale", readout_format=".2f", continuous_update=False
)
detuning_shift = widgets.FloatSlider(
    value=0.0,
    min=-15.0,
    max=15.0,
    step=0.25,
    description="Delta shift",
    readout_format=".2f",
    continuous_update=False,
)
T_slider = widgets.FloatSlider(
    value=T_DEFAULT, min=0.1, max=10.0, step=0.1, description="T (us)", readout_format=".2f", continuous_update=False
)
play_button = widgets.Button(description="Play", icon="play", button_style="success")
status = widgets.HTML()


def update_status(*_):
    status.value = (
        f"<b>Parameters</b>: Omega scale={omega_scale.value:.2f}, "
        f"Delta shift={detuning_shift.value:.2f} rad/us, T={T_slider.value:.2f} us. "
        "Click Play to run the simulation."
    )


for slider in (omega_scale, detuning_shift, T_slider):
    slider.observe(update_status, names="value")
update_status()

out = widgets.Output()


def on_play(_):
    with out:
        out.clear_output(wait=True)
        status.value = (
            f"Running with Omega scale={omega_scale.value:.2f}, "
            f"Delta shift={detuning_shift.value:.2f} rad/us, T={T_slider.value:.2f} us..."
        )
        try:
            t_full, protocols = make_protocols(T_slider.value, omega_scale.value, detuning_shift.value)
            if not len(t_full):
                raise ValueError("Time grid is empty")
            fig, anim = build_animation(t_full, protocols)
            html = anim.to_jshtml()
            plt.close(fig)
            status.value = (
                f"Complete. Frames={protocols[0]['coords'].shape[1]}, T={T_slider.value:.2f} us. "
                "Press Play again after adjusting sliders to rerun."
            )
            display(HTML(html))
        except Exception as exc:  # noqa: BLE001
            status.value = f"<span style='color:#cc0000;'>Simulation failed: {exc}</span>"
            raise


play_button.on_click(on_play)

layout = widgets.VBox(
    [
        widgets.HTML("<b>Adjust sliders, then press Play.</b>"),
        widgets.HBox([omega_scale, detuning_shift, T_slider, play_button]),
        status,
        out,
    ]
)

display(layout)


VBox(children=(HTML(value='<b>Adjust sliders, then press Play.</b>'), HBox(children=(FloatSlider(value=1.0, coâ€¦