In [None]:
using Revise

using CairoMakie
using Distributions
using LinearAlgebra
using Random

using QuantumCollocation
using NamedTrajectories

pathof(QuantumCollocation)

In [None]:
# System
T = 50
Δt = 0.2
H_drift = GATES[:Z]
H_drives = [GATES[:X], GATES[:Y]]
system(ζ) = QuantumSystem(ζ * H_drift, H_drives)
operator = GATES[:H]
;

In [None]:
# Initial data
R = 1e-3
R_a, R_da, R_dda = fill(R, 3)
Q = 1e2

a_bounds = fill(1.0, length(H_drives))
dda_bounds = fill(1.0, length(H_drives))
drive_derivative_σ = 0.1

init_trajectory = nothing
a_guess = nothing
geodesic = true

integrator = :pade
pade_order = 4
autodiff = pade_order != 4

free_time = false
timesteps_all_equal = true
bound_unitary = false
leakage_suppression = false
control_norm_constraint = false

# QuantumControlProblem data
blas_multithreading = false
max_iter = 100
linear_solver = "mumps"
verbose = true
ipopt_options = Options()
jacobian_structure = true
hessian_approximation = true
;

Goals:

- [x] Construct a system from a bunch of inital samples.
- [] Construct a problem from a bunch of systems.
- [] Refactor the problem templates code so the trajectory initialization in unitary smooth pulse is not copy and pasted so much.
- [] Build a named trajectory with suffixes associated to each system.

In [None]:
# test = [:a, :b]
# val = [1, 2]
# (; (test .=> val)...)

# # Has to be tuples in this case
# test = (:a, :b)
# val = (1, 2)
# NamedTuple{test}(val)

In [None]:
ζs = range(-.05, .05, length=5)
systems = [system(ζ) for ζ ∈ ζs];
system_labels = string.(1:length(systems))
system_weights = 

fill(1.0, length(systems))
;

In [None]:
ζs

In [None]:
function control_initial(
    n_drives::Int,
    a_bounds::AbstractVector{<:Real},
    dda_bounds::AbstractVector{<:Real},
    drive_derivative_σ::Float64
)
    if a_bounds isa AbstractVector
        a_dists = [Uniform(-a_bounds[i], a_bounds[i]) for i = 1:n_drives]
    elseif a_bounds isa Tuple
        a_dists = [Uniform(aᵢ_lb, aᵢ_ub) for (aᵢ_lb, aᵢ_ub) ∈ zip(a_bounds...)]
    else
        error("a_bounds must be a Vector or Tuple")
    end

    a = hcat([
        zeros(n_drives),
        vcat([rand(a_dists[i], 1, T - 2) for i = 1:n_drives]...),
        zeros(n_drives)
    ]...)

    da = randn(n_drives, T) * drive_derivative_σ
    dda = randn(n_drives, T) * drive_derivative_σ
    return a, da, dda
end

In [None]:
function unitary_initial(
    U_init::AbstractMatrix{<:Number},
    U_goal::AbstractMatrix{<:Number},
    T::Int;
    geodesic::Bool=true
)
    if geodesic
        Ũ⃗ = unitary_geodesic(U_goal, T)
    else
        Ũ⃗ = unitary_linear_interpolation(U_init, U_goal, T)
    end
    return Ũ⃗
end

In [None]:
Random.seed!(1)

@assert length(systems) ≥  1 "systems must be a non-empty vector of QuantumSystems"

if operator isa EmbeddedOperator
    U_goal = operator.operator
    U_init = get_subspace_identity(operator)
else
    U_goal = Matrix{ComplexF64}(operator)
    U_init = Matrix{ComplexF64}(I(size(U_goal, 1)))
end

Ũ⃗_keys = [add_suffix(:Ũ⃗, ℓ) for ℓ ∈ system_labels]

if !blas_multithreading
    BLAS.set_num_threads(1)
end

if hessian_approximation
    ipopt_options.hessian_approximation = "limited-memory"
end

n_drives = length(systems[1].G_drives)

if !isnothing(init_trajectory)
    traj = init_trajectory
else
    if free_time
        if Δt isa Float64
            Δt = fill(Δt, 1, T)
        end
    end

    # Initial state and controls
    if isnothing(a_guess)
        Ũ⃗ = unitary_initial(U_init, U_goal, T; geodesic=geodesic)
        a, da, dda = control_initial(n_drives, a_bounds, dda_bounds, drive_derivative_σ)
    else
        Ũ⃗ = unitary_rollout(
            operator_to_iso_vec(U_init),
            a_guess,
            Δt,
            system;
            integrator=rollout_integrator
        )
        a = a_guess
        da = derivative(a, Δt)
        dda = derivative(da, Δt)

        # to avoid constraint violation error at initial iteration
        da[:, end] = da[:, end-1] + Δt[end-1] * dda[:, end-1]
    end

    # Constraints and bounds
    Ũ⃗_inits = repeat([operator_to_iso_vec(U_init)], length(Ũ⃗_keys))
    initial = (;
        (Ũ⃗_keys .=> Ũ⃗_inits)...,
        a = zeros(n_drives),
    )

    final = (
        a = zeros(n_drives),
    )

    Ũ⃗_goals = repeat([operator_to_iso_vec(U_goal)], length(Ũ⃗_keys))
    goal = (; (Ũ⃗_keys .=> Ũ⃗_goals)...)

    bounds = (
        a = a_bounds,
        dda = dda_bounds,
    )

    if bound_unitary
        Ũ⃗_dim = size(Ũ⃗, 1)
        Ũ⃗_bounds = repeat([(-ones(Ũ⃗_dim), ones(Ũ⃗_dim))], length(Ũ⃗_keys))
        bounds = merge(bounds, (; (Ũ⃗_keys .=> Ũ⃗_bounds)...))
    end

    # Trajectory
    Ũ⃗_values = repeat([Ũ⃗], length(systems))
    keys = [Ũ⃗_keys..., :a, :da, :dda]
    values = [Ũ⃗_values..., a, da, dda]

    if free_time
        push!(keys, :Δt)
        push!(values, Δt)
        controls = (:dda, :Δt)
        timestep = :Δt
        bounds = merge(bounds, (Δt = (Δt_min, Δt_max),))
    else
        controls = (:dda,)
        timestep = Δt
    end

    traj = NamedTrajectory(
        (; (keys .=> values)...);
        controls=controls,
        timestep=timestep,
        bounds=bounds,
        initial=initial,
        final=final,
        goal=goal
    )
end

# Objective
J = NullObjective()
for (wᵢ, Ũ⃗ᵢ) in zip(system_weights, Ũ⃗_keys)
    J += UnitaryInfidelityObjective(
        Ũ⃗ᵢ, traj, Q; 
        subspace=operator isa EmbeddedOperator ? operator.subspace_indices : nothing
    )
end
J += QuadraticRegularizer(:a, traj, R_a)
J += QuadraticRegularizer(:da, traj, R_da)
J += QuadraticRegularizer(:dda, traj, R_dda)

# Constraints 
constraints = AbstractConstraint[]
# leakage TODO: Must change for the parameterized systems
if leakage_suppression
    if operator isa EmbeddedOperator
        leakage_indices = get_unitary_isomorphism_leakage_indices(operator)
        for Ũ⃗ᵢ in Ũ⃗_keys
            J_leakage, slack_con = L1Regularizer(
                Ũ⃗ᵢ,
                traj;
                R_value=R_leakage,
                indices=leakage_indices
            )
            push!(constraints, slack_con)
            J += J_leakage
        end
    else
        @warn "leakage_suppression is not supported for non-embedded operators, ignoring."
    end
end

if free_time
    if timesteps_all_equal
        push!(constraints, TimeStepsAllEqualConstraint(:Δt, traj))
    end
end

if control_norm_constraint
    @assert !isnothing(control_norm_constraint_components) "control_norm_constraint_components must be provided"
    @assert !isnothing(control_norm_R) "control_norm_R must be provided"
    norm_con = ComplexModulusContraint(
        :a,
        control_norm_R,
        traj;
        name_comps=control_norm_constraint_components,
    )
    push!(constraints, norm_con)
end

# Integrators
unitary_integrators = AbstractIntegrator[]
for (sysᵢ, Ũ⃗ᵢ) in zip(systems, Ũ⃗_keys)
    if integrator == :pade
        push!(
            unitary_integrators,
            UnitaryPadeIntegrator(sysᵢ, Ũ⃗ᵢ, :a; order=pade_order, autodiff=autodiff)
        )
    elseif integrator == :exponential
        push!(
            unitary_integrators,
            UnitaryExponentialIntegrator(sysᵢ, Ũ⃗ᵢ, :a)
        )
    else
        error("integrator must be one of (:pade, :exponential)")
    end
end

integrators = [
    unitary_integrators...,
    DerivativeIntegrator(:a, :da, traj),
    DerivativeIntegrator(:da, :dda, traj),
]

ipopt_options.recalc_y = "yes"
ipopt_options.recalc_y_feas_tol = 1.0

prob = QuantumControlProblem(
    direct_sum(systems),
    traj,
    J,
    integrators;
    constraints=constraints,
    max_iter=max_iter,
    linear_solver=linear_solver,
    verbose=verbose,
    ipopt_options=ipopt_options,
    jacobian_structure=jacobian_structure,
    hessian_approximation=hessian_approximation,
    eval_hessian=!hessian_approximation,
    # kwargs...
)

In [200]:
solve!(prob, max_iter=500)

In [None]:
f = Figure()
ax = f[1, 1] = Axis(f, xlabel="Time", ylabel="Control")
times = get_times(prob.trajectory)
for a in eachrow(prob.trajectory[:a])
    lines!(ax, times, a)
end
f

In [None]:
default_prob = UnitarySmoothPulseProblem(
    system(0),
    U_goal,
    T,
    Δt,
    ipopt_options=Options(print_level=2, recalc_y = "yes", recalc_y_feas_tol = 1.0)
)

solve!(default_prob, max_iter=500)

In [None]:
rob_prob = UnitaryRobustnessProblem(
    GATES[:Z],
    UnitarySmoothPulseProblem(system(0),U_goal,T,Δt,verbose=false),
    final_fidelity=0.9999,
    ipopt_options=Options(recalc_y="yes", recalc_y_feas_tol=1.0)
)

solve!(rob_prob, max_iter=500)

In [None]:
ζs = range(-.1, .1, length=100)
infids = map(ζs) do ζ
    Ũ⃗_end = unitary_rollout(
        prob.trajectory.a, 
        get_timesteps(prob.trajectory), 
        system(ζ)
    )[:, end]
    1 - unitary_fidelity(Ũ⃗_end, operator_to_iso_vec(U_goal))
end

default_infids = map(ζs) do ζ
    Ũ⃗_end = unitary_rollout(
        default_prob.trajectory.a, 
        get_timesteps(default_prob.trajectory), 
        system(ζ)
    )[:, end]
    1 - unitary_fidelity(Ũ⃗_end, operator_to_iso_vec(U_goal))
end

rob_infids = map(ζs) do ζ
    Ũ⃗_end = unitary_rollout(
        rob_prob.trajectory.a, 
        get_timesteps(rob_prob.trajectory), 
        system(ζ)
    )[:, end]
    1 - unitary_fidelity(Ũ⃗_end, operator_to_iso_vec(U_goal))
end
;

In [None]:
f = Figure()
ax = f[1, 1] = Axis(f, xlabel="ζ", ylabel="Infidelity", yscale=log10)
lines!(ax, ζs, infids, label="Sampling")
lines!(ax, ζs, default_infids, label="Default")
lines!(ax, ζs, rob_infids, label="FORE")
Legend(f[1, 2], ax)
f