# Simulating custom pulses

### Imports and settings

In [96]:
import sys

sys.path.append("../")
import pulse_simulator as ps

In [188]:
import numpy as np
import csv
import functools
import qiskit as qk
import qiskit_dynamics as qk_d
import qiskit.quantum_info as qk_qi
import qiskit.providers.fake_provider as qk_fp
import matplotlib.pyplot as plt
from qiskit.circuit.library import XGate, SXGate

In [98]:
# configure jax to use 64 bit mode
import jax

jax.config.update("jax_enable_x64", True)

# tell JAX we are using CPU
jax.config.update("jax_platform_name", "cpu")

# set default backend
qk_d.array.Array.set_default_backend("jax")
qk_d.array.Array.default_backend()

'jax'

In [246]:
backend = qk_fp.FakeManila()

### Retrieve saved pulses

In [249]:
file_name = "../pico-pulses/saved-pulses-23-12-05/a_single_qubit_gateset_R1e-6.csv"
gates = []

with open(file_name) as file:
    reader = csv.reader(file)
    for row in reader:
        gates.append(np.array([float(x) for x in row]))

Rescale pulses so that their integral when `dt = 1` is the same as when `dt = 0.2`.

In [250]:
pulse_array = gates[0]
original_dt = 0.2
1 / original_dt

5.0

In [251]:
original_area = np.trapz(y=pulse_array, dx=0.2)
original_area

4.711539731629206

In [252]:
new_pulse_array = pulse_array / (1 / original_dt)
np.trapz(y=new_pulse_array, dx=1)

4.711539731629206

### Construct a solver

In [253]:
registers = [0]
config_vars = ps.backend_simulation_vars(backend, rabi=True, units=units)

H_rx = functools.partial(
    ps.rx_model,
    registers=registers,
    backend=backend,
    variables=config_vars,
    rotating_frame=False,
)

Hs_control = []
Hs_channels = []
for qubit in range(1):
    Hj_drift, Hjs_control, Hjs_channel = H_rx(qubit)
    Hs_control += Hjs_control
    Hs_channels += Hjs_channel


solver = qk_d.Solver(
    static_hamiltonian=None,
    hamiltonian_operators=Hs_control,
    static_dissipators=None,
    rotating_frame=None,
    rwa_cutoff_freq=None,
    hamiltonian_channels=Hs_channels,
    channel_carrier_freqs={ch: 0.0 for ch in Hs_channels},
    dt=1,
)

In [254]:
def get_pulse_unitary(pulse_array, duration, solver):
    # rescale pulse
    pulse_array = pulse_array / (1 / original_dt)

    # construct pulse and pulse moment
    pulse = qk.pulse.Waveform(pulse_array)
    with qk.pulse.build() as pulse_moment:
        channel = qk.pulse.DriveChannel(0)
        qk.pulse.play(pulse, channel)

    U0 = ps.qiskit_identity_operator(1)
    solver.model.evaluation_mode = "sparse"
    sol = solver.solve(
        t_span=[0.0, duration],
        y0=U0,
        signals=pulse_moment,
        atol=1e-8,
        rtol=1e-8,
        method="jax_odeint",
    )
    
    return sol.y[-1]

In [255]:
expected_list = [XGate(), XGate(), SXGate(), SXGate()]
for i, gate in enumerate(gates):
    print("Output unitary:")
    out = get_pulse_unitary(gate, 50, solver)
    print(np.round(out.data, 5))

    print("Expected unitary:")
    expected = qk_qi.Operator(expected_list[i])
    print(np.round(expected.data, 5))

    print("Fidelity: ", qk_qi.process_fidelity(expected, out, require_cp=False, require_tp=False), "\n")

Output unitary:
[[-0.21248+0.j       0.     +0.97717j]
 [ 0.     +0.97717j -0.21248+0.j     ]]
Expected unitary:
[[0.+0.j 1.+0.j]
 [1.+0.j 0.+0.j]]
Fidelity:  0.9548532726720862 

Output unitary:
[[0.06791+0.j      0.     -0.99769j]
 [0.     -0.99769j 0.06791+0.j     ]]
Expected unitary:
[[0.+0.j 1.+0.j]
 [1.+0.j 0.+0.j]]
Fidelity:  0.995388530697306 

Output unitary:
[[-0.81965+0.j       0.     +0.57287j]
 [ 0.     +0.57287j -0.81965+0.j     ]]
Expected unitary:
[[0.5+0.5j 0.5-0.5j]
 [0.5-0.5j 0.5+0.5j]]
Fidelity:  0.9695489607291636 

Output unitary:
[[0.73306+0.j      0.     -0.68017j]
 [0.     -0.68017j 0.73306+0.j     ]]
Expected unitary:
[[0.5+0.5j 0.5-0.5j]
 [0.5-0.5j 0.5+0.5j]]
Fidelity:  0.998601352580867 

