Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neural optimal control documentation #174

Closed
ChrisRackauckas opened this issue Feb 28, 2020 · 9 comments
Closed

Neural optimal control documentation #174

ChrisRackauckas opened this issue Feb 28, 2020 · 9 comments

Comments

@ChrisRackauckas
Copy link
Member

Here is a classical u(t) optimal control problem trained:

using DiffEqFlux, Flux, Optim, OrdinaryDiffEq, Plots
tspan = (0.0f0,1.0f0)
ann = FastChain(FastDense(1,64,sin), FastDense(64,64,sin), FastDense(64,1))
θ = initial_params(ann)
function dxdt_(dx,x,p,t)
    x1, x2 = x
    dx[1] = x[2]
    dx[2] = ann([t],p)[1]
end
x0 = [-4f0,0f0]
ts = Float32.(collect(0.0:0.01:1.0))
prob = ODEProblem(dxdt_,x0,tspan,θ)
concrete_solve(prob,Tsit5(),x0,θ,abstol=1e-8,reltol=1e-6)
function predict_adjoint(θ)
  Array(concrete_solve(prob,Tsit5(),x0,θ,saveat=ts))
end
function loss_adjoint(θ)
  x = predict_adjoint(θ)
  sum(abs2,4.0 .- x[1,:]) + 2sum(abs2,x[2,:]) + sum(abs2,[first(ann([t],θ)) for t in ts])
end
l = loss_adjoint(θ)
cb = function (θ,l)
  println(l)
  display(plot(solve(remake(prob,p=θ),Tsit5(),saveat=0.01),ylim=(-6,6)))
  return false
end
# Display the ODE with the current parameter values.
cb(θ,l)
loss1 = loss_adjoint(θ)
res = DiffEqFlux.sciml_train(loss_adjoint, θ, BFGS(), cb = cb)

should add it to the docs.

@AlCap23
Copy link

AlCap23 commented Mar 16, 2020

Just out of curiosity:
Is this a MWE which trains?

Works fine for me with

res = DiffEqFlux.sciml_train(loss_adjoint, θ, ADAM(5e-1),maxiters = 1000, cb = cb)

but does not really train with BFGS() as an optimizer. The optimizer seems to have issues to overcome 6.4e3.

I am using Julia 1.3 and the current releases of all packages.

@ChrisRackauckas
Copy link
Member Author

It trains. BFGS is horrible to start with though, so I would suggest doing a few iterations of ADAM before BFGS.

ChrisRackauckas added a commit that referenced this issue May 2, 2020
However, has a weird Zygote bug:

```julia
Can't differentiate gc_preserve_end expression
error(::String) at error.jl:33
getindex at essentials.jl:591 [inlined]
(::typeof(∂(getindex)))(::Nothing) at interface2.jl:0
iterate at essentials.jl:603 [inlined]
(::typeof(∂(iterate)))(::Nothing) at interface2.jl:0
_compute_eltype at tuple.jl:117 [inlined]
(::typeof(∂(_compute_eltype)))(::Nothing) at interface2.jl:0
eltype at tuple.jl:110 [inlined]
eltype at namedtuple.jl:145 [inlined]
Pairs at iterators.jl:169 [inlined]
pairs at iterators.jl:226 [inlined]
(::typeof(∂(pairs)))(::Nothing) at interface2.jl:0
Type at diffeqfunction.jl:388 [inlined]
(::typeof(∂(Type##kw)))(::Nothing) at interface2.jl:0
NeuralODEMM at neural_de.jl:305 [inlined]
(::typeof(∂(λ)))(::Array{Float64,2}) at interface2.jl:0
predict_n_dae at neural_ode_mm.jl:26 [inlined]
(::typeof(∂(predict_n_dae)))(::Array{Float64,2}) at interface2.jl:0
loss at neural_ode_mm.jl:29 [inlined]
(::Zygote.var"#174#175"{typeof(∂(loss)),Tuple{Tuple{Nothing},Int64}})(::Tuple{Int64,Nothing}) at lib.jl:182
(::Zygote.var"#347#back#176"{Zygote.var"#174#175"{typeof(∂(loss)),Tuple{Tuple{Nothing},Int64}}})(::Tuple{Int64,Nothing}) at adjoint.jl:49
#34 at train.jl:176 [inlined]
(::typeof(∂(λ)))(::Int64) at interface2.jl:0
#36 at interface.jl:36 [inlined]
(::DiffEqFlux.var"#37#50"{DiffEqFlux.var"#34#47"{typeof(loss)}})(::Array{Float32,1}, ::Array{Float32,1}) at train.jl:199
value_gradient!!(::TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}) at interface.jl:82
initial_state(::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}, ::Optim.Options{Float64,DiffEqFlux.var"#_cb#46"{var"#11#12",Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}}}, ::TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}) at bfgs.jl:66
optimize(::TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}, ::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}, ::Optim.Options{Float64,DiffEqFlux.var"#_cb#46"{var"#11#12",Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}}}) at optimize.jl:33
sciml_train(::Function, ::Array{Float32,1}, ::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, diffmode::DiffEqFlux.ZygoteDiffMode, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at train.jl:269
sciml_train at train.jl:163 [inlined]
(::DiffEqFlux.var"#sciml_train##kw")(::NamedTuple{(:cb, :maxiters),Tuple{var"#11#12",Int64}}, ::typeof(DiffEqFlux.sciml_train), ::Function, ::Array{Float32,1}, ::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}) at train.jl:163
top-level scope at neural_ode_mm.jl:40
```

in

```julia
using Flux, DiffEqFlux, OrdinaryDiffEq, Optim, Test
#A desired MWE for now, not a test yet.
function f(du,u,p,t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p
    du[1] = -k₁*y₁ + k₃*y₂*y₃
    du[2] =  k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
    du[3] =  y₁ + y₂ + y₃ - 1
    nothing
end
u₀ = [1.0, 0, 0]
M = [1. 0  0
     0  1. 0
     0  0  0]
tspan = (0.0,1.0)
p = [0.04,3e7,1e4]
func = ODEFunction(f,mass_matrix=M)
prob = ODEProblem(func,u₀,tspan,p)
sol = solve(prob,Rodas5(),saveat=0.1)

dudt2 = FastChain(FastDense(3,64,tanh),FastDense(64,2))
ndae = NeuralODEMM(dudt2, (u,p,t) -> [u[1] + u[2] + u[3] - 1], tspan, M, Rodas5(autodiff=false),saveat=0.1)
ndae(u₀)

function predict_n_dae(p)
    ndae(u₀,p)
end
function loss(p)
    pred = predict_n_dae(p)
    loss = sum(abs2,sol .- pred)
    loss,pred
end

cb = function (p,l,pred) #callback function to observe training
  display(l)
  return false
end

l1 = first(loss(ndae.p))
res = DiffEqFlux.sciml_train(loss, ndae.p, BFGS(initial_stepnorm = 0.001), cb = cb, maxiters = 100)
```
@ChrisRackauckas
Copy link
Member Author

I realized that the problem in the start isn't too interesting because the cost of using the control is so high that it pretty much never does anything... instead the following is much more interesting:

using DiffEqFlux, Flux, Optim, OrdinaryDiffEq, Plots, Statistics
tspan = (0.0f0,8.0f0)
ann = FastChain(FastDense(1,32,tanh), FastDense(32,32,tanh), FastDense(32,1))
θ = initial_params(ann)
function dxdt_(dx,x,p,t)
    x1, x2 = x
    dx[1] = x[2]
    dx[2] = ann([t],p)[1]^3
end
x0 = [-4f0,0f0]
ts = Float32.(collect(0.0:0.01:tspan[2]))
prob = ODEProblem(dxdt_,x0,tspan,θ)
concrete_solve(prob,Vern9(),x0,θ,abstol=1e-10,reltol=1e-10)
function predict_adjoint(θ)
  Array(concrete_solve(prob,Vern9(),x0,θ,saveat=ts,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
function loss_adjoint(θ)
  x = predict_adjoint(θ)
  mean(abs2,4.0 .- x[1,:]) + 2mean(abs2,x[2,:]) + mean(abs2,[first(ann([t],θ)) for t in ts])/10
end
l = loss_adjoint(θ)
cb = function (θ,l)
  println(l)
  p = plot(solve(remake(prob,p=θ),Tsit5(),saveat=0.01),ylim=(-6,6),lw=3)
  plot!(p,ts,[first(ann([t],θ)) for t in ts],label="u(t)",lw=3)
  display(p)
  return false
end
# Display the ODE with the current parameter values.
cb(θ,l)
loss1 = loss_adjoint(θ)
res1 = DiffEqFlux.sciml_train(loss_adjoint, θ, ADAM(0.005), cb = cb,maxiters=100)
res2 = DiffEqFlux.sciml_train(loss_adjoint, res1.minimizer, BFGS(initial_stepnorm=0.01), cb = cb,maxiters=100)

function loss_adjoint(θ)
  x = predict_adjoint(θ)
  mean(abs2,4.0 .- x[1,:]) + 2mean(abs2,x[2,:]) + mean(abs2,[first(ann([t],θ)) for t in ts])
end

res3 = DiffEqFlux.sciml_train(loss_adjoint, res2.minimizer, BFGS(initial_stepnorm=0.01), cb = cb,maxiters=100)

l = loss_adjoint(res3.minimizer)
cb(res3.minimizer,l)
p = plot(solve(remake(prob,p=res3.minimizer),Tsit5(),saveat=0.01),ylim=(-6,6),lw=3)
plot!(p,ts,[first(ann([t],res3.minimizer)) for t in ts],label="u(t)",lw=3)
savefig("optimal_control.png")

optimal_control

We might want to pick a problem with a known analytical solution for the documentation.

@agerlach
Copy link

Am I reading this right that your dynamics are xdotdot = u(t)^3?

I would look at the examples here. In particular, the Van der Pol oscillator in Example 6.

@ChrisRackauckas
Copy link
Member Author

Yup that's it.

@ChrisRackauckas
Copy link
Member Author

I added this example to the documentation, but will leave this open because it would be nice to add a full exploration of that Van der Pol example.

@agerlach
Copy link

FYI, I am getting ReverseDiffVJP not defined and need to add using DiffEqSensitivity.

  [aae7a2af] DiffEqFlux v1.10.0
  [587475ba] Flux v0.10.4
  [429524aa] Optim v0.21.0
  [1dea7af3] OrdinaryDiffEq v5.37.0
  [91a5bcdd] Plots v1.2.4
  [10745b16] Statistics 

@ChrisRackauckas
Copy link
Member Author

Oh, do using DiffEqSensitivity

@ChrisRackauckas
Copy link
Member Author

Yeah it was missing from the tutorial. Added.

ChrisRackauckas added a commit that referenced this issue May 30, 2020
However, has a weird Zygote bug:

```julia
Can't differentiate gc_preserve_end expression
error(::String) at error.jl:33
getindex at essentials.jl:591 [inlined]
(::typeof(∂(getindex)))(::Nothing) at interface2.jl:0
iterate at essentials.jl:603 [inlined]
(::typeof(∂(iterate)))(::Nothing) at interface2.jl:0
_compute_eltype at tuple.jl:117 [inlined]
(::typeof(∂(_compute_eltype)))(::Nothing) at interface2.jl:0
eltype at tuple.jl:110 [inlined]
eltype at namedtuple.jl:145 [inlined]
Pairs at iterators.jl:169 [inlined]
pairs at iterators.jl:226 [inlined]
(::typeof(∂(pairs)))(::Nothing) at interface2.jl:0
Type at diffeqfunction.jl:388 [inlined]
(::typeof(∂(Type##kw)))(::Nothing) at interface2.jl:0
NeuralODEMM at neural_de.jl:305 [inlined]
(::typeof(∂(λ)))(::Array{Float64,2}) at interface2.jl:0
predict_n_dae at neural_ode_mm.jl:26 [inlined]
(::typeof(∂(predict_n_dae)))(::Array{Float64,2}) at interface2.jl:0
loss at neural_ode_mm.jl:29 [inlined]
(::Zygote.var"#174#175"{typeof(∂(loss)),Tuple{Tuple{Nothing},Int64}})(::Tuple{Int64,Nothing}) at lib.jl:182
(::Zygote.var"#347#back#176"{Zygote.var"#174#175"{typeof(∂(loss)),Tuple{Tuple{Nothing},Int64}}})(::Tuple{Int64,Nothing}) at adjoint.jl:49
#34 at train.jl:176 [inlined]
(::typeof(∂(λ)))(::Int64) at interface2.jl:0
#36 at interface.jl:36 [inlined]
(::DiffEqFlux.var"#37#50"{DiffEqFlux.var"#34#47"{typeof(loss)}})(::Array{Float32,1}, ::Array{Float32,1}) at train.jl:199
value_gradient!!(::TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}) at interface.jl:82
initial_state(::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}, ::Optim.Options{Float64,DiffEqFlux.var"#_cb#46"{var"#11#12",Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}}}, ::TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}) at bfgs.jl:66
optimize(::TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}, ::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}, ::Optim.Options{Float64,DiffEqFlux.var"#_cb#46"{var"#11#12",Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}}}) at optimize.jl:33
sciml_train(::Function, ::Array{Float32,1}, ::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, diffmode::DiffEqFlux.ZygoteDiffMode, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at train.jl:269
sciml_train at train.jl:163 [inlined]
(::DiffEqFlux.var"#sciml_train##kw")(::NamedTuple{(:cb, :maxiters),Tuple{var"#11#12",Int64}}, ::typeof(DiffEqFlux.sciml_train), ::Function, ::Array{Float32,1}, ::BFGS{LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Nothing,Float64,Flat}) at train.jl:163
top-level scope at neural_ode_mm.jl:40
```

in

```julia
using Flux, DiffEqFlux, OrdinaryDiffEq, Optim, Test
#A desired MWE for now, not a test yet.
function f(du,u,p,t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p
    du[1] = -k₁*y₁ + k₃*y₂*y₃
    du[2] =  k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
    du[3] =  y₁ + y₂ + y₃ - 1
    nothing
end
u₀ = [1.0, 0, 0]
M = [1. 0  0
     0  1. 0
     0  0  0]
tspan = (0.0,1.0)
p = [0.04,3e7,1e4]
func = ODEFunction(f,mass_matrix=M)
prob = ODEProblem(func,u₀,tspan,p)
sol = solve(prob,Rodas5(),saveat=0.1)

dudt2 = FastChain(FastDense(3,64,tanh),FastDense(64,2))
ndae = NeuralODEMM(dudt2, (u,p,t) -> [u[1] + u[2] + u[3] - 1], tspan, M, Rodas5(autodiff=false),saveat=0.1)
ndae(u₀)

function predict_n_dae(p)
    ndae(u₀,p)
end
function loss(p)
    pred = predict_n_dae(p)
    loss = sum(abs2,sol .- pred)
    loss,pred
end

cb = function (p,l,pred) #callback function to observe training
  display(l)
  return false
end

l1 = first(loss(ndae.p))
res = DiffEqFlux.sciml_train(loss, ndae.p, BFGS(initial_stepnorm = 0.001), cb = cb, maxiters = 100)
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants