-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add neural DAE support #159
Conversation
This would need a test to build on, @ChrisRackauckas can you point me to an example how that should like? |
src/neural_de.jl
Outdated
@@ -167,3 +167,29 @@ function (n::NeuralCDDE)(x,p=n.p) | |||
prob = DDEProblem{false}(dudt_,x,n.hist,n.tspan,p,constant_lags = n.lags) | |||
concrete_solve(prob,n.solver,x,p,n.args...;sensealg=TrackerAdjoint(),n.kwargs...) | |||
end | |||
|
|||
struct NeuralDAE{P,M,RE,T,S,DV,A,K} <: NeuralDELayer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not cover the mass matrix form. That should be added as per https://docs.juliadiffeq.org/latest/tutorials/advanced_ode_example/#Handling-Mass-Matrices-1 and the discussion in the slack thread.
A mass matrix one would be more useful, but this also is just incorrect. How does one specify the constraint equations? The algebraic variables? For a test DAE, use something non-stiff. The ROBER equations for a short time are. function f(out,du,u,p,t)
out[1] = - 0.04u[1] + 1e4*u[2]*u[3] - du[1]
out[2] = + 0.04u[1] - 3e7*u[2]^2 - 1e4*u[2]*u[3] - du[2]
out[3] = u[1] + u[2] + u[3] - 1.0
end
u₀ = [1.0, 0, 0]
du₀ = [-0.04, 0.04, 0.0]
tspan = (0.0,10.0)
using OrdinaryDiffEq
differential_vars = [true,true,false]
prob = DAEProblem(f,du₀,u₀,tspan,differential_vars=differential_vars)
sol = solve(prob,DABDF2()) which should be differentiable with Tracker. |
3a0393b
to
7c90a6a
Compare
end | ||
dudt_(u,p,t) = f | ||
prob = DAEProblem(dudt_,du0,x,n.tspan,p,differential_vars=n.differential_vars) | ||
concrete_solve(prob,n.solver,x,p,n.args...;n.kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default to sensealg = TrackerAdjoint()
src/neural_de.jl
Outdated
|
||
function (n::NeuralDAE)(x,du0=n.du0,p=n.p) | ||
function f(u,p,t) | ||
vcat([n.differential_vars[i] == 1 ? n.re(p)(u[i]) : n.constraints_model(u[i],p,t) for i in 1:length(n.differential_vars)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't quite correct. n.constraints_model(u,p,t)
should return a vector of constraints. n.re(p)(u)
returns a vector. Then n.differential_vars[i] == true
is a better check for decomposing / concatenting.
test/neural_dae.jl
Outdated
@@ -0,0 +1,22 @@ | |||
using Flux, DiffEqFlux, OrdinaryDiffEq |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test isn't called until it's added to runtests.jl
Codecov Report
@@ Coverage Diff @@
## master #159 +/- ##
==========================================
- Coverage 75.86% 68.13% -7.73%
==========================================
Files 4 4
Lines 174 204 +30
==========================================
+ Hits 132 139 +7
- Misses 42 65 +23
Continue to review full report at Codecov.
|
8459b12
to
d502d98
Compare
src/neural_de.jl
Outdated
Flux.@functor NeuralDAE | ||
|
||
function (n::NeuralDAE)(x,p=n.p) | ||
function f(du,u,p,t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use the oop form
src/neural_de.jl
Outdated
function f(du,u,p,t) | ||
nn_out = n.re(p)(u) | ||
alg_out = n.constraints_model(u,p,t) | ||
du .= vcat(nn_out,alg_out) | ||
nothing | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function f(du,u,p,t) | |
nn_out = n.re(p)(u) | |
alg_out = n.constraints_model(u,p,t) | |
du .= vcat(nn_out,alg_out) | |
nothing | |
end | |
function f(u,p,t) | |
nn_out = n.re(p)(u) | |
alg_out = n.constraints_model(u,p,t) | |
vcat(nn_out,alg_out) | |
end |
@ChrisRackauckas @Vaibhavdixit02