In [None]:
using QuantumCollocation
using NamedTrajectories
using TrajectoryIndexingUtils

using CairoMakie
using DelimitedFiles
using Distributions
using LinearAlgebra

In [None]:
show_matrix(A) = show(stdout, "text/plain", A)

In [None]:
# Check previous solutions

In [None]:
# Operators 
const n_levels = 2
at = create(n_levels)
a = annihilate(n_levels)

H_operators = Dict(
        "X" => a + at,
        "Y" => -im * (a - at),
        "Z" => I - 2 * at * a,
)

# Time
T = 50
# Δt = 0.2
Δt = 2/9

# Crosstalk sweep
ζs = range(0, 0.1, length=51)
;

In [None]:
default = load_traj("saved-pulses-2023-12-13/single_qubit_gateset_default.jld2")
solution = load_traj("saved-pulses-2023-12-13/single_qubit_gateset_R1e-3.jld2")

In [None]:
fig = Figure()
ts = accumulate(+, timesteps(default))
ax1 = Axis(fig[1,1], xlabel = "Time", ylabel = "Control amplitude")
ax2 = Axis(fig[1,2], xlabel = "Time", ylabel = "Control amplitude")
for row in eachrow(default[:a])
    lines!(ax1, ts, row)
end
for row in eachrow(solution[:a])
    lines!(ax2, ts, row)
end
fig

In [None]:
H_drift = zeros(n_levels^2, n_levels^2)
H_crosstalk = kron_from_dict("ZZ", H_operators)
H_controls = [
    kron_from_dict("XI", H_operators),
    kron_from_dict("IX", H_operators),
]
sys = QuantumSystem(H_drift, H_controls)
sys_xtalk_fn(s) = QuantumSystem(H_drift + s * H_crosstalk, H_controls)
;

In [None]:
# uf = iso_vec_to_operator(unitary_rollout(u0, default[:a][1:2, :], Δt, sys)[:, end])
# uf = iso_vec_to_operator(unitary_rollout(u0, solution[:a][1:2, :], Δt, sys)[:, end])
# uf = iso_vec_to_operator(unitary_rollout(u0, solution[:a][1:2, :], Δt, sys_xtalk)[:, end])
# uf = iso_vec_to_operator(unitary_rollout(u0, default[:a][1:2, :], Δt, sys_xtalk)[:, end])

In [None]:
function my_fn(controls, system, target)
    u0 = operator_to_iso_vec(kron_from_dict("II", H_operators))
    Uf = iso_vec_to_operator(unitary_rollout(u0, controls, Δt, system)[:, end])
    return 1 - unitary_fidelity(Uf, target)
end

In [None]:
target = GATES[:X] ⊗ GATES[:X]

def_xtalk_res = [my_fn(default[:a][1:2, :], sys_xtalk_fn(s), target) for s in ζs]
sol_xtalk_res = [my_fn(solution[:a][1:2, :], sys_xtalk_fn(s), target) for s in ζs]

fig = Figure()
ax = Axis(fig[1,1], yscale=log10, xlabel = "Crosstalk strength", ylabel = "Fidelity")
l1 = lines!(ax, ζs, def_xtalk_res, color = :blue)
l2 = lines!(ax, ζs, sol_xtalk_res, color = :red, label = "Solution")
Legend(fig[1,2], [l1, l2], ["Default", "Solution"])
fig

In [None]:
target = GATES[:X] ⊗ sqrt(GATES[:X])

def_xtalk_res = [my_fn(default[:a][[1,3], :], sys_xtalk_fn(s), target) for s in ζs]
sol_xtalk_res = [my_fn(solution[:a][[1,3], :], sys_xtalk_fn(s), target) for s in ζs]

fig = Figure()
ax = Axis(fig[1,1], yscale=log10, xlabel = "Crosstalk strength", ylabel = "Fidelity")
l1 = lines!(ax, ζs, def_xtalk_res, color = :blue)
l2 = lines!(ax, ζs, sol_xtalk_res, color = :red, label = "Solution")
Legend(fig[1,2], [l1, l2], ["Default", "Solution"])
fig

In [None]:
target = GATES[:X] ⊗ sqrt(GATES[:X])

def_xtalk_res = [my_fn(default[:a][[2,4], :], sys_xtalk_fn(s), target) for s in ζs]
sol_xtalk_res = [my_fn(solution[:a][[2,4], :], sys_xtalk_fn(s), target) for s in ζs]

fig = Figure()
ax = Axis(fig[1,1], yscale=log10, xlabel = "Crosstalk strength", ylabel = "Fidelity")
l1 = lines!(ax, ζs, def_xtalk_res, color = :blue)
l2 = lines!(ax, ζs, sol_xtalk_res, color = :red, label = "Solution")
Legend(fig[1,2], [l1, l2], ["Default", "Solution"])
fig

In [None]:
target = GATES[:X] ⊗ sqrt(GATES[:X])

def_xtalk_res = [my_fn(default[:a][[1,4], :], sys_xtalk_fn(s), target) for s in ζs]
sol_xtalk_res = [my_fn(solution[:a][[1,4], :], sys_xtalk_fn(s), target) for s in ζs]

fig = Figure()
ax = Axis(fig[1,1], yscale=log10, xlabel = "Crosstalk strength", ylabel = "Fidelity")
l1 = lines!(ax, ζs, def_xtalk_res, color = :blue)
l2 = lines!(ax, ζs, sol_xtalk_res, color = :red, label = "Solution")
Legend(fig[1,2], [l1, l2], ["Default", "Solution"])
fig

In [None]:
target = sqrt(GATES[:X]) ⊗ sqrt(GATES[:X])

def_xtalk_res = [my_fn(default[:a][[3,4], :], sys_xtalk_fn(s), target) for s in ζs]
sol_xtalk_res = [my_fn(solution[:a][[3,4], :], sys_xtalk_fn(s), target) for s in ζs]

fig = Figure()
ax = Axis(fig[1,1], yscale=log10, xlabel = "Crosstalk strength", ylabel = "Fidelity")
l1 = lines!(ax, ζs, def_xtalk_res, color = :blue)
l2 = lines!(ax, ζs, sol_xtalk_res, color = :red, label = "Solution")
Legend(fig[1,2], [l1, l2], ["Default", "Solution"])
fig

# ECR gate

In [None]:
ZX = kron_from_dict("ZX", H_operators)
XZ = kron_from_dict("XZ", H_operators)
XI = kron_from_dict("XI", H_operators)
IX = kron_from_dict("IX", H_operators)
XY = kron_from_dict("XY", H_operators)
YX = kron_from_dict("YX", H_operators)

In [None]:
R = im * √2 * exp(im * π/8 * XZ) * exp(-im * π/2 * IX) * exp(-im * π/8 * XZ) #* exp(-im * π/2 * XI)
show_matrix(round.(R, digits=4))

In [None]:
show_matrix(IX - XY)

Careful on ordering in Qiskit.

In [None]:
R = im * √2 * exp(im * π/8 * ZX) * exp(-im * π/2 * XI) * exp(-im * π/8 * ZX) #* exp(-im * π/2 * XI)
show_matrix(round.(R, digits=4))

In [None]:
show_matrix(XI - YX)