In [1]:
from qiskit import QuantumCircuit, QuantumRegister
from qiskit.circuit import Parameter
from qiskit_dynamics import Solver, DynamicsBackend
from qiskit_dynamics.array import Array
from pulse_level.qiskit_pulse.custom_jax_sim import DynamicsBackendEstimator, JaxSolver
from qiskit import transpile, pulse
from qiskit.compiler import schedule as build_schedule
import numpy as np
from qiskit.quantum_info import Statevector, Operator
from rl_qoc.helper_functions import perform_standard_calibrations
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
# tell JAX we are using CPU
jax.config.update("jax_platform_name", "cpu")
Array.set_default_backend("jax")

In [2]:
# Define the system
dim = 3

v0, v1 = 4.86e9, 4.97e9
anharm0, anharm1 = -0.32e9, -0.32e9
r0, r1 = 0.22e9, 0.26e9

J = 0.002e9

a = np.diag(np.sqrt(np.arange(1, dim)), 1)
adag = np.diag(np.sqrt(np.arange(1, dim)), -1)
N = np.diag(np.arange(dim))

ident = np.eye(dim, dtype=complex)
full_ident = np.eye(dim**2, dtype=complex)

N0, N1 = np.kron(ident, N), np.kron(N, ident)
a0, a1 = np.kron(ident, a), np.kron(a, ident)
a0dag, a1dag = np.kron(ident, adag), np.kron(adag, ident)

static_ham0 = 2 * np.pi * v0 * N0 + np.pi * anharm0 * N0 * (N0 - full_ident)
static_ham1 = 2 * np.pi * v1 * N1 + np.pi * anharm1 * N1 * (N1 - full_ident)

static_ham_full = (
    static_ham0 + static_ham1 + 2 * np.pi * J * ((a0 + a0dag) @ (a1 + a1dag))
)

drive_op0 = 2 * np.pi * r0 * (a0 + a0dag)
drive_op1 = 2 * np.pi * r1 * (a1 + a1dag)

# build solver
dt = 1 / 4.5e9

solver_2q_jax = JaxSolver(
    static_hamiltonian=static_ham_full,
    hamiltonian_operators=[drive_op0, drive_op1, drive_op0, drive_op1],
    rotating_frame=static_ham_full,
    hamiltonian_channels=["d0", "d1", "u0", "u1"],
    channel_carrier_freqs={"d0": v0, "d1": v1, "u0": v1, "u1": v0},
    dt=dt,
)
solver_2q = Solver(
    static_hamiltonian=static_ham_full,
    hamiltonian_operators=[drive_op0, drive_op1, drive_op0, drive_op1],
    rotating_frame=static_ham_full,
    hamiltonian_channels=["d0", "d1", "u0", "u1"],
    channel_carrier_freqs={"d0": v0, "d1": v1, "u0": v1, "u1": v0},
    dt=dt,
)
# Consistent solver option to use throughout notebook

solver_options = [
    {"method": "jax_odeint", "atol": 1e-6, "rtol": 1e-8, "hmax": dt} for _ in range(2)
]
solver = solver_2q_jax
jax_backend = DynamicsBackend(
    solver=solver_2q_jax,
    # target = fake_backend_v2.target,
    subsystem_dims=[dim, dim],  # for computing measurement data
    solver_options=solver_options[0],  # to be used every time run is called
)

standard_backend = DynamicsBackend(
    solver=solver_2q,
    # target = fake_backend_v2.target,
    subsystem_dims=[dim, dim],  # for computing measurement data
    solver_options=solver_options[1],  # to be used every time run is called
)

In [3]:
jax_cals, jax_results = perform_standard_calibrations(backend=jax_backend)
standard_cals, standard_results = perform_standard_calibrations(
    backend=standard_backend
)

In [4]:
x_amp_ref = jax_backend.target.get_calibration("sx", (0,)).instructions[0][1].pulse.amp
jax_backend.target.get_calibration("x", (0,)).instructions[0][1].pulse.draw()

In [71]:
# use amplitude as the function argument
from qiskit import pulse
from qiskit_dynamics.pulse import InstructionToSignals
import sympy as sym
from qiskit_dynamics.array import wrap

jit = wrap(jax.jit, decorator=True)
qd_vmap = wrap(jax.vmap, decorator=True)
qc = QuantumCircuit(1)
param = Parameter("amp")
qc.x(0)


def jit_func(amp):
    drag_pulse = pulse.Drag(duration=160, amp=amp, sigma=40, angle=0, beta=6.6166741255)

    # build a pulse schedule
    with pulse.build(backend=jax_backend) as schedule:
        pulse.call(qc)
        # pulse.play(gaussian_pulse, pulse.DriveChannel(0))
        # pulse.play(drag_pulse, pulse.DriveChannel(0))
    # convert from a pulse schedule to a list of signals

    results = solver_2q_jax.solve(
        t_span=Array([0, 300 * dt]),
        y0=jnp.eye(9),
        signals=schedule,
        **solver_options[0]
    )
    return Array(results.t).data, Array(results.y).data


sim_func = jax.jit(jit_func)
results = sim_func(x_amp_ref)

In [72]:
Statevector.from_int(0, 9).evolve(Operator(np.array(results[1][1])))

In [73]:
from rl_qoc.helper_functions import projected_statevector, qubit_projection

print(
    projected_statevector(
        Statevector.from_int(0, 9).evolve(Operator(np.array(results[1][1]))).data,
        [3, 3],
    )
)
qubit_projection(np.array(results[1][1]), [3, 3])

In [76]:
from qiskit.quantum_info import state_fidelity, average_gate_fidelity
from qiskit.circuit.library import HGate

print(
    state_fidelity(
        projected_statevector(
            Statevector.from_int(1, 9).evolve(Operator(np.array(results[1][1]))).data,
            [3, 3],
        ),
        Statevector.from_label("00"),
        validate=False,
    )
)

gate_fid = average_gate_fidelity(
    qubit_projection(np.array(results[1][1]), [3, 3]), Operator.from_label("IH")
)
gate_fid

In [85]:
s1 = Statevector.from_int(0, 3)
s2 = Statevector.from_int(0, 3)
s = s1.tensor(s2)
print(s)
state_fidelity(
    projected_statevector(
        s.evolve(
            Operator(np.array(results[1][1]), input_dims=(3, 3), output_dims=(3, 3))
        ).data,
        [3, 3],
    ),
    Statevector.from_label("01"),
    validate=False,
)

In [77]:
average_gate_fidelity(Operator.from_label("I"), Operator.from_label("Z"))

In [68]:
Operator.from_label("IH")