In [1]:
using QuantumCollocation
using NamedTrajectories
using TrajectoryIndexingUtils
using Flux
using ReinforcementLearning
using IntervalSets
using LinearAlgebra
using Base
using Distributions
using Statistics
using Printf


In [2]:
Base.@kwdef mutable struct PretrainingGateEnv <: AbstractEnv
            system::AbstractQuantumSystem
            T::Int
            g::Gate
            Δt::Union{Float64, Vector{Float64}}
            N::Int64
            pretraining_trajectory::NamedTrajectory{Float64}
    
            dda_bound::Float64=1.0
            current_op::AbstractVector{Float64} = operator_to_iso_vec(Matrix{ComplexF64}(I(size(system.H_drives[1], 1))))
            time_step::Float64=1/T
            
            a::AbstractMatrix{Float64} = reduce(hcat,[[0. for i in 1:length(system.H_drives)]])
            da::AbstractMatrix{Float64} = reduce(hcat,[[0. for i in 1:length(system.H_drives)]])
            dda::AbstractMatrix{Float64} = Matrix{Float64}(reshape([],length(system.H_drives),0))
            angle::Vector{Float64} = [range(0,2*pi,N)[i] for i in rand(DiscreteUniform(1,N),g.n)]

end

Base.@kwdef mutable struct TrainingGateEnv <: AbstractEnv
            system::AbstractQuantumSystem
            T::Int
            g::Gate
            Δt::Union{Float64, Vector{Float64}}
    
            dda_bound::Float64=1.0
            current_op::AbstractVector{Float64} = operator_to_iso_vec(Matrix{ComplexF64}(I(size(system.H_drives[1], 1))))
            time_step::Float64=1/T
            
            a::AbstractMatrix{Float64} = reduce(hcat,[[0. for i in 1:length(system.H_drives)]])
            da::AbstractMatrix{Float64} = reduce(hcat,[[0. for i in 1:length(system.H_drives)]])
            dda::AbstractMatrix{Float64} =  Matrix{Float64}(reshape([],length(system.H_drives),0))
            angle::Vector{Float64} = rand(Uniform(0,2*pi),g.n)

end

RLBase.is_terminated(env::Union{PretrainingGateEnv,TrainingGateEnv}) = env.time_step >= (env.T-2)/env.T

RLBase.action_space(env::Union{PretrainingGateEnv,TrainingGateEnv}) = reduce(×,[(-1..1) for i in 1:length(env.system.H_drives)])

RLBase.state_space(env::Union{PretrainingGateEnv,TrainingGateEnv}) = reduce(×, [(-1..1) for i in 1:length(env.current_op)]) × reduce(×, [(-Inf..Inf) for i in 1:2*length(env.system.H_drives)]) × (1/env.T..1) × reduce(×,[(0..2*pi) for i in 1:env.g.n])

RLBase.state(env::Union{PretrainingGateEnv,TrainingGateEnv})= Vector{Float32}(reduce(vcat,[env.current_op,env.da[:,end],env.a[:,end],[env.time_step],env.angle]))

function RLBase.act!(env::Union{PretrainingGateEnv,TrainingGateEnv}, action::Vector{Float32})
    action = Vector{Float64}(action)*env.dda_bound
    env.dda = hcat(env.dda,action)
    env.a = hcat(env.a, env.a[:,end] + env.da[:,end]*env.Δt)
    env.da = hcat(env.da, env.da[:,end] + env.dda[:,end]*env.Δt)
    
    env.time_step += 1/env.T
    env.current_op = unitary_rollout(env.current_op,hcat(env.a[:,end],zeros(length(action))),env.Δt,env.system)[:,end]

    if(RLBase.is_terminated(env))
        da0 = env.da[:,end]
        a0 = env.a[:,end]
        
        dda0 = (-a0-da0*2*env.Δt)/env.Δt^2
        env.dda = hcat(env.dda, dda0)
        env.a = hcat(env.a, env.a[:,end] + env.da[:,end]*env.Δt)
        env.da = hcat(env.da, env.da[:,end] + env.dda[:,end]*env.Δt)

        dda1=(-da0-dda0*env.Δt)/env.Δt
        env.dda = hcat(env.dda, dda1)
        env.a = hcat(env.a, env.a[:,end] + env.da[:,end]*env.Δt)
        env.da = hcat(env.da, env.da[:,end] + env.dda[:,end]*env.Δt)

        env.dda = hcat(env.dda, [0. for i in 1:length(system.H_drives)])

        env.current_op = unitary_rollout(env.current_op,hcat(env.a[:,end-1:end],zeros(length(action))),env.Δt,env.system)[:,end]

    end
end

function RLBase.reset!(env::PretrainingGateEnv; angle::Union{Vector{Float64},Nothing}=nothing)
    env.current_op = operator_to_iso_vec(Matrix{ComplexF64}(I(size(system.H_drives[1], 1))))
    env.time_step=1/env.T
    
    env.a = reduce(hcat,[[0. for i in 1:length(env.system.H_drives)]])
    env.da = reduce(hcat,[[0. for i in 1:length(env.system.H_drives)]])
    env.dda =  Matrix{Float64}(reshape([],length(env.system.H_drives),0))
    env.angle = isnothing(angle) ? [range(0,2*pi,N)[i] for i in rand(DiscreteUniform(1,N),env.g.n)] : angle
end

function RLBase.reset!(env::TrainingGateEnv; angle::Union{Vector{Float64},Nothing}=nothing)
    env.current_op = operator_to_iso_vec(Matrix{ComplexF64}(I(size(system.H_drives[1], 1))))
    env.time_step=1/env.T
    
    env.a = reduce(hcat,[[0. for i in 1:length(env.system.H_drives)]])
    env.da = reduce(hcat,[[0. for i in 1:length(env.system.H_drives)]])
    env.dda =  Matrix{Float64}(reshape([],length(env.system.H_drives),0))
    env.angle = isnothing(angle) ? rand(Uniform(0,2*pi),env.g.n) : angle
end

struct GatePolicy
    mean_network::Chain
    std_network::Chain
end

function GatePolicy(env::Union{PretrainingGateEnv,TrainingGateEnv};l::Vector{Int64}=[16,16])
    out = length(env.system.H_drives)
    mean_in = length(RLBase.state(env))
    std_in = env.g.n

    mean_network = Chain(Dense(mean_in=>l[1],relu),[Dense(l[i]=>l[i+1],relu) for i in 1:length(l)-1]...,Dense(l[end]=>out,softsign))
    std_network = Chain(Dense(std_in=>l[1],relu),[Dense(l[i]=>l[i+1],relu) for i in 1:length(l)-1]...,Dense(l[end]=>1))

    return GatePolicy(mean_network,std_network)
end

function(Policy::GatePolicy)(env::Union{PretrainingGateEnv,TrainingGateEnv}; deterministic::Bool = false)
    state = Vector{Float32}(RLBase.state(env))
    means = Policy.mean_network(state)
    if(!deterministic)
        std = exp(Policy.std_network(state[end-env.g.n+1:end])[1])
        return means+rand(Normal(0,std),length(means))
    else
        return means
    end        
end

# function(Policy::GatePolicy)(state::Vector{Float32}; deterministic::Bool = false)
#     means = Policy.mean_network(state)
#     if(!deterministic)
#         std = exp(Policy.std_network(state[end-env.g.n+1:end])[1])
#         return means+rand(Normal(0,std),length(means))
#     else
#         return means
#     end        
# end

function deepcopy(Policy::GatePolicy)
    return GatePolicy(Flux.deepcopy(Policy.mean_network),Flux.deepcopy(Policy.std_network))
end

function policy_prob(policy::GatePolicy,state::Vector{Float32},action::Vector{Float32})
    n=size(policy.std_network.layers[1].weight)[end]
    means = policy.mean_network(state)
    std = exp(policy.std_network(state[end-n+1:end])[1])
    devs = action-means
    return Float32(reduce(*,exp.(-devs.^2/(2*std^2))*1/sqrt(2 * pi * std^2)))
end

function policy_log_prob(policy::GatePolicy,state::Vector{Float32},action::Vector{Float32})
    n=size(policy.std_network.layers[1].weight)[end]
    means = policy.mean_network(state)
    std = exp(policy.std_network(state[end-n+1:end])[1])
    devs = action-means
    return Float32(sum((-devs.^2/(2*std^2)).-1/2 * log(2 * pi * std^2)))
end

Flux.@functor GatePolicy

function RLBase.reward(env::PretrainingGateEnv;
                action::Union{AbstractVector{Float32},Nothing}=nothing,
                S::Float64=2/(env.Δt)^2,
                S_a::Float64=S,
                S_da::Float64=S,
                S_dda::Float64=S)
    idx = Vector{Int64}(env.angle.*(env.N-1)/(2*pi).+1)
    idx = sum((idx[1:env.g.n-1].-1).*[env.N^(env.g.n-i) for i in 1:env.g.n-1])+idx[end]
    if(! RLBase.is_terminated(env))
        t = Int64(round(env.time_step*env.T))
        action = Vector{Float64}(action)*env.dda_bound
        return -(sum((env.a[:,end] - env.pretraining_trajectory[Symbol("a"*string(idx))][:,t]).^2)*Δt^2/2 * S_a+sum((env.da[:,end] -env.pretraining_trajectory[Symbol("da"*string(idx))][:,t]).^2)*Δt^2/2 * S_da + sum((action -env.pretraining_trajectory[Symbol("dda"*string(idx))][:,t]).^2)*Δt^2/2 * S_dda)
    else
        return -(sum((env.a[:,end-2:end] - env.pretraining_trajectory[Symbol("a"*string(idx))][:,end-2:end]).^2)*Δt^2/2 * S_a
            +sum((env.da[:,end-2:end] -env.pretraining_trajectory[Symbol("da"*string(idx))][:,end-2:end]).^2)*Δt^2/2 * S_da 
            + sum((env.dda[:,end-2:end] -env.pretraining_trajectory[Symbol("dda"*string(idx))][:,end-2:end]).^2)*Δt^2/2 * S_dda)
    end
end


function RLBase.reward(env::TrainingGateEnv;
                action::Union{AbstractVector{Float32},Nothing}=nothing,
                R::Float64=2/(env.Δt)^2,
                Q::Float64=1e4 * 2/(env.Δt)^2,
                R_a::Float64=R,
                R_da::Float64=R,
                R_dda::Float64=R)

   if(! RLBase.is_terminated(env))
        t = Int64(round(env.time_step*env.T))
        action = Vector{Float64}(action)*env.dda_bound
        return -(sum(env.a[:,end].^2)*Δt^2/2 * R_a+sum(env.da[:,end].^2)*Δt^2/2 * R_da + sum(action.^2)*Δt^2/2 * R_dda)
    else 
        return -(sum(env.a[:,end-2:end].^2)*Δt^2/2 * R_a+sum(env.da[:,end-2:end].^2)*Δt^2/2 * R_da + sum(env.dda[:,end-2:end].^2)*Δt^2/2 * R_dda+Q*(1-abs(tr(iso_vec_to_operator(env.current_op)'env.g(env.angle))/size(env.system.H_drives[1])[1])))
    end
end

function getTrajectoryLoss(env::PretrainingGateEnv;
                S::Float64=2/(env.Δt)^2,
                S_a::Float64=S,
                S_da::Float64=S,
                S_dda::Float64=S)
    idx = Vector{Int64}(env.angle.*(env.N-1)/(2*pi).+1)
    idx = sum((idx[1:env.g.n-1].-1).*[env.N^(env.g.n-i) for i in 1:env.g.n-1])+idx[end]
    return -(sum((env.a - env.pretraining_trajectory[Symbol("a"*string(idx))]).^2)*Δt^2/2 * S_a+sum((env.da -env.pretraining_trajectory[Symbol("da"*string(idx))]).^2)*Δt^2/2 * S_da + sum((env.dda -env.pretraining_trajectory[Symbol("dda"*string(idx))]).^2)*Δt^2/2 * S_dda)
end

function getTrajectoryLoss(env::TrainingGateEnv;
                R::Float64=2/(env.Δt)^2,
                Q::Float64=1e4*2/(env.Δt)^2,
                R_a::Float64=R,
                R_da::Float64=R,
                R_dda::Float64=R)
    
    reg = sum(env.a.^2)*Δt^2/2 * R_a+sum(env.da.^2)*Δt^2/2 * R_da + sum(env.dda .^2)*Δt^2/2 * R_dda
    return -(reg+Q*(1-abs(tr(iso_vec_to_operator(env.current_op)'env.g(env.angle))/size(env.system.H_drives[1])[1])))
end

function SampleTrajectory(Policy::GatePolicy,env::Union{PretrainingGateEnv,TrainingGateEnv};deterministic::Bool=false, kwargs...)
    RLBase.reset!(env)
    rewards = Vector{Float64}()
    actions = Vector{Vector{Float32}}()
    states  = Vector{Vector{Float64}}()
    while(! RLBase.is_terminated(env))
        push!(states,RLBase.state(env))
        action = Policy(env;deterministic=deterministic)
        push!(actions,action)
        push!(rewards,RLBase.reward(env;action=action,kwargs...))
        RLBase.act!(env,action) 
    end
    rewards[end]+= RLBase.reward(env;kwargs...)
    return Vector{Float32}(rewards),Vector{Vector{Float32}}(actions),Vector{Vector{Float32}}(states)
end

function euler(dda::Matrix{Float64},n_steps::Int64,Δt::Float64)
    n_controls = size(dda)[1]
    da_init=-sum(hcat([0 for i in 1:n_controls],cumsum(dda[:,1:end-1]*Δt,dims=2))[:,1:end-1],dims=2)/(n_steps-1)
    da=hcat([0 for i in 1:n_controls],cumsum(dda[:,1:end-1]*Δt,dims=2)) + reduce(hcat,[da_init for i in 1:n_steps])
    a_=hcat([0 for i in 1:n_controls],cumsum(da[:,1:end-1]*Δt,dims=2))
    return a_
end

euler (generic function with 1 method)

In [3]:
RZ_traj = load_traj("RZ_pretrained.jld2")

const Units = 1e9
const MHz = 1e6 / Units
const GHz = 1e9 / Units
const ns = 1e-9 * Units
const μs = 1e-6 * Units
;


# Operators
const Paulis = Dict(
    "I" => Matrix{ComplexF64}([1 0; 0 1]),
    "X" => Matrix{ComplexF64}([0 1; 1 0]),
    "Y" => Matrix{ComplexF64}([0 im; -im 0]),
    "Z" => Matrix{ComplexF64}([1 0; 0 -1]),
)

rz_op(theta) = exp(-im/2 * theta[1] * Paulis["Z"]);

RZ = Gate(1,rz_op)

H_drives = [
     Paulis["X"],Paulis["Y"]
]
system = QuantumSystem(H_drives);
t_f = 10* ns
n_steps = 51
times = range(0, t_f, n_steps)  # Alternative: collect(0:Δt:t_f)
n_controls=1
n_qubits=1;
Δt = times[2] - times[1]

N = 11
;

In [4]:
Pretraining_Env = PretrainingGateEnv(
                                    system = system,
                                    Δt=Δt,
                                    T=n_steps,
                                    g=RZ,
                                    N=11,
                                    pretraining_trajectory=RZ_traj;
                                    dda_bound=1.0
                                    )

Training_Env = TrainingGateEnv(
                            system = system,
                            Δt=Δt,
                            T=n_steps,
                            g=RZ;
                            dda_bound=1.0
                            );

policy = GatePolicy(Training_Env)


GatePolicy(Chain(Dense(14 => 16, relu), Dense(16 => 16, relu), Dense(16 => 2, softsign)), Chain(Dense(1 => 16, relu), Dense(16 => 16, relu), Dense(16 => 1)))

In [5]:
rewards,acts,states = SampleTrajectory(policy,Pretraining_Env;)
a=Pretraining_Env.a
da=Pretraining_Env.da
dda=Pretraining_Env.dda
;

In [6]:
getTrajectoryLoss(Pretraining_Env;)-sum(rewards)

-0.00782041106140241

In [7]:
sum((euler(dda,n_steps,Δt)-a).^2)

1.4236188424877247e-28

In [8]:
unitary_rollout(operator_to_iso_vec([1+0.0im 0; 0 1]),a,Δt,system)[:,end]-Pretraining_Env.current_op

8-element Vector{Float64}:
 -3.3306690738754696e-16
  0.0
 -1.1102230246251565e-16
  1.1102230246251565e-16
  0.0
 -3.3306690738754696e-16
  1.1102230246251565e-16
  1.1102230246251565e-16

In [9]:
rewards,acts,states = SampleTrajectory(policy,Training_Env;)
a=Training_Env.a
da=Training_Env.da
dda=Training_Env.dda
;

In [10]:
getTrajectoryLoss(Training_Env;)-sum(rewards)

0.007854093215428293

In [11]:
sum((euler(dda,n_steps,Δt)-a).^2)

3.7695618774468643e-28

In [12]:
unitary_rollout(operator_to_iso_vec([1+0.0im 0; 0 1]),a,Δt,system)[:,end]-Training_Env.current_op

8-element Vector{Float64}:
 -1.1102230246251565e-16
 -3.3306690738754696e-16
  4.718447854656915e-16
 -1.2906342661267445e-15
  3.3306690738754696e-16
 -1.1102230246251565e-16
 -1.2906342661267445e-15
 -4.718447854656915e-16

In [13]:
function ValueNetwork(env::Union{PretrainingGateEnv,TrainingGateEnv};l::Vector{Int64}=[16,16])
    input = length(RLBase.state(env))
    return  Chain(Dense(input=>l[1],relu),[Dense(l[i]=>l[i+1],relu) for i in 1:length(l)-1]...,Dense(l[end]=>1))
end

ValueNetwork (generic function with 1 method)

In [14]:
discount_cumsum(l::Vector{Float32},γ::Float32) = [sum(l[j:end].*[γ^i for i in 0:length(l)-j]) for j in 1:length(l)]
function GAE(states::Vector{Vector{Float32}},rewards::Vector{Float32},actions::Vector{Vector{Float32}},VNN::Chain;γ::Float32=Float32(0.99),λ::Float32=Float32(0.97))
    vals = getindex.(VNN.(states),1)
    push!(vals,vals[end])
    δ = rewards[1:end] + γ * vals[2:end] - vals[1:end-1]
    return discount_cumsum(rewards,γ),discount_cumsum(δ,γ*λ)
end
function FitValueVNN(VNN::Chain,states_list::Matrix{Float32},rewards_to_go::Matrix{Float32}; max_iter::Int64 = 1000, lr::Float32 = 1f-3,tol::Float32=1f-4, batchsize::Int64=32)
    epoch_losses = Vector{Float32}()
    loss(x, y) = mean(abs2.(x.- y))
    opt_state = Flux.setup(Adam(lr), VNN)
    data = Flux.DataLoader((states_list, rewards_to_go), batchsize=batchsize)
    
    for epoch in 1:max_iter
        batch_losses = Vector{Float32}()
        for (x_d,y_d) in data
            val, grads = Flux.withgradient(VNN) do VNN
              result = VNN(x_d)
              loss(result, y_d)
            end
        
            # Save the loss from the forward pass. (Done outside of gradient.)
            push!(batch_losses, val)
            Flux.update!(opt_state, VNN, grads[1])    
        end
        push!(epoch_losses, mean(batch_losses))
        if(length(epoch_losses)>2 && abs(epoch_losses[end]-epoch_losses[end-1]) <= tol)
            break
        end
    end
    return epoch_losses
end 

g(ϵ,A) = (A>=0) ? (1+ϵ) * A : (1-ϵ) * A

function clip_optimize(policy::GatePolicy,
                      rewards_to_go::Matrix{Float32},
                      states_list::Matrix{Float32},
                      acts_list::Matrix{Float32},
                      Adv_list::Matrix{Float32};
                      ϵ::Float32 = 1f-1,
                      max_iter::Int64 = 1000, lr::Float32 = 1f-3,tol::Float32=1f-4, 
        targ_KL::Float32 = 1f-2,batchsize::Int64=32)
    n = size(acts_list)[1]
    old_policy = deepcopy(policy)
    epoch_losses = Vector{Float32}()
    opt_state = Flux.setup(Adam(lr), policy)
    old_log_probs = reshape([policy_log_prob(old_policy,states_list[:,i],acts_list[:,i]) for i in 1:length(rewards_to_go)],size(Adv_list)...)
    data = Flux.DataLoader((reduce(vcat,[rewards_to_go,Adv_list,old_log_probs]), vcat(states_list,acts_list)), batchsize=batchsize)
    KL = 0
    for epoch in 1:max_iter
        batch_losses = Vector{Float32}()
        for (x,y) in data
            batch_rewards_to_go=x[1,:]
            batch_Adv=x[2,:]
            batch_old_log_probs=x[3,:]

            batch_states = y[1:end-n,:]
            batch_acts = y[end-n+1:end,:]

            val, grads = Flux.withgradient(policy) do policy
                -mean(minimum([exp.([policy_log_prob(policy,batch_states[:,i],batch_acts[:,i]) for i in 1:length(batch_rewards_to_go)].-batch_old_log_probs).*batch_Adv g.(ϵ,batch_Adv)],dims=2))
            end
            push!(batch_losses, val)
            Flux.update!(opt_state, policy, grads[1])   
        end
        push!(epoch_losses, mean(batch_losses))
        if(length(epoch_losses)>2 && abs(epoch_losses[end]-epoch_losses[end-1]) <= tol)
            break
        end   
        new_log_probs = reshape([policy_log_prob(policy,states_list[:,i],acts_list[:,i]) for i in 1:length(rewards_to_go)],size(Adv_list)...)
        log_ratio = new_log_probs.-old_log_probs
        KL = mean((exp.(log_ratio).- 1).-log_ratio)
        if((KL) >= 15f-1 * targ_KL)
            break
        end
    end
    return epoch_losses,KL
end


function PPO(env::Union{PretrainingGateEnv,TrainingGateEnv};
             trajectory_batch_size::Int64=20, 
             epochs::Int64=100,
             initial_policy::Union{GatePolicy,Nothing}=nothing,
             initial_VNN::Union{Chain,Nothing}=nothing,
             l::Vector{Int64}=[16,16,16],
             max_iter::Int64 = 100, 
             vf_fit_lr::Float32 = 1f-2,
             pi_fit_lr::Float32 = 3f-5,
             fit_batch_size::Int64=32,
             ϵ::Float32 = 2f-1,
             verbose::Bool=true,
             γ::Float32=Float32(0.99),
             λ::Float32=Float32(0.97),
            tol::Float32=Float32(-1),#1f-4,
             trajectory_kwargs...)

    policy = isnothing(initial_policy) ? GatePolicy(env;l=l) : initial_policy
    VNN = isnothing(initial_VNN) ? ValueNetwork(env;l=l) : initial_VNN

    for epoch in 1:epochs
        rewards_to_go = Vector{Vector{Float32}}()
        Adv_list =   Vector{Vector{Float32}}()
        acts_list = Vector{Vector{Float32}}()
        states_list = Vector{Vector{Float32}}()
        for i in 1:trajectory_batch_size
            rewards,acts,states = SampleTrajectory(policy,env;trajectory_kwargs...)
            rtg,adv = GAE(states,rewards,acts,VNN;γ=γ,λ=λ)
            
            rewards_to_go = vcat(rewards_to_go,rtg)
            Adv_list = vcat(Adv_list,adv)
            
            states_list=vcat(states_list,states)
            acts_list=vcat(acts_list,acts)
        end
        
        acts_list = Matrix{Float32}([acts_list[j][i] for i=1:size(acts_list[1])[1], j=1:size(acts_list)[1]])
        states_list = Matrix{Float32}([states_list[j][i] for i=1:size(states_list[1])[1], j=1:size(states_list)[1]])
        rewards_to_go =  Matrix{Float32}(reshape(rewards_to_go,1,length(rewards_to_go)))
        Adv_list =  Matrix{Float32}(reshape(Adv_list,1,length(Adv_list)))
        Adv_list = (Adv_list.-mean(Adv_list))/std(Adv_list)
        
        policy_losses,KL = clip_optimize(policy,rewards_to_go,states_list,acts_list,Adv_list;ϵ=ϵ,max_iter= max_iter, lr = pi_fit_lr, tol = tol, batchsize=fit_batch_size)
        value_losses =FitValueVNN(VNN,states_list,rewards_to_go; max_iter= max_iter, lr= vf_fit_lr, tol = tol,batchsize=fit_batch_size)
        if(verbose)
            @printf "Epoch %i Complete\n" epoch
            @printf "Mean Rtg: %.2f\n" mean(rewards_to_go)
            @printf "KL: %.2f\n" KL
            @printf "Policy Iters: %i\n" length(policy_losses)
            @printf "Value Iters: %i\n" length(value_losses)
            @printf "Value Loss 1: %.2f\n" value_losses[1]
            @printf "Value Loss end: %.2f\n" value_losses[end]

            println("-------------------------")
            flush(stdout)
        end
    end
    return policy,VNN
end

PPO (generic function with 1 method)

In [None]:
policy,vnn = PPO(Pretraining_Env;epochs=250,)

In [None]:
x = range(0,2*pi,1000)
y = [policy.std_network([v])[1] for v in x]
using CairoMakie
lines(x,y)

In [None]:
RLBase.reset!(Pretraining_Env,angle = [range(0,2*pi,11)[5]])

In [None]:
rewards,acts,states =SampleTrajectory(policy,Pretraining_Env;deterministic=true)
rtg,adv = GAE(states,rewards,acts,vnn);

In [None]:
get_t

In [None]:
getTrajectoryLoss

In [None]:
sum(rewards)

In [None]:
RZ_traj[:a5]

In [None]:
fig = Figure()
ax = Axis(fig[1,1])
lines!(ax,1:n_steps,Pretraining_Env.a[1,:])
lines!(ax,1:n_steps,RZ_traj[:a5][1,:])
lines!(ax,1:n_steps,abs.(RZ_traj[:a5][1,:]-Pretraining_Env.a[1,:]).^2)

fig

In [None]:
fig = Figure()
ax = Axis(fig[1,1])
lines!(ax,1:n_steps,Pretraining_Env.da[1,:])
lines!(ax,1:n_steps,RZ_traj[:da5][1,:])
lines!(ax,1:n_steps,abs.(RZ_traj[:da5][1,:]-Pretraining_Env.da[1,:]).^2)

fig