## Experimental Business

Here we demonstrate the Two Qudit Hamiltonian Class (to easily construct hamiltonians), along with the custom schedule
function construction, and finally the JaxedSolver allowing us to estimate the simulated state vectors in the desired basis.

In [38]:
# All Imports

import numpy as np

import jax
import jax.numpy as jnp
from jax.numpy.linalg import norm

import qiskit.pulse as pulse
from qiskit_dynamics.array import Array

from library.utils import PauliToQuditOperator, TwoQuditHamiltonian
from library.new_sims import JaxedSolver

Array.set_default_backend('jax')
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platform_name', 'cpu')

In [15]:
# Testing out the TwoQuditBackend Functionality

dt = 1/4.5e9
atol = 1e-2
rtol = 1e-4

batchsize = 400

t_linspace = np.linspace(0.0, 400e-9, 11)
t_span = np.array([t_linspace[0], t_linspace[-1]])

qudit_dim = 3

q_end = TwoQuditHamiltonian(
    qudit_dim=qudit_dim,
    dt=dt
)

solver = q_end.solver
ham_ops = q_end.ham_ops
ham_chans = q_end.ham_chans
chan_freqs = q_end.chan_freqs

In [34]:
# Make the Custom Schedule Construction Function

amp_vals = jnp.linspace(0.5, 0.99, batchsize, dtype=jnp.float64).reshape(-1, 1)
sigma_vals = jnp.linspace(20, 80, batchsize, dtype=jnp.int8).reshape(-1, 1)
freq_vals = jnp.linspace(-0.5, 0.5, batchsize, dtype=jnp.float64).reshape(-1, 1) * 1e6
batch_params = jnp.concatenate((amp_vals, sigma_vals, freq_vals), axis=-1)

init_y0 = jnp.ones(qudit_dim ** 2, dtype=jnp.complex128)
init_y0 /= norm(init_y0)
batch_y0 = jnp.tile(init_y0, (batchsize, 1))

batch_str = ["XX", "IX", "YZ", "ZY"] * 100

print(f"initial statevec: {init_y0}")
print(f"statevector * hc: {init_y0 @ init_y0.conj().T}")

def standard_func(params):
    amp, sigma, freq = params

    # Here we use a Drag Pulse as defined in qiskit pulse as its already a Scalable Symbolic Pulse
    special_pulse = pulse.Drag(
        duration=320,
        amp=amp,
        sigma=sigma,
        beta=0.1,
        angle=0.1,
        limit_amplitude=False
    )

    with pulse.build(default_alignment='sequential') as sched:
        d0 = pulse.DriveChannel(0)
        d1 = pulse.DriveChannel(1)
        u0 = pulse.ControlChannel(0)
        u1 = pulse.ControlChannel(1)

        pulse.shift_frequency(freq, d0)
        pulse.play(special_pulse, d0)

        pulse.shift_frequency(freq, d1)
        pulse.play(special_pulse, d1)

        pulse.shift_frequency(freq, u0)
        pulse.play(special_pulse, u0)

        pulse.shift_frequency(freq, u1)
        pulse.play(special_pulse, u1)
    
    return sched

initial statevec: [0.33333333+0.j 0.33333333+0.j 0.33333333+0.j 0.33333333+0.j
 0.33333333+0.j 0.33333333+0.j 0.33333333+0.j 0.33333333+0.j
 0.33333333+0.j]
statevector * hc: (1.0000000000000002+0j)


In [35]:
# Make the JaxedSolver backend

j_solver = JaxedSolver(
    schedule_func=standard_func,
    solver=solver,
    dt=dt,
    carrier_freqs=chan_freqs,
    ham_chans=ham_chans,
    ham_ops=ham_ops,
    t_span=t_span,
    rtol=rtol,
    atol=atol
)

In [37]:
j_solver.estimate2(batch_y0=batch_y0, batch_params=batch_params, batch_obs_str=batch_str)

%timeit j_solver.estimate2(batch_y0=batch_y0, batch_params=batch_params, batch_obs_str=batch_str)

(400, 9)
7.46 s ± 756 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Workflow

1. Use the TwoQudit Hamiltonian to make the Hamiltonian and corresponding Solver
2. Make a custom ScheduleFunc that will construct the appropriate Schedule
3. Use the JaxedSolver to take as input the standard TwoQudit solver, and the schedule func, and output the
    appropriate estimator results