# Import

In [None]:
import functools
import itertools
import numpy as np
from scipy.linalg import expm

import qiskit as qk
import qiskit_dynamics as qk_d
import qiskit.providers.fake_provider as qk_fp

import qutip as qt
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

import importlib

In [None]:
# configure jax to use 64 bit mode
import jax
jax.config.update("jax_enable_x64", True)

# tell JAX we are using CPU
jax.config.update('jax_platform_name', 'cpu')

# set default backend
qk_d.array.Array.set_default_backend('jax')
qk_d.array.Array.default_backend()

In [None]:
import sys
sys.path.append("../")
import pulse_simulator as ps

# Inspect two qubit gate Hamiltonian

In [None]:
backend = qk_fp.FakeManila()

# Initialize device
# =====
# Undo units
units = 1e9
GHz = 1/units
ns = units

dt = backend.configuration().dt * ns  
duration = 220 * dt  # ns

registers = [0, 1, 2, 3]  # TODO: Active registers

# Variables
# NOTE: If the Rabi rates are different, you have to calibrate!
config_vars = ps.backend_simulation_vars(backend, rabi=False, units=units)

# Carrier frequencies of each control line
carriers = ps.backend_carriers(backend, config_vars)

config_vars;

In [None]:
backend.properties().gate_length('cx', [0, 1])

## Type Toy

In [None]:
# Partially compile to get this circuit's gates
cr_model = functools.partial(
    ps.cross_resonance_model,
    registers=registers,
    backend=backend,
    variables=config_vars, 
    model_name="Toy",
    return_params=True
)

qb_model = functools.partial(
    ps.rx_model,
    registers=registers,
    backend=backend,
    variables=config_vars, 
    rotating_frame=True,
    return_params=True
)

# Two qubit model
for i,j in itertools.permutations(registers, 2):
    if abs(i-j) == 1:
        control = i
        target = j
        H_drift, Hs_control, H_channel, params = cr_model((control, target))
        print(f"Control: {control}, Target: {target}\n Params: {params}, \n Channel: {H_channel}")

print()

# Single qubit model
for i in registers:
    H_drift, Hs_control, H_channel, params = qb_model(i)
    print(f"Qubit: {i}, Params: {params}")

## Type SWPT

In [None]:
# Partially compile to get this circuit's gates
cr_model = functools.partial(
    ps.cross_resonance_model,
    registers=registers,
    backend=backend,
    variables=config_vars, 
    model_name="SWPT",
    return_params=True
)

qb_model = functools.partial(
    ps.rx_model,
    registers=registers,
    backend=backend,
    variables=config_vars, 
    rotating_frame=True,
    return_params=True
)

# Control model
for i,j in itertools.permutations(registers, 2):
    if abs(i-j) == 1:
        control = i
        target = j
        H_drift, Hs_control, H_channel, params = cr_model((control, target))
        print(f"Control: {control}, Target: {target}\n Params: {params}")

## Plot

In [None]:
def plot_Hamiltonian(H, vmin=-1, vmax=1):
    fig, axes = plt.subplots(1, 2)
    kwargs = {"vmin": vmin, "vmax": vmax, "cmap": "RdBu"}
    ax = axes[0]
    ax.imshow(np.real(H), **kwargs)
    ax = axes[1]
    ax.imshow(np.imag(H), **kwargs)
    [ax.axis("off") for ax in axes]
    return fig, axes

In [None]:
# params['IX'], params['ZX']

In [None]:
plot_Hamiltonian(Hs_control[0], vmin=-0.05, vmax=0.05)

# ECR as qiskit circuit

In [None]:
circ = qk.QuantumCircuit(2)
circ.ecr(0, 1)

print(qk.quantum_info.Operator(circ).data * np.sqrt(2))

circ.draw('mpl')

In [None]:
circ = qk.QuantumCircuit(2)
circ.ecr(1, 0)

print(qk.quantum_info.Operator(circ).data * np.sqrt(2))

circ.draw('mpl')

# Two qubit ECR gate (three pulses)

This section explores creating a two qubit ECR gate from components.

We need to define the Toy two qubit Hamiltonain, and make sure we get the channels correct.

We also need to add wait gates (constant zero pulses) to preserve timing.

```Python
     ┌─────────┐            ┌────────────┐┌────────┐┌─────────────┐
q_0: ┤0        ├       q_0: ┤0           ├┤ RX(pi) ├┤0            ├
     │   ECR   │   =        │  RZX(pi/4) │└────────┘│  RZX(-pi/4) │
q_1: ┤1        ├       q_1: ┤1           ├──────────┤1            ├
     └─────────┘            └────────────┘          └─────────────┘
```

In [None]:
def gaussian(x, mu=0, sigma=1):
    return np.exp(-(x - mu)**2 / 2 / sigma**2)

def lifted_gaussian(x, mu, sigma, x0=-1):
    g = functools.partial(gaussian, mu=mu, sigma=sigma)
    return (g(x) - g(x0)) / (1 - g(x0))

def truncated_gaussian(x, mu=0, sigma=1):
    g = functools.partial(gaussian, mu=mu, sigma=sigma)
    return g(x) - g(0)

def gaussian_envelope(dt, duration, angle=np.pi):
    """ Define gaussian envelope function to accumulate the angle.

    Returns:
        Qiskit pulse implementing angle.
    """
    steps = int(duration / dt)

    # Arbitrary shape
    # NOTE: Qiskit doesn't like pulse amplitudes > 1. Widen to avoid this.
    sigma = steps / 4
    mu = steps / 2

    # Normalize
    x = np.linspace(0, steps, endpoint=True)
    area = np.trapz(lifted_gaussian(x, mu, sigma), x)
    amplitude = angle / area / dt

    # Adjust the amplitude to achieve the angle
    return qk.pulse.Gaussian(steps, amplitude, sigma)

def zero_envelope(dt, duration):
    steps = int(duration / dt)
    return qk.pulse.Constant(steps, 0.0)

In [None]:
# brief aside to check the order of operators
# default is standard order
ps.from_label("XI")

In [None]:
# notice that this is the qiskit order, which would reverse to IX
ps.from_label("XI", reverse=True)

In [None]:
gate_lookup = {
    "sx_red": gaussian_envelope(dt, duration, angle=np.pi/4),
    "sx_blue": gaussian_envelope(dt, duration, angle=np.pi/4),
    "x_red": gaussian_envelope(dt, duration, angle=np.pi/2),
    "x_blue": gaussian_envelope(dt, duration, angle=np.pi/2),
    "zx+_red": gaussian_envelope(dt, duration, angle=np.pi/8),
    "zx-_red": gaussian_envelope(dt, duration, angle=-np.pi/8),
    "zx+_blue": gaussian_envelope(dt, duration, angle=np.pi/8),
    "zx-_blue": gaussian_envelope(dt, duration, angle=-np.pi/8),
    "x_wait": zero_envelope(dt, duration),
    "zx_wait": zero_envelope(dt, duration),
    "sx_wait": zero_envelope(dt, duration),
}

registers = [0, 1, 2, 3]

# Define gates
# NOTE: The order of the indices is control, target
gates = {(0, 1): 'ecr_red', (2, 3): 'ecr_red'}

# Design pulse schedule of a single ECR gate
# =====
with qk.pulse.build(name="Current moment") as pulse_moment:
    for (i_c, i_t), gate in gates.items():
        print(i_c, i_t)

        gate_type, gate_color = gate.split("_")
        
        # Drive using two kinds of channels
        drive_channel = ps.get_drive_channel(i_c, backend)
        control_channel = ps.get_control_channel(i_c, i_t, backend)
        print(drive_channel, control_channel)

        # # Control channel is for R_ZX
        qk.pulse.play(gate_lookup[f"zx+_{gate_color}"], control_channel)
        qk.pulse.play(gate_lookup["x_wait"], control_channel)
        qk.pulse.play(gate_lookup[f"zx-_{gate_color}"], control_channel)

        # Drive channel is for R_X
        qk.pulse.play(gate_lookup["zx_wait"], drive_channel)
        qk.pulse.play(gate_lookup[f"x_{gate_color}"], drive_channel)
        qk.pulse.play(gate_lookup["zx_wait"], drive_channel)

In [None]:
pulse_moment.draw()

In [None]:
# Create a system model
# =====
# Partially compile to get this circuit's gates
cr_model = functools.partial(
    ps.cross_resonance_model,
    registers=registers,
    backend=backend,
    variables=config_vars, 
    model_name="Toy"
)

# Control model
H_drift = 0.
Hs_control = []
Hs_channels = []
for (control, target), label in gates.items():
    Hj_drift, Hjs_control, Hjs_channel, params = cr_model((control, target), return_params=True)
    H_drift += Hj_drift
    Hs_control += Hjs_control
    Hs_channels += Hjs_channel
    
# Construct the solver
# =====
"""
Simulating the effective model, therefore the drift
is ZZ crosstalk, and there is no rotating frame.

Use the crosstalk computed perviously.
"""
solver = qk_d.Solver(
    static_hamiltonian=None, #H_xtalk,
    hamiltonian_operators=Hs_control,
    static_dissipators=None,
    rotating_frame=None,
    rwa_cutoff_freq=None,
    hamiltonian_channels=Hs_channels,
    channel_carrier_freqs={ch: 0. for ch in Hs_channels},
    dt=dt
)

In [None]:
# Start the qubit in its ground state.
y0 = ps.qiskit_ground_state(len(registers))

# Identity matrix
id_label = ''.join(['I'] * len(registers))
U0 = qk.quantum_info.Operator.from_label(id_label)

# Simulation time NOTE: longer than a single pulse
moment_duration = duration * 3

# Unitary sim.
solver.model.evaluation_mode = 'dense'
sol = solver.solve(
    t_span=[0.0, moment_duration],
    y0=U0,
    signals=pulse_moment,
    max_dt=dt,
    t_eval=[0, moment_duration],
    method="jax_expm",
    magnus_order=1,
)

# Sparse state vector sim
solver.model.evaluation_mode = 'sparse'
sol1 = solver.solve(
    t_span=[0., moment_duration],
    y0=y0,
    signals=pulse_moment,
    atol=1e-8,
    rtol=1e-8,
    method='jax_odeint'
)

In [None]:
basis = ps.hilbert_space_basis([2] * len(registers))

# Check final states
yf1 = sol1.y[-1]

Uf = sol.y[-1]
yf = y0.evolve(Uf)
# Compare
print(f"Are close? ||y1 - y2|| = {np.linalg.norm(yf1 - yf)}\n")

# States
ps.print_wavefunction(yf, basis, tol=1e-3)
print()

### Check the answer in a handful of ways

In [None]:
# 1. Solution of moment
Uf_prefactor = 1j * np.power(2, len(registers) / 4)
Uf_corrected = Uf_prefactor * Uf.data

In [None]:
plot_Hamiltonian(Uf_corrected, vmin=-2, vmax=2)

Notice that the order is switched between the desired ECR and the implemented ECR.

In [None]:
# 2. Analytic solution
if len(registers) == 2:
    # ECR01 = 1 / np.sqrt(2) * (ps.from_label("IX") - ps.from_label("XY"))
    ECR01 = 1 / np.sqrt(2) * (ps.from_label("XI") - ps.from_label("YX"))
    Uf_expected = ECR01.data
elif len(registers) == 4:
    # ECR01 = 1 / np.sqrt(2) * (ps.from_label("IXII") - ps.from_label("XYII"))
    # ECR23 = 1 / np.sqrt(2) * (ps.from_label("IIIX") - ps.from_label("IIXY"))
    ECR01 = 1 / np.sqrt(2) * (ps.from_label("XIII") - ps.from_label("YXII"))
    ECR23 = 1 / np.sqrt(2) * (ps.from_label("IIXI") - ps.from_label("IIYX"))
    Uf_expected = ECR01.data @ ECR23.data
else:
    raise ValueError("Only 2 or 4 qubits supported.")

plot_Hamiltonian(-Uf_prefactor * Uf_expected, vmin=-2, vmax=2)

In [None]:
np.isclose(Uf_corrected, -Uf_prefactor * Uf_expected, atol=1e-3).all()

In [None]:
# 3. Unitary, so Uf^+ Uf = I
plot_Hamiltonian(Uf_expected @ Uf.data)

In [None]:
# 4. Single ECR from components
ZX = ps.from_label("ZX")
XI = ps.from_label("XI")
1j * np.sqrt(2) * expm(1j * np.pi / 8 * ZX) @ expm(-1j * np.pi / 2 * XI) @ expm(-1j * np.pi / 8 * ZX)

In [None]:
(ps.from_label("XI") - ps.from_label("YX"))

# ECR control and control-spectator crosstalk study

Notice that if you run two X gates next to each other, the crosstalk condition vanishes.

The key assumptions here are that:
1. The amplitude of the gate should result in an accumulated angle of $\frac{\pi}{2}$. We are using the crosstalk condition for a Hamiltonian that is $H(t) = a(t)X$.
2. The crosstalk condition for a single gate is
\begin{equation}
    \left(\int_0^T \cos\left(2\int_0^t a(s)ds\right)dt \right)^2 + \left(\int_0^T \sin\left(2\int_0^t a(s)ds \right)dt\right)^2
\end{equation}

## Spectators and the gate

\begin{equation}
CS \leftrightarrow C \leftrightarrow T \leftrightarrow TS
\end{equation}

There are two parts to the crosstalk robustness. The first part is due to the X gate. The second part is due to the ZX rotation. They happen sequentially.

First, ignore the ZX rotation. This means studying the impact of the X gate and a control spectator. In this case, the ZZ crosstalk vanishes.

Now, we need to consider how to add the ZX drive back into the calculation. We want to see two things. First, the ZX drive should not affect the ability of the XI drive to cancel crosstalk when both are present. (This is the current behavior of ECR.) Second, the ZX drive should be able to cancel crosstalk on the target spectator. (This would be novel.)

Question:
Why 2at vs. $\int_0^t a dt$

In [None]:
def angle(t, amp=np.pi/2, duration=1, pulse_duration=1/4):
    wait_duration = (duration - pulse_duration * 2) / 2
    if t < pulse_duration:
        return 0 # amp / pulse_duration
    elif t < pulse_duration + wait_duration:
        return 0.
    elif t < 2 * pulse_duration + wait_duration:
        return amp / pulse_duration
    else:
        return 0.

In [None]:
duration = 10
pulse = 1.
ts = np.linspace(0, duration, 1000, endpoint=True)
fig, ax = plt.subplots()
ax.plot(ts, [angle(t, duration=duration, pulse_duration=pulse) for t in ts])
ax.set_title("Control vs. time")

In [None]:
np.trapz([angle(t, amp=np.pi, duration=duration, pulse_duration=pulse) for t in ts], ts)

In the next cell, plot the net crosstalk contributions of sine and cosine. Notice that we get back to zero for both sine and cosine.

In [None]:
int_angle = np.array([
    np.trapz([
        angle(t, duration=duration, pulse_duration=pulse)
        for t in ts[:m]], ts[:m]
    ) for m in range(len(ts))])
fig, axes = plt.subplots(2,1)
ax = axes[0]
ax.plot(int_angle, label=r"$\int_0^t a(s) ds$")
ax.legend()
ax = axes[1]
cos_vals = [np.trapz(np.cos(2 * int_angle[:m]), ts[:m]) for m in range(len(ts))]
sin_vals = [np.trapz(np.sin(2 * int_angle[:m]), ts[:m]) for m in range(len(ts))]
ax.plot(cos_vals, label="cos")
ax.plot(sin_vals, label="sin")
ax.axhline(pulse, color="black", linestyle="--", alpha=.5, lw=1)
ax.axhline(cos_vals[np.argwhere(ts > pulse)[0][0]], color="black", linestyle="--", alpha=.5, lw=1)
ax.legend()

In [None]:
(np.trapz([np.cos(2 * i) for i in int_angle], ts)**2, 
np.trapz([np.sin(2 * i) for i in int_angle], ts)**2)
