# Simulating custom pulses

### Imports and settings

In [None]:
import sys

sys.path.append("../")
import pulse_simulator as ps

In [None]:
import numpy as np
import csv
import functools
import qiskit as qk
import qiskit_dynamics as qk_d
import qiskit.quantum_info as qk_qi
import qiskit.providers.fake_provider as qk_fp
from qiskit.circuit.library import XGate, SXGate

In [None]:
import matplotlib.pyplot as plt

In [None]:
# configure jax to use 64 bit mode
import jax

jax.config.update("jax_enable_x64", True)

# tell JAX we are using CPU
jax.config.update("jax_platform_name", "cpu")

# set default backend
qk_d.array.Array.set_default_backend("jax")
qk_d.array.Array.default_backend()

In [None]:
backend = qk_fp.FakeManila()
units = 1e9
GHz = 1 / units
ns = units

dt = backend.configuration().dt * ns
duration = 50 * dt  # ns

### Retrieve saved pulses

In [None]:
file_name = "../pico-pulses/saved-pulses-23-12-05/a_single_qubit_gateset_R1e-3.csv"
gates = []

with open(file_name) as file:
    reader = csv.reader(file)
    for row in reader:
        gates.append([float(x) for x in row])

In [None]:
fig, ax = plt.subplots()
ts = np.arange(0, 50 * 0.2, 0.2)
for gate in gates:
    ax.step(ts, gate)

### Construct a solver

In [None]:
registers = [0]
config_vars = ps.backend_simulation_vars(backend, rabi=True, units=units)

H_rx = functools.partial(
    ps.rx_model,
    registers=registers,
    backend=backend,
    variables=config_vars,
    rotating_frame=False,
)

Hs_control = []
Hs_channels = []
for qubit in range(1):
    Hj_drift, Hjs_control, Hjs_channel = H_rx(qubit)
    Hs_control += Hjs_control
    Hs_channels += Hjs_channel


solver = qk_d.Solver(
    static_hamiltonian=None,
    hamiltonian_operators=Hs_control,
    static_dissipators=None,
    rotating_frame=None,
    rwa_cutoff_freq=None,
    hamiltonian_channels=Hs_channels,
    channel_carrier_freqs={ch: 0.0 for ch in Hs_channels},
    dt=dt,
)

In [None]:
integrate.simpson(gates[0], dx=0.2)

In [None]:
np.trapz(gates[0], dx=0.2)

In [None]:
def get_pulse_unitary(pulse_array, duration, solver, expected=None):
    duration = duration * dt

    # rescale pulse if needed
    expected = integrate.simpson(pulse_array, dx=0.2) if expected is None else expected
    normalization = integrate.simpson(pulse_array, dx=dt) / expected
    # normalization = np.trapz(pulse_array, dx=dt) / expected
    # if max_amp > 1:
    pulse_array = pulse_array / normalization

    # construct pulse and pulse moment
    pulse = qk.pulse.Waveform(pulse_array, limit_amplitude=False)
    with qk.pulse.build() as pulse_moment:
        channel = qk.pulse.DriveChannel(0)
        qk.pulse.play(pulse, channel)

    U0 = ps.qiskit_identity_operator(1)
    solver.model.evaluation_mode = "sparse"
    sol = solver.solve(
        t_span=[0.0, duration],
        y0=U0,
        signals=pulse_moment,
        atol=1e-8,
        rtol=1e-8,
        method="jax_odeint",
    )
    
    return sol.y[-1]

In [None]:
expected_list = [XGate(), XGate(), SXGate(), SXGate()]
expected_angle = [np.pi / 2, np.pi / 2, np.pi / 4, np.pi / 4]

for i, gate in enumerate(gates):
    print("Output unitary:")
    out = get_pulse_unitary(gate, 50, solver, expected=expected_angle[i])
    print(np.round(out.data, 5))

    print("Expected unitary:")
    expected = qk_qi.Operator(expected_list[i])
    print(np.round(expected.data, 5))

    print("Fidelity: ", qk_qi.process_fidelity(expected, out, require_cp=False, require_tp=False), "\n")