In [1]:
import numpy as np
import qiskit.pulse as pulse
from qiskit_dynamics.pulse.pulse_simulator import PulseSimulator
from qiskit_dynamics import Solver

# configure jax to use 64 bit mode
import jax
jax.config.update("jax_enable_x64", True)

# tell JAX we are using CPU
jax.config.update('jax_platform_name', 'cpu')

# import Array and set default backend
from qiskit_dynamics.array import Array
Array.set_default_backend('jax')

In [2]:
Z = np.array([[-1., 0.], [0., 1.]])
X = np.array([[0., 1.], [1., 0.]])

r = 0.1

static_ham = 2 * np.pi * 5 * Z / 2
drive_op = 2 * np.pi * r * X / 2

solver = Solver(
    static_hamiltonian=static_ham,
    hamiltonian_operators=[drive_op],
    hamiltonian_channels=['d0'],
    channel_carrier_freqs={'d0': 5.},
    dt=0.1,
    rotating_frame=static_ham
)

In [3]:
backend = PulseSimulator(solver=solver, subsystem_dims=[2])

In [12]:
from qiskit.pulse import library

amp = 1
sigma = 10
num_samples = 64
#%%
gauss = pulse.library.Gaussian(num_samples, amp, sigma,
                              name="Parametric Gauss")
gauss.draw()


with pulse.build() as schedule:
    # note: carrier frequency is automatically set to channel_carrier_freq if nothing
    # specified in schedule. Is this what we want?
    # This is baked into the InstructionToSchedule object. What does it do on the backends?
    #pulse.set_frequency(5., pulse.DriveChannel(0))
    pulse.play(gauss, pulse.DriveChannel(0))

results = backend.run(
    schedule, 
    shots=100, 
    solver_options={'method': 'jax_odeint', 'atol': 1e-8, 'rtol': 1e-8}
)
#%%
results

[{'0': 87, '1': 13}]