In [None]:
using QuantumCollocation
using NamedTrajectories
using TrajectoryIndexingUtils
using Flux
using ReinforcementLearning
using IntervalSets
using LinearAlgebra
using Base
using Distributions
using Statistics
using Printf
using Reexport
using Revise
using DomainSets
using CairoMakie

includet("PPO.jl")
includet("AIRL.jl")
includet("GateEnvs.jl")

In [None]:
RZ_traj = load_traj("RZ_pretrained.jld2")

const Units = 1e9
const MHz = 1e6 / Units
const GHz = 1e9 / Units
const ns = 1e-9 * Units
const μs = 1e-6 * Units
;


# Operators
const Paulis = Dict(
    "I" => Matrix{ComplexF64}([1 0; 0 1]),
    "X" => Matrix{ComplexF64}([0 1; 1 0]),
    "Y" => Matrix{ComplexF64}([0 im; -im 0]),
    "Z" => Matrix{ComplexF64}([1 0; 0 -1]),
)

rz_op(theta) = exp(-im/2 * theta[1] * Paulis["Z"]);

RZ = Gate(1,rz_op)

H_drives = [
     Paulis["X"],Paulis["Y"]
]
system = QuantumSystem(H_drives);
t_f = 10* ns
n_steps = 51
times = range(0, t_f, n_steps)  # Alternative: collect(0:Δt:t_f)
n_controls=1
n_qubits=1;
Δt = times[2] - times[1]

N = 11
;

In [None]:
Pretraining_Env = GatePretrainingEnv(
                                    system,
                                    n_steps,
                                    RZ,
                                    Δt,
                                    N,
                                    RZ_traj;
                                    dda_bound=1.5
                                    )

Training_Env = GateTrainingEnv(
                            system,
                            n_steps,
                            RZ,
                            Δt;
                            dda_bound=0.5
                            );

pretraining_𝒫 = ActorCriticPolicy(Pretraining_Env;l=[128,128])
training_𝒫 = ActorCriticPolicy(Training_Env;l=[128,128])

In [None]:
expert_states = Vector{Vector{Float32}}()
expert_acts = Vector{Vector{Float32}}()
expert_new_states = Vector{Vector{Float32}}()
for idx in 1:11
    for t in 1:48
        angle = range(0,2*pi,11)[idx]
        a = RZ_traj[Symbol("a"*string(idx))][:,t]
        da = RZ_traj[Symbol("da"*string(idx))][:,t]
        dda = RZ_traj[Symbol("dda"*string(idx))][:,t]
        U = RZ_traj[Symbol(Unicode.normalize("Ũ⃗"*string(idx)))][:,t]
        push!(expert_states, vcat(U,da,a,[t/51],angle))
        push!(expert_acts, dda)
        if(t>1)
          push!(expert_new_states, expert_states[end])
        end
        if(t==48)
            angle = range(0,2*pi,11)[idx]
            a = RZ_traj[Symbol("a"*string(idx))][:,51]
            da = RZ_traj[Symbol("da"*string(idx))][:,51]
            dda = RZ_traj[Symbol("dda"*string(idx))][:,51]
            U = RZ_traj[Symbol(Unicode.normalize("Ũ⃗"*string(idx)))][:,51]
            push!(expert_new_states, vcat(U,da,a,[49/51],angle))

        end
    end
end
behavior_clone(pretraining_𝒫,expert_states,expert_acts;epochs=50000,η=5f-5) 

In [None]:
mean(mean.([abs2.(x) for x in (pretraining_𝒫.mean_network.(pretraining_𝒫.feature_network.(expert_states))-expert_acts)]))

In [None]:
PPO_pretraining_𝒫,score_history = PPO(Pretraining_Env;η=3f-5,iterations=1000,n_steps=1,trajectory_batch_size=40,vf_ratio=5f-1,norm_adv = true,ϵ= 1f-1,ent_ratio=1f-3,KL_targ=1f-2,initial_policy=pretraining_𝒫,clip_grad_tresh=1f3,use_log_rewards=false)

In [None]:
#using CairoMakie
#fig = Figure()
#ax = Axis(fig[1, 1])
#lines!(ax,1:length(score_history),score_history)
#fig

In [None]:
includet("AIRL.jl")
AIRL_pretraining_𝒫,score_history,e_losses,s_losses,total_d_losses = AIRL(Pretraining_Env,expert_states,expert_acts,expert_new_states;η=3f-5,iterations=50,n_steps=1,trajectory_batch_size=11,vf_ratio=5f-1,norm_adv = true,ϵ= 1f-1,ent_ratio=1f-3,KL_targ=1f-2,clip_grad_tresh=1f3,use_log_rewards=false)

In [None]:
using CairoMakie
fig = Figure()
ax = Axis(fig[1, 1])
lines!(ax,1:length(score_history),score_history)
fig

In [None]:
using CairoMakie
fig = Figure()
ax = Axis(fig[1, 1])
lines!(ax,1:length(e_losses),e_losses)
fig

In [None]:
using CairoMakie
fig = Figure()
ax = Axis(fig[1, 1])
lines!(ax,1:length(s_losses),s_losses)
fig

In [None]:
using CairoMakie
fig = Figure()
ax = Axis(fig[1, 1])
lines!(ax,1:length(total_d_losses),total_d_losses)
fig

In [None]:
includet("AIRL.jl")
bc_AIRL_pretraining_𝒫,score_history,e_losses,s_losses,total_d_losses = AIRL(Pretraining_Env,expert_states,expert_acts,expert_new_states;η=3f-5,iterations=50,n_steps=1,trajectory_batch_size=11,vf_ratio=5f-1,norm_adv = true,ϵ= 1f-1,ent_ratio=1f-3,KL_targ=1f-2,initial_policy=pretraining_𝒫,clip_grad_tresh=1f3,use_log_rewards=false)

In [None]:
using CairoMakie
fig = Figure()
ax = Axis(fig[1, 1])
lines!(ax,1:length(score_history),score_history)
fig

In [None]:
using CairoMakie
fig = Figure()
ax = Axis(fig[1, 1])
lines!(ax,1:length(e_losses),e_losses)
fig

In [None]:
using CairoMakie
fig = Figure()
ax = Axis(fig[1, 1])
lines!(ax,1:length(s_losses),s_losses)
fig

In [None]:
using CairoMakie
fig = Figure()
ax = Axis(fig[1, 1])
lines!(ax,1:length(total_d_losses),total_d_losses)
fig