Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neural optimal control documentation #174

ChrisRackauckas opened this issue Feb 28, 2020 · 9 comments

Neural optimal control documentation #174

ChrisRackauckas opened this issue Feb 28, 2020 · 9 comments


Copy link

Here is a classical u(t) optimal control problem trained:

using DiffEqFlux, Flux, Optim, OrdinaryDiffEq, Plots
tspan = (0.0f0,1.0f0)
ann = FastChain(FastDense(1,64,sin), FastDense(64,64,sin), FastDense(64,1))
θ = initial_params(ann)
function dxdt_(dx,x,p,t)
    x1, x2 = x
    dx[1] = x[2]
    dx[2] = ann([t],p)[1]
x0 = [-4f0,0f0]
ts = Float32.(collect(0.0:0.01:1.0))
prob = ODEProblem(dxdt_,x0,tspan,θ)
function predict_adjoint(θ)
function loss_adjoint(θ)
  x = predict_adjoint(θ)
  sum(abs2,4.0 .- x[1,:]) + 2sum(abs2,x[2,:]) + sum(abs2,[first(ann([t],θ)) for t in ts])
l = loss_adjoint(θ)
cb = function (θ,l)
  return false
# Display the ODE with the current parameter values.
loss1 = loss_adjoint(θ)
res = DiffEqFlux.sciml_train(loss_adjoint, θ, BFGS(), cb = cb)

should add it to the docs.

Copy link

AlCap23 commented Mar 16, 2020

Just out of curiosity:
Is this a MWE which trains?

Works fine for me with

res = DiffEqFlux.sciml_train(loss_adjoint, θ, ADAM(5e-1),maxiters = 1000, cb = cb)

but does not really train with BFGS() as an optimizer. The optimizer seems to have issues to overcome 6.4e3.

I am using Julia 1.3 and the current releases of all packages.

Copy link
Member Author

It trains. BFGS is horrible to start with though, so I would suggest doing a few iterations of ADAM before BFGS.

ChrisRackauckas added a commit that referenced this issue May 2, 2020
However, has a weird Zygote bug:

Can't differentiate gc_preserve_end expression
error(::String) at error.jl:33
getindex at essentials.jl:591 [inlined]
(::typeof(∂(getindex)))(::Nothing) at interface2.jl:0
iterate at essentials.jl:603 [inlined]
(::typeof(∂(iterate)))(::Nothing) at interface2.jl:0
_compute_eltype at tuple.jl:117 [inlined]
(::typeof(∂(_compute_eltype)))(::Nothing) at interface2.jl:0
eltype at tuple.jl:110 [inlined]
eltype at namedtuple.jl:145 [inlined]
Pairs at iterators.jl:169 [inlined]
pairs at iterators.jl:226 [inlined]
(::typeof(∂(pairs)))(::Nothing) at interface2.jl:0
Type at diffeqfunction.jl:388 [inlined]
(::typeof(∂(Type##kw)))(::Nothing) at interface2.jl:0
NeuralODEMM at neural_de.jl:305 [inlined]
(::typeof(∂(λ)))(::Array{Float64,2}) at interface2.jl:0
predict_n_dae at neural_ode_mm.jl:26 [inlined]
(::typeof(∂(predict_n_dae)))(::Array{Float64,2}) at interface2.jl:0
loss at neural_ode_mm.jl:29 [inlined]
(::Zygote.var"#174#175"{typeof(∂(loss)),Tuple{Tuple{Nothing},Int64}})(::Tuple{Int64,Nothing}) at lib.jl:182
(::Zygote.var"#347#back#176"{Zygote.var"#174#175"{typeof(∂(loss)),Tuple{Tuple{Nothing},Int64}}})(::Tuple{Int64,Nothing}) at adjoint.jl:49
#34 at train.jl:176 [inlined]
(::typeof(∂(λ)))(::Int64) at interface2.jl:0
#36 at interface.jl:36 [inlined]
(::DiffEqFlux.var"#37#50"{DiffEqFlux.var"#34#47"{typeof(loss)}})(::Array{Float32,1}, ::Array{Float32,1}) at train.jl:199
value_gradient!!(::TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}) at interface.jl:82
initial_state(::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}, ::Optim.Options{Float64,DiffEqFlux.var"#_cb#46"{var"#11#12",Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}}}, ::TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}) at bfgs.jl:66
optimize(::TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}, ::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}, ::Optim.Options{Float64,DiffEqFlux.var"#_cb#46"{var"#11#12",Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}}}) at optimize.jl:33
sciml_train(::Function, ::Array{Float32,1}, ::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, diffmode::DiffEqFlux.ZygoteDiffMode, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at train.jl:269
sciml_train at train.jl:163 [inlined]
(::DiffEqFlux.var"#sciml_train##kw")(::NamedTuple{(:cb, :maxiters),Tuple{var"#11#12",Int64}}, ::typeof(DiffEqFlux.sciml_train), ::Function, ::Array{Float32,1}, ::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}) at train.jl:163
top-level scope at neural_ode_mm.jl:40


using Flux, DiffEqFlux, OrdinaryDiffEq, Optim, Test
#A desired MWE for now, not a test yet.
function f(du,u,p,t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p
    du[1] = -k₁*y₁ + k₃*y₂*y₃
    du[2] =  k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
    du[3] =  y₁ + y₂ + y₃ - 1
u₀ = [1.0, 0, 0]
M = [1. 0  0
     0  1. 0
     0  0  0]
tspan = (0.0,1.0)
p = [0.04,3e7,1e4]
func = ODEFunction(f,mass_matrix=M)
prob = ODEProblem(func,u₀,tspan,p)
sol = solve(prob,Rodas5(),saveat=0.1)

dudt2 = FastChain(FastDense(3,64,tanh),FastDense(64,2))
ndae = NeuralODEMM(dudt2, (u,p,t) -> [u[1] + u[2] + u[3] - 1], tspan, M, Rodas5(autodiff=false),saveat=0.1)

function predict_n_dae(p)
function loss(p)
    pred = predict_n_dae(p)
    loss = sum(abs2,sol .- pred)

cb = function (p,l,pred) #callback function to observe training
  return false

l1 = first(loss(ndae.p))
res = DiffEqFlux.sciml_train(loss, ndae.p, BFGS(initial_stepnorm = 0.001), cb = cb, maxiters = 100)
Copy link
Member Author

I realized that the problem in the start isn't too interesting because the cost of using the control is so high that it pretty much never does anything... instead the following is much more interesting:

using DiffEqFlux, Flux, Optim, OrdinaryDiffEq, Plots, Statistics
tspan = (0.0f0,8.0f0)
ann = FastChain(FastDense(1,32,tanh), FastDense(32,32,tanh), FastDense(32,1))
θ = initial_params(ann)
function dxdt_(dx,x,p,t)
    x1, x2 = x
    dx[1] = x[2]
    dx[2] = ann([t],p)[1]^3
x0 = [-4f0,0f0]
ts = Float32.(collect(0.0:0.01:tspan[2]))
prob = ODEProblem(dxdt_,x0,tspan,θ)
function predict_adjoint(θ)
function loss_adjoint(θ)
  x = predict_adjoint(θ)
  mean(abs2,4.0 .- x[1,:]) + 2mean(abs2,x[2,:]) + mean(abs2,[first(ann([t],θ)) for t in ts])/10
l = loss_adjoint(θ)
cb = function (θ,l)
  p = plot(solve(remake(prob,p=θ),Tsit5(),saveat=0.01),ylim=(-6,6),lw=3)
  plot!(p,ts,[first(ann([t],θ)) for t in ts],label="u(t)",lw=3)
  return false
# Display the ODE with the current parameter values.
loss1 = loss_adjoint(θ)
res1 = DiffEqFlux.sciml_train(loss_adjoint, θ, ADAM(0.005), cb = cb,maxiters=100)
res2 = DiffEqFlux.sciml_train(loss_adjoint, res1.minimizer, BFGS(initial_stepnorm=0.01), cb = cb,maxiters=100)

function loss_adjoint(θ)
  x = predict_adjoint(θ)
  mean(abs2,4.0 .- x[1,:]) + 2mean(abs2,x[2,:]) + mean(abs2,[first(ann([t],θ)) for t in ts])

res3 = DiffEqFlux.sciml_train(loss_adjoint, res2.minimizer, BFGS(initial_stepnorm=0.01), cb = cb,maxiters=100)

l = loss_adjoint(res3.minimizer)
p = plot(solve(remake(prob,p=res3.minimizer),Tsit5(),saveat=0.01),ylim=(-6,6),lw=3)
plot!(p,ts,[first(ann([t],res3.minimizer)) for t in ts],label="u(t)",lw=3)


We might want to pick a problem with a known analytical solution for the documentation.

Copy link

Am I reading this right that your dynamics are xdotdot = u(t)^3?

I would look at the examples here. In particular, the Van der Pol oscillator in Example 6.

Copy link
Member Author

Yup that's it.

Copy link
Member Author

I added this example to the documentation, but will leave this open because it would be nice to add a full exploration of that Van der Pol example.

Copy link

FYI, I am getting ReverseDiffVJP not defined and need to add using DiffEqSensitivity.

  [aae7a2af] DiffEqFlux v1.10.0
  [587475ba] Flux v0.10.4
  [429524aa] Optim v0.21.0
  [1dea7af3] OrdinaryDiffEq v5.37.0
  [91a5bcdd] Plots v1.2.4
  [10745b16] Statistics 

Copy link
Member Author

Oh, do using DiffEqSensitivity

Copy link
Member Author

Yeah it was missing from the tutorial. Added.

ChrisRackauckas added a commit that referenced this issue May 30, 2020
However, has a weird Zygote bug:

Can't differentiate gc_preserve_end expression
error(::String) at error.jl:33
getindex at essentials.jl:591 [inlined]
(::typeof(∂(getindex)))(::Nothing) at interface2.jl:0
iterate at essentials.jl:603 [inlined]
(::typeof(∂(iterate)))(::Nothing) at interface2.jl:0
_compute_eltype at tuple.jl:117 [inlined]
(::typeof(∂(_compute_eltype)))(::Nothing) at interface2.jl:0
eltype at tuple.jl:110 [inlined]
eltype at namedtuple.jl:145 [inlined]
Pairs at iterators.jl:169 [inlined]
pairs at iterators.jl:226 [inlined]
(::typeof(∂(pairs)))(::Nothing) at interface2.jl:0
Type at diffeqfunction.jl:388 [inlined]
(::typeof(∂(Type##kw)))(::Nothing) at interface2.jl:0
NeuralODEMM at neural_de.jl:305 [inlined]
(::typeof(∂(λ)))(::Array{Float64,2}) at interface2.jl:0
predict_n_dae at neural_ode_mm.jl:26 [inlined]
(::typeof(∂(predict_n_dae)))(::Array{Float64,2}) at interface2.jl:0
loss at neural_ode_mm.jl:29 [inlined]
(::Zygote.var"#174#175"{typeof(∂(loss)),Tuple{Tuple{Nothing},Int64}})(::Tuple{Int64,Nothing}) at lib.jl:182
(::Zygote.var"#347#back#176"{Zygote.var"#174#175"{typeof(∂(loss)),Tuple{Tuple{Nothing},Int64}}})(::Tuple{Int64,Nothing}) at adjoint.jl:49
#34 at train.jl:176 [inlined]
(::typeof(∂(λ)))(::Int64) at interface2.jl:0
#36 at interface.jl:36 [inlined]
(::DiffEqFlux.var"#37#50"{DiffEqFlux.var"#34#47"{typeof(loss)}})(::Array{Float32,1}, ::Array{Float32,1}) at train.jl:199
value_gradient!!(::TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}) at interface.jl:82
initial_state(::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}, ::Optim.Options{Float64,DiffEqFlux.var"#_cb#46"{var"#11#12",Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}}}, ::TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}) at bfgs.jl:66
optimize(::TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}, ::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}, ::Optim.Options{Float64,DiffEqFlux.var"#_cb#46"{var"#11#12",Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}}}) at optimize.jl:33
sciml_train(::Function, ::Array{Float32,1}, ::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, diffmode::DiffEqFlux.ZygoteDiffMode, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at train.jl:269
sciml_train at train.jl:163 [inlined]
(::DiffEqFlux.var"#sciml_train##kw")(::NamedTuple{(:cb, :maxiters),Tuple{var"#11#12",Int64}}, ::typeof(DiffEqFlux.sciml_train), ::Function, ::Array{Float32,1}, ::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}) at train.jl:163
top-level scope at neural_ode_mm.jl:40


using Flux, DiffEqFlux, OrdinaryDiffEq, Optim, Test
#A desired MWE for now, not a test yet.
function f(du,u,p,t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p
    du[1] = -k₁*y₁ + k₃*y₂*y₃
    du[2] =  k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
    du[3] =  y₁ + y₂ + y₃ - 1
u₀ = [1.0, 0, 0]
M = [1. 0  0
     0  1. 0
     0  0  0]
tspan = (0.0,1.0)
p = [0.04,3e7,1e4]
func = ODEFunction(f,mass_matrix=M)
prob = ODEProblem(func,u₀,tspan,p)
sol = solve(prob,Rodas5(),saveat=0.1)

dudt2 = FastChain(FastDense(3,64,tanh),FastDense(64,2))
ndae = NeuralODEMM(dudt2, (u,p,t) -> [u[1] + u[2] + u[3] - 1], tspan, M, Rodas5(autodiff=false),saveat=0.1)

function predict_n_dae(p)
function loss(p)
    pred = predict_n_dae(p)
    loss = sum(abs2,sol .- pred)

cb = function (p,l,pred) #callback function to observe training
  return false

l1 = first(loss(ndae.p))
res = DiffEqFlux.sciml_train(loss, ndae.p, BFGS(initial_stepnorm = 0.001), cb = cb, maxiters = 100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
None yet
None yet

No branches or pull requests

3 participants