## Imports

In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jaxquantum.codes as jqtb
import jaxquantum as jqt
import jaxquantum.circuits as jqtc
import matplotlib.pyplot as plt
from tqdm import tqdm
import jax
import numpy as np

plt.style.use('ggplot')

## Check CUDA devs

In [None]:
jax.devices()

In [None]:
jax.devices()[0].memory_stats()

## Compute overlap of common states

In [None]:
delta = 0.3
gkp_qubit = jqtb.GKPQubit({"delta": delta, "N": 500})
gkp = gkp_qubit.basis["+z"]
fully_mixed = (gkp_qubit.basis["+z"].to_dm() + gkp_qubit.basis["-z"].to_dm())/2
print(jqt.overlap(gkp_qubit.basis["-z"], gkp_qubit.basis["+z"]))
print(jqt.overlap(fully_mixed, gkp_qubit.basis["+z"]))
print(jqt.overlap(gkp_qubit.common_gates["Z_0"], gkp_qubit.basis["+z"]))
print(jqt.overlap(gkp_qubit.common_gates["Z"], gkp_qubit.basis["+z"]))
print(jqt.overlap(gkp_qubit.common_gates["Z_0"], gkp_qubit.basis["-z"]))
print(jqt.overlap(gkp_qubit.common_gates["Z"], gkp_qubit.basis["-z"]))
print(jqt.overlap(gkp_qubit.common_gates["Z_0"], fully_mixed))
print(jqt.overlap(gkp_qubit.common_gates["Z"], fully_mixed))
print(jqt.overlap(gkp_qubit.common_gates["Z_0"], jqt.basis(500, 0)))
print(jqt.overlap(gkp_qubit.common_gates["Z"], jqt.basis(500, 0)))

## Declare circuit parameters

In [None]:
N = 125
T = 100
kappa_0 = 1/30
T_0 = 2276/1000
err_prob = 1-jnp.exp(-kappa_0*T_0)
deltas = jnp.linspace(0.25, 0.6, 40)
sd_ratios = jnp.linspace(0.5, 5, 50)
times = jnp.linspace(0, T*2*T_0, T+1)

## Sweep circuit parameters

In [None]:
@jax.jit
def sBs_round(initial_state, alphas, phis, thetas, err_prob):
    N = initial_state.space_dims[1]
    reg = jqtc.Register([2, N])

    cirq = jqtc.Circuit.create(reg, layers=[])

    cirq.append(jqtc.Ry(jnp.pi / 2), 0)
    cirq.append(jqtc.CD(N, alphas[0]), [0, 1])
    cirq.append(jqtc.Ry(phis[0]), 0)
    cirq.append(jqtc.Rx(thetas[0]), 0)
    cirq.append(jqtc.CD(N, alphas[1]), [0, 1])
    cirq.append(jqtc.Ry(phis[1]), 0)
    cirq.append(jqtc.Rx(thetas[1]), 0)
    cirq.append(jqtc.CD(N, alphas[2]), [0, 1])
    cirq.append(jqtc.Reset(), 0)

    cirq.append(jqtc.Amp_Damp(N, err_prob, 20), 1)

    res = jqtc.simulate(cirq, initial_state, mode='kraus')
    final_state = res[-1][-1]

    return final_state

In [None]:
def sbs_batch(delta, sd_ratio, err_prob):
    l = jnp.sqrt(2*jnp.pi)
    epsilon = jnp.sinh(delta*delta)*l

    alphas_real = jnp.array([epsilon/2, 0., sd_ratio*epsilon/2, 0., l, 0.])
    alphas_imag = jnp.array([0., -l, 0., epsilon/2, 0., sd_ratio*epsilon/2])
    alphas = alphas_real + alphas_imag * 1.j
    phis = jnp.array([0., 0., 0., 0.])
    thetas = jnp.array([jnp.pi/2, -jnp.pi/2, jnp.pi/2, -jnp.pi/2])
    exp_X = []
    exp_Z = []
    for axis in ["Z", "X"]:
        gkp_qubit = jqtb.GKPQubit({"delta": delta, "N": N})
        gkp = gkp_qubit.basis["+z"] if axis == "Z" else gkp_qubit.basis["+x"]

        initial_state = jqt.basis(2, 0) ^ gkp
        current_state = initial_state

        if axis=="Z":
            exp_Z.append(jnp.real(jqt.overlap(gkp_qubit.common_gates["Z_0"], current_state.ptrace(1))))
        else:
            exp_X.append(jnp.real(jqt.overlap(gkp_qubit.common_gates["X_0"], current_state.ptrace(1))))
        
        for _ in range(T):

            current_state = sBs_round(current_state, alphas[0:3], phis[0:2],
                                      thetas[0:2], err_prob)
            current_state = sBs_round(current_state, alphas[3:6], phis[2:4],
                                      thetas[2:4], err_prob)
            if axis=="Z":
                exp_Z.append(jnp.real(jqt.overlap(gkp_qubit.common_gates["Z_0"], current_state.ptrace(1))))
            else:
                exp_X.append(jnp.real(jqt.overlap(gkp_qubit.common_gates["X_0"], current_state.ptrace(1))))
    
    return jnp.array([exp_Zi / 2 + exp_Xi / 2 for (exp_Zi, exp_Xi) in zip(exp_Z, exp_X)])

In [None]:
def fit_t1(times, amps):
    p = jnp.polyfit(times, jnp.log(jnp.array(amps)), deg=1)
    return p[0]

In [None]:
sbs = jax.vmap(jax.vmap(sbs_batch, (0, 0, None), 0), (0, 0, None), 0)
fit_t1_vmap = jax.vmap(jax.vmap(fit_t1, (None, 0), 0), (None, 0), 0)

In [None]:
deltas_mg, sd_ratios_mg = jnp.meshgrid(deltas, sd_ratios)

In [None]:
results = jnp.abs(sbs(deltas_mg, sd_ratios_mg, err_prob))

In [None]:
t1s = -1 / fit_t1_vmap(times, results)

#TODO Add masking of invalid fits

In [None]:
fig, ax = plt.subplots(1, figsize=(10, 7))
handle = ax.pcolormesh(deltas_mg, sd_ratios_mg, t1s, vmin=0, vmax=100, shading='nearest')
fig.colorbar(handle, ax=ax, label="Lifetime $[\mu s]$")
ax.set_xlabel(r"$\Delta$")
ax.set_ylabel(r"$\alpha_2 / \alpha_1$")
ax.set_title(f"20 cycles; $T_{{sBs}}=2\cdot {T_0}\mu s$; $\kappa=1/(30\mu s)$")

## Zoom in

In [None]:
delta_min = 0.37
delta_max = 0.48
sd_ratio_min = 1.2
sd_ratio_max = 2.2

In [None]:
delta_min_idx = jnp.abs(deltas-delta_min).argmin()
delta_max_idx = jnp.abs(deltas-delta_max).argmin()
sd_ratio_min_idx = jnp.abs(sd_ratios-sd_ratio_min).argmin()
sd_ratio_max_idx = jnp.abs(sd_ratios-sd_ratio_max).argmin()

In [None]:
cut_deltas_mg = deltas_mg[sd_ratio_min_idx:sd_ratio_max_idx,delta_min_idx:delta_max_idx]
cut_sd_ratios_mg = sd_ratios_mg[sd_ratio_min_idx:sd_ratio_max_idx,delta_min_idx:delta_max_idx]
cut_t1s = t1s[sd_ratio_min_idx:sd_ratio_max_idx,delta_min_idx:delta_max_idx]

In [None]:
max_t1_idx = jnp.unravel_index(jnp.argmax(cut_t1s), cut_t1s.shape)
max_t1 = cut_t1s[max_t1_idx]
max_t1_delta = cut_deltas_mg[max_t1_idx]
max_t1_sd_ratio = cut_sd_ratios_mg[max_t1_idx]

In [None]:
fig, ax = plt.subplots(1, figsize=(10, 7))
handle = ax.pcolormesh(cut_deltas_mg, cut_sd_ratios_mg, cut_t1s, shading='nearest')
fig.colorbar(handle, ax=ax, label="Lifetime $[\mu s]$")
ax.scatter(max_t1_delta, max_t1_sd_ratio, color='red', label=f"Lifetime={max_t1:.1f}$\\mu s$")
ax.set_xlabel(r"$\Delta$")
ax.set_ylabel(r"$\alpha_2 / \alpha_1$")
ax.set_title("Reward after 20 cycles; $T_{sBs}=4\mu s$; $\kappa=1/(30\mu s)$")
ax.legend()
print(f"max_T1 = {max_t1:.1f}")
print(f"delta = {max_t1_delta:.3f}")
print(f"sd_ratio = {max_t1_sd_ratio:.3f}")

## Export best sBs sequence

In [None]:
delta = max_t1_delta
sd_ratio = max_t1_sd_ratio

In [None]:
l = jnp.sqrt(2*jnp.pi)
epsilon = jnp.sinh(delta*delta)*l

alphas_real = jnp.array([epsilon/2, 0., sd_ratio*epsilon/2, 0., l, 0.])
alphas_imag = jnp.array([0., -l, 0., epsilon/2, 0., sd_ratio*epsilon/2])
phis = jnp.array([jnp.pi/2, 0., 0., jnp.pi/2, 0., 0.]) / 2 / jnp.pi
thetas = (jnp.array([jnp.pi, jnp.pi/2, -jnp.pi/2, jnp.pi, jnp.pi/2, -jnp.pi/2])-jnp.pi) / 2 / jnp.pi

In [None]:
sbs_max_LL_Z = jnp.array([thetas[0:3], phis[0:3], alphas_real[0:3], alphas_imag[0:3]])
sbs_max_LL_X = jnp.array([thetas[3:6], phis[3:6], alphas_real[3:6], alphas_imag[3:6]])

In [None]:
jnp.savez(f"./sbs_delta{delta*1e3:.0f}_sdratio{sd_ratio*1e3:.0f}_stabX", best_params=sbs_max_LL_X)
jnp.savez(f"./sbs_delta{delta*1e3:.0f}_sdratio{sd_ratio*1e3:.0f}_stabZ", best_params=sbs_max_LL_Z)

## Compute X and Z lifetimes

In [None]:
def sbs_batch(delta, sd_ratio, err_prob):
    l = jnp.sqrt(2*jnp.pi)
    epsilon = jnp.sinh(delta*delta)*l

    alphas_real = jnp.array([epsilon/2, 0., sd_ratio*epsilon/2, 0., l, 0.])
    alphas_imag = jnp.array([0., -l, 0., epsilon/2, 0., sd_ratio*epsilon/2])
    alphas = alphas_real + alphas_imag * 1.j
    phis = jnp.array([0., 0., 0., 0.])
    thetas = jnp.array([jnp.pi/2, -jnp.pi/2, jnp.pi/2, -jnp.pi/2])
    exp_X = []
    exp_Z = []
    for axis in ["Z", "X"]:
        gkp_qubit = jqtb.GKPQubit({"delta": delta, "N": N})
        gkp = gkp_qubit.basis["+z"] if axis == "Z" else gkp_qubit.basis["+x"]

        initial_state = jqt.basis(2, 0) ^ gkp
        current_state = initial_state

        if axis=="Z":
            exp_Z.append(jnp.real(jqt.overlap(gkp_qubit.common_gates["Z_0"], current_state.ptrace(1))))
        else:
            exp_X.append(jnp.real(jqt.overlap(gkp_qubit.common_gates["X_0"], current_state.ptrace(1))))
    
        for _ in range(T):

            current_state = sBs_round(current_state, alphas[0:3], phis[0:2],
                                      thetas[0:2], err_prob)
            current_state = sBs_round(current_state, alphas[3:6], phis[2:4],
                                      thetas[2:4], err_prob)
            if axis=="Z":
                exp_Z.append(jnp.real(jqt.overlap(gkp_qubit.common_gates["Z_0"], current_state.ptrace(1))))
            else:
                exp_X.append(jnp.real(jqt.overlap(gkp_qubit.common_gates["X_0"], current_state.ptrace(1))))
    
    return jnp.array(exp_Z), jnp.array(exp_X)

In [None]:
results = sbs_batch(delta, sd_ratio, err_prob)

In [None]:
t_z = -1/fit_t1(times, jnp.abs(results[0]))
t_x = -1/fit_t1(times, jnp.abs(results[1]))

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=150)
axs[0].scatter(times, results[0], label=f"$T_Z={t_z:.3f} \mu s$")
axs[0].set_xlabel("t [$\mu s]$")
axs[0].set_ylabel("$\\langle \\hat{Z} \\rangle$")
axs[0].legend()
axs[1].scatter(times, results[1], label=f"$T_X={t_x:.3f} \mu s$")
axs[1].set_xlabel("t [$\mu s]$")
axs[1].set_ylabel("$\\langle \\hat{X} \\rangle$")
axs[1].legend()

fig.suptitle(f"Logical operator decay during sBs, $\\Delta={delta}$, $r_{{SB}}={sd_ratio:.2f}$, $T_{{cyc}}=2\cdot2276ns$")

## Plot CF after a single sBs half-round

In [None]:
def calculate_cf(osc_state, betas_re=None, betas_im=None):
    # Plot CF
    N = osc_state.dims[0][0]

    betas_re = betas_re if betas_re is not None else jnp.linspace(-4,4, 41)
    betas_im = betas_im if betas_im is not None else jnp.linspace(-4,4, 41)
    betas = betas_re.reshape(-1,1) + 1j*betas_im.reshape(1,-1)
    
    cf_vals = np.zeros((len(betas_re), len(betas_im)), dtype=jnp.complex64)
    for j in tqdm(range(len(betas_re))):
        for k in range(len(betas_im)):
            cf_vals[j,k] = jqt.overlap(jqt.displace(N, betas[j,k]), osc_state)
    return cf_vals, betas_re, betas_im

In [None]:
def qubit_phase(beta_squared, η0, η2, η4, η6, ξ2, ξ4, offset):
    beta_squared = np.abs(beta_squared)
    theta = ξ2 * beta_squared + ξ4 * beta_squared**2
    sx = np.cos(2*theta)
    sy = np.sin(2*theta)
    purity_scale_factor = (1 - (η0 + η2 *beta_squared + η4 * beta_squared**2 + η6 * beta_squared**3))
    return sx * purity_scale_factor + offset, sy * purity_scale_factor + offset

def purity_func(beta_squared, η0, η2, η4, η6, ξ2, ξ4, offset):
    sx, sy = qubit_phase(beta_squared, η0, η2, η4, η6, ξ2, ξ4, offset)
    return 0.5 * (1 + sx**2 + sy**2)

def calculate_envelope(betas_re, betas_im, fit_vals):
    betas_re_grid, betas_im_grid = jnp.meshgrid(betas_re, betas_im)
    betas_squared_grid = jnp.abs(betas_re_grid)**2 + jnp.abs(betas_im_grid)**2
    purity_vals = purity_func(betas_squared_grid, *fit_vals)
    envelope = 2*(purity_vals-0.5)
    return envelope


fit_vals = (0.39097432835648155,
 0.034291746251959555,
 -0.0010010145791953715,
 1.0867037992914275e-05,
 -0.04611114282106909,
 9.382815933549335e-06,
 0.07836293463518865)

### sBs Z half-round

#### Full CF

In [None]:
l = jnp.sqrt(2*jnp.pi)
epsilon = jnp.sinh(delta*delta)*l

alphas_real = jnp.array([epsilon/2, 0., sd_ratio*epsilon/2, 0., l, 0.])
alphas_imag = jnp.array([0., -l, 0., epsilon/2, 0., sd_ratio*epsilon/2])
alphas = alphas_real + alphas_imag * 1.j
phis = jnp.array([0., 0., 0., 0.])
thetas = jnp.array([jnp.pi/2, -jnp.pi/2, jnp.pi/2, -jnp.pi/2])
exp_X = []
exp_Z = []

gkp_qubit = jqtb.GKPQubit({"delta": delta, "N": N})
gkp = gkp_qubit.basis["-x"]

initial_state = jqt.basis(2, 0) ^ gkp
current_state = initial_state




current_state = sBs_round(current_state, alphas[0:3], phis[0:2],
                          thetas[0:2], err_prob)

betas_re= jnp.linspace(-3,3, 41)
betas_im= jnp.linspace(-3,3, 41)

cf_vals, _, _ = calculate_cf(current_state.ptrace(1), betas_re, betas_im)

fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=150)

for j, ax in enumerate((axs[0], axs[1])):
    vmin=-1
    vmax=1
    sf = 1/jnp.sqrt(jnp.pi*2)

    if j == 0:
        cf_vals_plot = jnp.real(cf_vals)
        cbar_title = r"Re[$\mathcal{C}(\beta)$]"
    else:
        cf_vals_plot = jnp.imag(cf_vals)
        cbar_title = r"Im[$\mathcal{C}(\beta)$]"

    im = ax.contourf(betas_re*sf, betas_im*sf, cf_vals_plot.T, levels=np.linspace(vmin, vmax, 101), cmap="seismic", vmin=vmin, vmax=vmax)
    ax.set_aspect("equal", adjustable="box")
    ax.grid()

    ax.set_xticks([-1,0, 1])
    ax.set_yticks([-1,0, 1])
    ax.set_xlabel(r"Re[$\beta$]/$\sqrt{2\pi}$")
    ax.set_ylabel(r"Im[$\beta$]/$\sqrt{2\pi}$")

    cbar = plt.colorbar(im, ax=ax, orientation="vertical")
    cbar.ax.set_title(cbar_title)
    cbar.ax.set_yticks(np.linspace(-1, 1, 11))

    fig.suptitle(r"$\widehat{sBs_Z} \cdot \vert -z \rangle$", y = 1.05)

    

#### Enveloped CF

In [None]:
envelope = calculate_envelope(betas_re, betas_im, fit_vals)
cf_vals, betas_re, betas_im = calculate_cf(current_state.ptrace(1), betas_re, betas_im)
cf_vals = envelope * cf_vals

In [None]:

fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=150)

for j, ax in enumerate((axs[0], axs[1])):
    vmin=-1
    vmax=1
    sf = 1/(jnp.sqrt(2*jnp.pi))

    if j == 0:
        cf_vals_plot = jnp.real(cf_vals)
        cbar_title = r"Re[$\mathcal{C}(\beta)$]"
    else:
        cf_vals_plot = jnp.imag(cf_vals)
        cbar_title = r"Im[$\mathcal{C}(\beta)$]"

    im = ax.contourf(betas_re*sf, betas_im*sf, cf_vals_plot.T, levels=np.linspace(vmin, vmax, 101), cmap="seismic", vmin=vmin, vmax=vmax)
    ax.set_aspect("equal", adjustable="box")
    ax.grid()

    ax.set_xticks([-1,0, 1])
    ax.set_yticks([-1,0, 1])
    ax.set_xlabel(r"Re[$\beta$]/$\sqrt{2\pi}$")
    ax.set_ylabel(r"Im[$\beta$]/$\sqrt{2\pi}$")

    cbar = plt.colorbar(im, ax=ax, orientation="vertical")
    cbar.ax.set_title(cbar_title)
    cbar.ax.set_yticks(np.linspace(-1, 1, 11))

    fig.suptitle(r"$\widehat{sBs_Z} \cdot \vert -z \rangle$ with impurity envelope", y = 1.05)

    

### sbs X half-round

#### Full CF

In [None]:
l = jnp.sqrt(2*jnp.pi)
epsilon = jnp.sinh(delta*delta)*l

alphas_real = jnp.array([epsilon/2, 0., sd_ratio*epsilon/2, 0., l, 0.])
alphas_imag = jnp.array([0., -l, 0., epsilon/2, 0., sd_ratio*epsilon/2])
alphas = alphas_real + alphas_imag * 1.j
phis = jnp.array([0., 0., 0., 0.])
thetas = jnp.array([jnp.pi/2, -jnp.pi/2, jnp.pi/2, -jnp.pi/2])
exp_X = []
exp_Z = []

gkp_qubit = jqtb.GKPQubit({"delta": delta, "N": N})
gkp = gkp_qubit.basis["-x"]

initial_state = jqt.basis(2, 0) ^ gkp
current_state = initial_state




current_state = sBs_round(current_state, alphas[3:6], phis[2:4],
                          thetas[2:4], err_prob)

betas_re= jnp.linspace(-3,3, 41)
betas_im= jnp.linspace(-3,3, 41)

cf_vals, _, _ = calculate_cf(current_state.ptrace(1), betas_re, betas_im)

fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=150)

for j, ax in enumerate((axs[0], axs[1])):
    vmin=-1
    vmax=1
    sf = 1/jnp.sqrt(jnp.pi*2)

    if j == 0:
        cf_vals_plot = jnp.real(cf_vals)
        cbar_title = r"Re[$\mathcal{C}(\beta)$]"
    else:
        cf_vals_plot = jnp.imag(cf_vals)
        cbar_title = r"Im[$\mathcal{C}(\beta)$]"

    im = ax.contourf(betas_re*sf, betas_im*sf, cf_vals_plot.T, levels=np.linspace(vmin, vmax, 101), cmap="seismic", vmin=vmin, vmax=vmax)
    ax.set_aspect("equal", adjustable="box")
    ax.grid()

    ax.set_xticks([-1,0, 1])
    ax.set_yticks([-1,0, 1])
    ax.set_xlabel(r"Re[$\beta$]/$\sqrt{2\pi}$")
    ax.set_ylabel(r"Im[$\beta$]/$\sqrt{2\pi}$")

    cbar = plt.colorbar(im, ax=ax, orientation="vertical")
    cbar.ax.set_title(cbar_title)
    cbar.ax.set_yticks(np.linspace(-1, 1, 11))

    fig.suptitle(r"$\widehat{sBs_X} \cdot \vert -z \rangle$", y = 1.05)

    

#### Enveloped CF

In [None]:
envelope = calculate_envelope(betas_re, betas_im, fit_vals)


cf_vals, betas_re, betas_im = calculate_cf(current_state.ptrace(1), betas_re, betas_im)
cf_vals = envelope * cf_vals

In [None]:

fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=150)

for j, ax in enumerate((axs[0], axs[1])):
    vmin=-1
    vmax=1
    sf = 1/(jnp.sqrt(2*jnp.pi))

    if j == 0:
        cf_vals_plot = jnp.real(cf_vals)
        cbar_title = r"Re[$\mathcal{C}(\beta)$]"
    else:
        cf_vals_plot = jnp.imag(cf_vals)
        cbar_title = r"Im[$\mathcal{C}(\beta)$]"

    im = ax.contourf(betas_re*sf, betas_im*sf, cf_vals_plot.T, levels=np.linspace(vmin, vmax, 101), cmap="seismic", vmin=vmin, vmax=vmax)
    ax.set_aspect("equal", adjustable="box")
    ax.grid()

    ax.set_xticks([-1,0, 1])
    ax.set_yticks([-1,0, 1])
    ax.set_xlabel(r"Re[$\beta$]/$\sqrt{2\pi}$")
    ax.set_ylabel(r"Im[$\beta$]/$\sqrt{2\pi}$")

    cbar = plt.colorbar(im, ax=ax, orientation="vertical")
    cbar.ax.set_title(cbar_title)
    cbar.ax.set_yticks(np.linspace(-1, 1, 11))

    fig.suptitle(r"$\widehat{sBs_X} \cdot \vert -z \rangle$ with impurity envelope", y = 1.05)

    

## Action on vac

### sBs Z half-round

#### Full CF

In [None]:
l = jnp.sqrt(2*jnp.pi)
epsilon = jnp.sinh(delta*delta)*l

alphas_real = jnp.array([epsilon/2, 0., sd_ratio*epsilon/2, 0., l, 0.])
alphas_imag = jnp.array([0., -l, 0., epsilon/2, 0., sd_ratio*epsilon/2])
alphas = alphas_real + alphas_imag * 1.j
phis = jnp.array([0., 0., 0., 0.])
thetas = jnp.array([jnp.pi/2, -jnp.pi/2, jnp.pi/2, -jnp.pi/2])
exp_X = []
exp_Z = []

gkp_qubit = jqtb.GKPQubit({"delta": delta, "N": N})
gkp = gkp_qubit.basis["-x"]

initial_state = jqt.basis(2, 0) ^ jqt.basis(N, 0)
current_state = initial_state




current_state = sBs_round(current_state, alphas[0:3], phis[0:2],
                          thetas[0:2], err_prob)

betas_re= jnp.linspace(-3,3, 41)
betas_im= jnp.linspace(-3,3, 41)

cf_vals, _, _ = calculate_cf(current_state.ptrace(1), betas_re, betas_im)

fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=150)

for j, ax in enumerate((axs[0], axs[1])):
    vmin=-1
    vmax=1
    sf = 1/jnp.sqrt(jnp.pi*2)

    if j == 0:
        cf_vals_plot = jnp.real(cf_vals)
        cbar_title = r"Re[$\mathcal{C}(\beta)$]"
    else:
        cf_vals_plot = jnp.imag(cf_vals)
        cbar_title = r"Im[$\mathcal{C}(\beta)$]"

    im = ax.contourf(betas_re*sf, betas_im*sf, cf_vals_plot.T, levels=np.linspace(vmin, vmax, 101), cmap="seismic", vmin=vmin, vmax=vmax)
    ax.set_aspect("equal", adjustable="box")
    ax.grid()

    ax.set_xticks([-1,0, 1])
    ax.set_yticks([-1,0, 1])
    ax.set_xlabel(r"Re[$\beta$]/$\sqrt{2\pi}$")
    ax.set_ylabel(r"Im[$\beta$]/$\sqrt{2\pi}$")

    cbar = plt.colorbar(im, ax=ax, orientation="vertical")
    cbar.ax.set_title(cbar_title)
    cbar.ax.set_yticks(np.linspace(-1, 1, 11))

    fig.suptitle(r"$\widehat{sBs_Z} \cdot \vert vac \rangle$", y = 1.05)

    

#### Enveloped CF

In [None]:
envelope = calculate_envelope(betas_re, betas_im, fit_vals)


cf_vals, betas_re, betas_im = calculate_cf(current_state.ptrace(1), betas_re, betas_im)
cf_vals = envelope * cf_vals

In [None]:

fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=150)

for j, ax in enumerate((axs[0], axs[1])):
    vmin=-1
    vmax=1
    sf = 1/(jnp.sqrt(2*jnp.pi))

    if j == 0:
        cf_vals_plot = jnp.real(cf_vals)
        cbar_title = r"Re[$\mathcal{C}(\beta)$]"
    else:
        cf_vals_plot = jnp.imag(cf_vals)
        cbar_title = r"Im[$\mathcal{C}(\beta)$]"

    im = ax.contourf(betas_re*sf, betas_im*sf, cf_vals_plot.T, levels=np.linspace(vmin, vmax, 101), cmap="seismic", vmin=vmin, vmax=vmax)
    ax.set_aspect("equal", adjustable="box")
    ax.grid()

    ax.set_xticks([-1,0, 1])
    ax.set_yticks([-1,0, 1])
    ax.set_xlabel(r"Re[$\beta$]/$\sqrt{2\pi}$")
    ax.set_ylabel(r"Im[$\beta$]/$\sqrt{2\pi}$")

    cbar = plt.colorbar(im, ax=ax, orientation="vertical")
    cbar.ax.set_title(cbar_title)
    cbar.ax.set_yticks(np.linspace(-1, 1, 11))

    fig.suptitle(r"$\widehat{sBs_Z} \cdot \vert vac \rangle$ with impurity envelope", y = 1.05)

    

### sBs X half-round

#### Full CF

In [None]:
l = jnp.sqrt(2*jnp.pi)
epsilon = jnp.sinh(delta*delta)*l

alphas_real = jnp.array([epsilon/2, 0., sd_ratio*epsilon/2, 0., l, 0.])
alphas_imag = jnp.array([0., -l, 0., epsilon/2, 0., sd_ratio*epsilon/2])
alphas = alphas_real + alphas_imag * 1.j
phis = jnp.array([0., 0., 0., 0.])
thetas = jnp.array([jnp.pi/2, -jnp.pi/2, jnp.pi/2, -jnp.pi/2])
exp_X = []
exp_Z = []

gkp_qubit = jqtb.GKPQubit({"delta": delta, "N": N})
gkp = gkp_qubit.basis["-x"]

initial_state = jqt.basis(2, 0) ^ jqt.basis(N, 0)
current_state = initial_state




current_state = sBs_round(current_state, alphas[3:6], phis[2:4],
                          thetas[2:4], err_prob)

betas_re= jnp.linspace(-3,3, 41)
betas_im= jnp.linspace(-3,3, 41)

cf_vals, _, _ = calculate_cf(current_state.ptrace(1), betas_re, betas_im)

fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=150)

for j, ax in enumerate((axs[0], axs[1])):
    vmin=-1
    vmax=1
    sf = 1/jnp.sqrt(jnp.pi*2)

    if j == 0:
        cf_vals_plot = jnp.real(cf_vals)
        cbar_title = r"Re[$\mathcal{C}(\beta)$]"
    else:
        cf_vals_plot = jnp.imag(cf_vals)
        cbar_title = r"Im[$\mathcal{C}(\beta)$]"

    im = ax.contourf(betas_re*sf, betas_im*sf, cf_vals_plot.T, levels=np.linspace(vmin, vmax, 101), cmap="seismic", vmin=vmin, vmax=vmax)
    ax.set_aspect("equal", adjustable="box")
    ax.grid()

    ax.set_xticks([-1,0, 1])
    ax.set_yticks([-1,0, 1])
    ax.set_xlabel(r"Re[$\beta$]/$\sqrt{2\pi}$")
    ax.set_ylabel(r"Im[$\beta$]/$\sqrt{2\pi}$")

    cbar = plt.colorbar(im, ax=ax, orientation="vertical")
    cbar.ax.set_title(cbar_title)
    cbar.ax.set_yticks(np.linspace(-1, 1, 11))

    fig.suptitle(r"$\widehat{sBs_X} \cdot \vert vac \rangle$", y = 1.05)

    

#### Enveloped CF

In [None]:
envelope = calculate_envelope(betas_re, betas_im, fit_vals)


cf_vals, betas_re, betas_im = calculate_cf(current_state.ptrace(1), betas_re, betas_im)
cf_vals = envelope * cf_vals

In [None]:

fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=150)

for j, ax in enumerate((axs[0], axs[1])):
    vmin=-1
    vmax=1
    sf = 1/(jnp.sqrt(2*jnp.pi))

    if j == 0:
        cf_vals_plot = jnp.real(cf_vals)
        cbar_title = r"Re[$\mathcal{C}(\beta)$]"
    else:
        cf_vals_plot = jnp.imag(cf_vals)
        cbar_title = r"Im[$\mathcal{C}(\beta)$]"

    im = ax.contourf(betas_re*sf, betas_im*sf, cf_vals_plot.T, levels=np.linspace(vmin, vmax, 101), cmap="seismic", vmin=vmin, vmax=vmax)
    ax.set_aspect("equal", adjustable="box")
    ax.grid()

    ax.set_xticks([-1,0, 1])
    ax.set_yticks([-1,0, 1])
    ax.set_xlabel(r"Re[$\beta$]/$\sqrt{2\pi}$")
    ax.set_ylabel(r"Im[$\beta$]/$\sqrt{2\pi}$")

    cbar = plt.colorbar(im, ax=ax, orientation="vertical")
    cbar.ax.set_title(cbar_title)
    cbar.ax.set_yticks(np.linspace(-1, 1, 11))

    fig.suptitle(r"$\widehat{sBs_X} \cdot \vert vac \rangle$ with impurity envelope", y = 1.05)

    

## Compare with run_circuit

In [None]:
l = jnp.sqrt(2*jnp.pi)
epsilon = jnp.sinh(delta*delta)*l

alphas_real = jnp.array([epsilon/2, 0., sd_ratio*epsilon/2, 0., l, 0.])
alphas_imag = jnp.array([0., -l, 0., epsilon/2, 0., sd_ratio*epsilon/2])
phis = jnp.array([jnp.pi/2, 0., 0., jnp.pi/2, 0., 0.]) / 2 / jnp.pi
thetas = (jnp.array([jnp.pi, jnp.pi/2, -jnp.pi/2, jnp.pi, jnp.pi/2, -jnp.pi/2])-jnp.pi) / 2 / jnp.pi

In [None]:
sbs_max_LL_Z = jnp.array([thetas[0:3], phis[0:3], alphas_real[0:3], alphas_imag[0:3]])
sbs_max_LL_X = jnp.array([thetas[3:6], phis[3:6], alphas_real[3:6], alphas_imag[3:6]])

In [None]:
def run_circuit(params, N):
    gammas_x = 2*jnp.pi*params[0]
    gammas_y = 2*jnp.pi*params[1]
    betas_re = params[2]
    betas_im = params[3]

    betas = betas_re + 1j*betas_im
    
    reg = jqtc.Register([2,N])
    cirq = jqtc.Circuit.create(reg, layers=[])
    
    for i in range(len(gammas_x)):
        cirq.append(jqtc.Rx(gammas_x[i]), 0)
        cirq.append(jqtc.Ry(gammas_y[i]), 0)
        cirq.append(jqtc.CD(N, betas[i]), [0, 1])
        cirq.append(jqtc.Rx(jnp.pi), 0)
    
   
    initial_state = jqt.basis(2,0) ^ jqt.basis(N,0)
    res = jqtc.simulate(cirq, initial_state, mode="default")

    return res[-1][-1].unit()

In [None]:
res = run_circuit(sbs_max_LL_X, 200)

In [None]:
betas_re= jnp.linspace(-3,3, 41)
betas_im= jnp.linspace(-3,3, 41)

cf_vals, _, _ = calculate_cf(res.ptrace(1), betas_re, betas_im)
envelope = calculate_envelope(betas_re, betas_im, fit_vals)

cf_vals = envelope * cf_vals

fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=150)

for j, ax in enumerate((axs[0], axs[1])):
    vmin=-1
    vmax=1
    sf = 1/jnp.sqrt(jnp.pi*2)

    if j == 0:
        cf_vals_plot = jnp.real(cf_vals)
        cbar_title = r"Re[$\mathcal{C}(\beta)$]"
    else:
        cf_vals_plot = jnp.imag(cf_vals)
        cbar_title = r"Im[$\mathcal{C}(\beta)$]"

    im = ax.contourf(betas_re*sf, betas_im*sf, cf_vals_plot.T, levels=np.linspace(vmin, vmax, 101), cmap="seismic", vmin=vmin, vmax=vmax)
    ax.set_aspect("equal", adjustable="box")
    ax.grid()

    ax.set_xticks([-1,0, 1])
    ax.set_yticks([-1,0, 1])
    ax.set_xlabel(r"Re[$\beta$]/$\sqrt{2\pi}$")
    ax.set_ylabel(r"Im[$\beta$]/$\sqrt{2\pi}$")

    cbar = plt.colorbar(im, ax=ax, orientation="vertical")
    cbar.ax.set_title(cbar_title)
    cbar.ax.set_yticks(np.linspace(-1, 1, 11))

    fig.suptitle(r"$\widehat{sBs_X} \cdot \vert vac \rangle$", y = 1.05)

    

In [None]:
params = jnp.array(jnp.load("./20250702_h11m09s01_gkp_state_prep.npz")["best_params"])

In [None]:
params

In [None]:
res = run_circuit(params, 200)

In [None]:
betas_re= jnp.linspace(-3,3, 41)
betas_im= jnp.linspace(-3,3, 41)

cf_vals, _, _ = calculate_cf(res.ptrace(1), betas_re, betas_im)
envelope = calculate_envelope(betas_re, betas_im, fit_vals)

cf_vals = envelope * cf_vals

fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=150)

for j, ax in enumerate((axs[0], axs[1])):
    vmin=-1
    vmax=1
    sf = 1/jnp.sqrt(jnp.pi*2)

    if j == 0:
        cf_vals_plot = jnp.real(cf_vals)
        cbar_title = r"Re[$\mathcal{C}(\beta)$]"
    else:
        cf_vals_plot = jnp.imag(cf_vals)
        cbar_title = r"Im[$\mathcal{C}(\beta)$]"

    im = ax.contourf(betas_re*sf, betas_im*sf, cf_vals_plot.T, levels=np.linspace(vmin, vmax, 101), cmap="seismic", vmin=vmin, vmax=vmax)
    ax.set_aspect("equal", adjustable="box")
    ax.grid()

    ax.set_xticks([-1,0, 1])
    ax.set_yticks([-1,0, 1])
    ax.set_xlabel(r"Re[$\beta$]/$\sqrt{2\pi}$")
    ax.set_ylabel(r"Im[$\beta$]/$\sqrt{2\pi}$")

    cbar = plt.colorbar(im, ax=ax, orientation="vertical")
    cbar.ax.set_title(cbar_title)
    cbar.ax.set_yticks(np.linspace(-1, 1, 11))

    fig.suptitle(r"$\widehat{sBs_X} \cdot \vert vac \rangle$", y = 1.05)

    

## Compare with simulate CF

In [None]:
def cf_tomography_circuit(state, beta, measure_real=True):
    N = state.dims[0][1]
    reg = jqtc.Register([2,N])
    cirq = jqtc.Circuit.create(reg, layers=[])

    cirq.append(jqtc.Ry(jnp.pi/2), 0)
    cirq.append(jqtc.CD(N, beta), [0,1])
    
    if measure_real:
        cirq.append(jqtc.Ry(jnp.pi/2), 0)
    else:
        cirq.append(jqtc.Rx(jnp.pi/2), 0)

    res = jqtc.simulate(cirq, state)
    final_state = res[-1][-1]
    sigmaz = jqt.sigmaz() ^ jqt.identity(N)
    sigmaz_exp = final_state.dag() @ sigmaz @ final_state
    return sigmaz_exp.data[0][0].real

def sim_cf(osc_state, betas_re=None, betas_im=None):
    if len(osc_state.dims[0]) == 1:
        if osc_state.is_dm():
            state = jqt.ket2dm(jqt.basis(2,0)) ^ osc_state
        else:
            state = jqt.basis(2,0) ^ osc_state
    else:
        state = osc_state

    # Plot CF
    betas_re = betas_re if betas_re is not None else jnp.linspace(-4,4, 101)
    betas_im = betas_re if betas_re is not None else  jnp.linspace(-4,4, 101)
    betas = betas_re.reshape(-1,1) + 1j*betas_im.reshape(1,-1)
    betas_flat = betas.flatten()

    cf_tomography_circuit_vmap = jax.jit(jax.vmap(lambda beta: cf_tomography_circuit(state, beta, measure_real=True)))
    tomo_res_real = cf_tomography_circuit_vmap(betas_flat)

    cf_tomography_circuit_vmap = jax.jit(jax.vmap(lambda beta: cf_tomography_circuit(state, beta, measure_real=False)))
    tomo_res_imag = cf_tomography_circuit_vmap(betas_flat)
    
    tomo_res_real = tomo_res_real.reshape(*betas.shape)
    tomo_res_imag = tomo_res_imag.reshape(*betas.shape)

    tomo_res = tomo_res_real + 1j*tomo_res_imag

    return tomo_res, betas_re, betas_im

In [None]:
betas_re= jnp.linspace(-3,3, 41)
betas_im= jnp.linspace(-3,3, 41)

cf_vals, _, _ = sim_cf(res, betas_re, betas_im)

fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=150)

for j, ax in enumerate((axs[0], axs[1])):
    vmin=-1
    vmax=1
    sf = 1/jnp.sqrt(jnp.pi*2)

    if j == 0:
        cf_vals_plot = jnp.real(cf_vals)
        cbar_title = r"Re[$\mathcal{C}(\beta)$]"
    else:
        cf_vals_plot = jnp.imag(cf_vals)
        cbar_title = r"Im[$\mathcal{C}(\beta)$]"

    im = ax.contourf(betas_re*sf, betas_im*sf, cf_vals_plot.T, levels=np.linspace(vmin, vmax, 101), cmap="seismic", vmin=vmin, vmax=vmax)
    ax.set_aspect("equal", adjustable="box")
    ax.grid()

    ax.set_xticks([-1,0, 1])
    ax.set_yticks([-1,0, 1])
    ax.set_xlabel(r"Re[$\beta$]/$\sqrt{2\pi}$")
    ax.set_ylabel(r"Im[$\beta$]/$\sqrt{2\pi}$")

    cbar = plt.colorbar(im, ax=ax, orientation="vertical")
    cbar.ax.set_title(cbar_title)
    cbar.ax.set_yticks(np.linspace(-1, 1, 11))

    fig.suptitle(r"$\widehat{sBs_X} \cdot \vert vac \rangle$", y = 1.05)
