In [33]:
import torch
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

In [34]:
from data_gen.data_gen_sim_constants import (
    TOTAL_TIME,
    NUM_PULSES_PER_TRAIN,
    TIME_STEP_VALUES,
    NUM_TIME_STEPS,
    NUM_CONTROL_CHANNELS,
)
from qubit_sim.control_pulse_funcs import generate_gaussian_pulses

In [35]:
mpl.use("pdf")
plt.rc("font", family="serif", serif="cm10")
plt.rc("text", usetex=True)

alpha_value = 0.75

In [36]:
def generate_randomised_gaussian_centre_positions(
    num_diff_control_pulses: int = 1,
    total_time: float = TOTAL_TIME,
    num_time_steps: int = NUM_TIME_STEPS,
    num_pulses_per_train: int = NUM_PULSES_PER_TRAIN,
    num_control_channels: int = NUM_CONTROL_CHANNELS,
):
    shift_limit = 24 * total_time / num_time_steps
    n = torch.arange(1, num_pulses_per_train + 1)
    tau_n = (n - 0.5) / num_pulses_per_train * total_time
    tau_n_expanded = tau_n[None, :, None]

    random_shifts = shift_limit * (
        2
        * torch.rand(
            (num_diff_control_pulses, num_pulses_per_train, num_control_channels)
        )
        - 1
    )

    return tau_n_expanded + random_shifts


In [37]:
scale = 1 / 96
std_value = TOTAL_TIME / NUM_PULSES_PER_TRAIN * scale

gaussian_centre_positions = (
    (torch.arange(1, NUM_PULSES_PER_TRAIN + 1) - 0.5)
    / NUM_PULSES_PER_TRAIN
    * TOTAL_TIME
)

gaussian_centre_positions = gaussian_centre_positions[None, :, None]
gaussian_amplitudes = torch.full((1, NUM_PULSES_PER_TRAIN, 1), torch.pi)
gaussian_std_values = torch.full((1, NUM_PULSES_PER_TRAIN, 1), std_value)

gaussian_pulse_parameters = torch.stack(
    (gaussian_amplitudes, gaussian_centre_positions, gaussian_std_values), dim=-1
)

pulse_parameters = gaussian_pulse_parameters.reshape(1, NUM_PULSES_PER_TRAIN, 1 * 3)

control_pulse = generate_gaussian_pulses(
    number_of_channels=1,
    time_range_values=TIME_STEP_VALUES.to("cpu"),
    pulse_parameters=pulse_parameters,
)[0, :, 0]

In [38]:
scale = 1 / 24
non_ideal_std_value = TOTAL_TIME / NUM_PULSES_PER_TRAIN * scale

non_ideal_gaussian_centre_positions = generate_randomised_gaussian_centre_positions(
    total_time=TOTAL_TIME,
    num_time_steps=NUM_TIME_STEPS,
    num_pulses_per_train=NUM_PULSES_PER_TRAIN,
    num_control_channels=1,
)

gaussian_amplitudes = torch.full((1, NUM_PULSES_PER_TRAIN, 1), torch.pi)

small_amplitudes_fluctuations = (
    0.20 * torch.pi * (2 * torch.rand((1, NUM_PULSES_PER_TRAIN, 1)) - 1)
)

non_ideal_gaussian_amplitudes = gaussian_amplitudes + small_amplitudes_fluctuations

non_ideal_gaussian_std_values = torch.full(
    (1, NUM_PULSES_PER_TRAIN, 1), non_ideal_std_value
)

non_ideal_gaussian_pulse_parameters = torch.stack(
    (
        non_ideal_gaussian_amplitudes,
        non_ideal_gaussian_centre_positions,
        non_ideal_gaussian_std_values,
    ),
    dim=-1,
)

non_ideal_pulse_parameters = non_ideal_gaussian_pulse_parameters.reshape(
    1, NUM_PULSES_PER_TRAIN, 1 * 3
)

non_ideal_control_pulse = generate_gaussian_pulses(
    number_of_channels=1,
    time_range_values=TIME_STEP_VALUES.to("cpu"),
    pulse_parameters=non_ideal_pulse_parameters,
)[0, :, 0]

In [39]:
time_steps = np.linspace(0, NUM_TIME_STEPS, NUM_TIME_STEPS)

In [40]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 5))

ax1.plot(time_steps, control_pulse, color="blue", label="Ideal Control Pulse")
ax1.set_xlim(0, 1024)
ax1.set_xticks([0, 256, 512, 768, 1024])
ax1.set_xlabel("Time Step")
y_ticks = [0, np.pi / 4, np.pi / 2, 3 * np.pi / 4, np.pi]
y_tick_labels = ["0", r"$\pi/4$", r"$\pi/2$", r"$3\pi/4$", r"$\pi$"]
ax1.set_yticks(y_ticks)
ax1.set_yticklabels(y_tick_labels)
ax1.set_ylabel("Control Pulse Value")
ax1.set_title("Ideal Control Pulse Over Time (A)")
ax1.grid(True)


ax2.plot(
    time_steps,
    non_ideal_control_pulse,
    color="red",
    label="Non-Ideal Control Pulse",
)

ax2.set_xlim(0, 1024)
ax2.set_xticks([0, 256, 512, 768, 1024])
ax2.set_xlabel("Time Step")
ax2.set_yticks(y_ticks)
ax2.set_yticklabels(y_tick_labels)
ax2.set_ylabel("Control Pulse Value")
ax2.set_title("Realistic Control Pulse Over Time (B)")
ax2.grid(True)

plt.tight_layout()
plt.savefig("./feature_identification_paper/figures/control_pulses.pdf")