# Full Workflow - Batched Observables and States to Desired Probability Vectors

Here we include the pre-processing necessary to convert some standard Qiskit Observables into our desired forms.

In [1]:
# All Imports

import numpy as np
import matplotlib.pyplot as plt
import sympy as sym

import qiskit
from qiskit import pulse

from qiskit_dynamics import Solver, DynamicsBackend
from qiskit_dynamics.pulse import InstructionToSignals
from qiskit_dynamics.array import Array

from qiskit.quantum_info import Statevector, DensityMatrix, Operator
from qiskit.circuit.parameter import Parameter

import jax
import jax.numpy as jnp
from jax import jit, vmap, block_until_ready, config

import chex

from typing import Optional, Union

Array.set_default_backend('jax')
config.update('jax_enable_x64', True)
config.update('jax_platform_name', 'cpu')

In [2]:
# Constructing a Two Qutrit Hamiltonian

dim = 3

v0 = 4.86e9
anharm0 = -0.32e9
r0 = 0.22e9

v1 = 4.97e9
anharm1 = -0.32e9
r1 = 0.26e9

J = 0.002e9

a = np.diag(np.sqrt(np.arange(1, dim)), 1)
adag = np.diag(np.sqrt(np.arange(1, dim)), -1)
N = np.diag(np.arange(dim))

ident = np.eye(dim, dtype=complex)
full_ident = np.eye(dim**2, dtype=complex)

N0 = np.kron(ident, N)
N1 = np.kron(N, ident)

a0 = np.kron(ident, a)
a1 = np.kron(a, ident)

a0dag = np.kron(ident, adag)
a1dag = np.kron(adag, ident)


static_ham0 = 2 * np.pi * v0 * N0 + np.pi * anharm0 * N0 * (N0 - full_ident)
static_ham1 = 2 * np.pi * v1 * N1 + np.pi * anharm1 * N1 * (N1 - full_ident)

static_ham_full = static_ham0 + static_ham1 + 2 * np.pi * J * ((a0 + a0dag) @ (a1 + a1dag))

drive_op0 = 2 * np.pi * r0 * (a0 + a0dag)
drive_op1 = 2 * np.pi * r1 * (a1 + a1dag)

In [3]:
batchsize = 400

amp_vals = jnp.linspace(0.5, 0.99, batchsize, dtype=jnp.float64).reshape(-1, 1)
sigma_vals = jnp.linspace(20, 80, batchsize, dtype=jnp.int8).reshape(-1, 1)
freq_vals = jnp.linspace(-0.5, 0.5, batchsize, dtype=jnp.float64).reshape(-1, 1) * 1e6
batch_params = jnp.concatenate((amp_vals, sigma_vals, freq_vals), axis=-1)

batch_y0 = jnp.tile(np.ones(9), (batchsize, 1))
batch_obs = jnp.tile(N0, (batchsize, 1, 1))

print(f"Batched Params Shape: {batch_params.shape}")

# Constructing a custom function that takes as input a parameter vector and returns the simulated state

def standard_func(params):
    amp, sigma, freq = params

    # Here we use a Drag Pulse as defined in qiskit pulse as its already a Scalable Symbolic Pulse
    special_pulse = pulse.Drag(
        duration=320,
        amp=amp,
        sigma=sigma,
        beta=0.1,
        angle=0.1,
        limit_amplitude=False
    )

    with pulse.build(default_alignment='sequential') as sched:
        d0 = pulse.DriveChannel(0)
        d1 = pulse.DriveChannel(1)
        u0 = pulse.ControlChannel(0)
        u1 = pulse.ControlChannel(1)

        pulse.shift_frequency(freq, d0)
        pulse.play(special_pulse, d0)

        pulse.shift_frequency(freq, d1)
        pulse.play(special_pulse, d1)

        pulse.shift_frequency(freq, u0)
        pulse.play(special_pulse, u0)

        pulse.shift_frequency(freq, u1)
        pulse.play(special_pulse, u1)
    
    return sched

Batched Params Shape: (400, 3)


In [10]:
# Constructing the new solver

dt = 1/4.5e9
atol = 1e-2
rtol = 1e-4

t_linspace = np.linspace(0.0, 400e-9, 11)
t_span = np.array([t_linspace[0], t_linspace[-1]])

ham_ops = [drive_op0, drive_op1, drive_op0, drive_op1]
ham_chans = ["d0", "d1", "u0", "u1"]
chan_freqs = {"d0": v0, "d1": v1, "u0": v1, "u1": v0}

solver = Solver(
    static_hamiltonian=static_ham_full,
    hamiltonian_operators=ham_ops,
    rotating_frame=static_ham_full,
    hamiltonian_channels=ham_chans,
    channel_carrier_freqs=chan_freqs,
    dt=dt,
)

class JaxifiedSolver:
    def __init__(
        self,
        schedule_func,
        dt,
        carrier_freqs,
        ham_chans,
        t_span,
        rtol,
        atol
    ):
        super().__init__()
        self.schedule_func = schedule_func
        self.dt = dt
        self.carrier_freqs = carrier_freqs
        self.ham_chans = ham_chans
        self.t_span = t_span
        self.rtol = rtol
        self.atol = atol
        self.fast_batched_sim = jit(vmap(self.run_sim))

    def run_sim(self, y0, obs, params):
        sched = self.schedule_func(params)

        converter = InstructionToSignals(self.dt, carriers=self.carrier_freqs, channels=self.ham_chans)

        signals = converter.get_signals(sched)

        results = solver.solve(
            t_span=self.t_span,
            y0=y0 / jnp.linalg.norm(y0),
            t_eval=self.t_span,
            signals=signals,
            rtol=self.rtol,
            atol=self.atol,
            convert_results=False,
            method='jax_odeint'
        )

        state_vec = results.y.data[-1]
        state_vec = state_vec / jnp.linalg.norm(state_vec)
        two_vec = state_vec[:4]
        evolved_vec = jnp.dot(obs, two_vec)
        new_vec = jnp.concatenate((evolved_vec, state_vec[4:]))
        probs_vec = jnp.abs(new_vec)**2
        probs_vec = jnp.clip(probs_vec, a_min=0.0, a_max=1.0)

        # Shots instead of probabilities

        return probs_vec
    
    def estimate(self, batch_y0, batch_obs, batch_params):
        ops_mat = [b.to_matrix() for b in batch_obs]
        ops_arr = jnp.array(ops_mat)
        return self.fast_batched_sim(batch_y0, ops_arr, batch_params)

In [11]:
j_solver = JaxifiedSolver(
    schedule_func=standard_func,
    dt=dt,
    carrier_freqs=chan_freqs,
    ham_chans=ham_chans,
    t_span=t_span,
    rtol=rtol,
    atol=atol
)

In [12]:
from qiskit.quantum_info import SparsePauliOp
ops_list = [SparsePauliOp(["IX"]), SparsePauliOp(["IY"]), SparsePauliOp(["YZ"]), SparsePauliOp(["ZX"])] * 100

batch_res = j_solver.estimate(
    batch_y0,
    ops_list,
    batch_params
)

In [13]:
%timeit j_solver.estimate(batch_y0,ops_list,batch_params)

7.18 s ± 715 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


1. Handle converting qubit ops to qudit ops (instead convert statevector to qubit statevec)
2. Handle arbitrary statevector inputs
3. Construct a proper 

## Done! For now...

Here we have added an estimate function (that would more realistically be part of a DynamicsBackend or an Estimator Primitive), that can take as input a  bunch of Qiskit Operators and return the probability vector we love!

In [22]:
from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector

qc = QuantumCircuit(3)

ket = Statevector(qc)
qc.x(2)
ket2 = Statevector(qc)
qc.x(1)
ket3 = Statevector(qc)

ket.draw()
print(ket.data)
print(ket)
print(ket2)
print(ket3)

[1.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j]
Statevector([1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j,
             0.+0.j],
            dims=(2, 2, 2))
Statevector([0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j,
             0.+0.j],
            dims=(2, 2, 2))
Statevector([0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j,
             0.+0.j],
            dims=(2, 2, 2))


000 => (1, 0, 0, 0, 0, 0, 0, 0)
001 => (0, 1, 0, 0, 0, 0, 0, 0)
011 => (0, 0, 0, 1, 0, 0, 0, 0)
101 => (0, 0, 0, 0, 0, 1, 0, 0)
2**2 + 1 = 5

NOTES

Psi = a(000) + b(001) + c(010) + d(011) + e(100) + f(101) + g(110) + h(111)
(qubit convention, 2->0)
modulus 4 gives the qubit state, where x(mod 4) = {0, 1, 2, 3}

Qubit_Psi = a(00) + b(01) + c(10) + d(11) + e(00) + f(01) + g(10) + h(11)

In [31]:
total_vec = np.ones(3 ** 2)
total_vec /= np.linalg.norm(total_vec)

