Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NNODE for DAE Problems #695

Closed
wants to merge 7 commits into from

Conversation

sdesai1287
Copy link
Contributor

@sdesai1287 sdesai1287 commented Jun 27, 2023

Initial changes to ode_solve, added NNDAE to get this solver working for DAE problems

@sdesai1287 sdesai1287 closed this Jun 27, 2023
@sdesai1287 sdesai1287 reopened this Jun 27, 2023
@sdesai1287 sdesai1287 changed the title Initial changes to ode_solve, added NNDAE to get this solver working … NNODE for DAE Problems Jun 27, 2023
@sdesai1287 sdesai1287 marked this pull request as ready for review June 27, 2023 04:52
@sdesai1287 sdesai1287 marked this pull request as draft June 27, 2023 04:53
@sdesai1287
Copy link
Contributor Author

sdesai1287 commented Jun 27, 2023

Hi @ChrisRackauckas, I wanted to check with you to see if the theory that I have added in is correct. I am writing out some tests to verify that it works and will add them in the test file shortly

@sdesai1287
Copy link
Contributor Author

sdesai1287 commented Jun 27, 2023

I solved a bunch of bugs with this, but when I run this test file I get this long error that I do not really know how to solve. I believe the error starts at line 300, but my logic appears correct there so not sure what the problem is
Screen Shot 2023-06-27 at 1 28 41 AM

@xtalax
Copy link
Member

xtalax commented Jul 5, 2023

I think the problem is the argument signature that you have called f with. You have provided a path for the case where f is an out of place function like f(u, p, t), but you must also provide a path for the case where it is in place like f!(du, u, p, t), which commonly happens as it is a pattern for better speed in iterative solves.

You also need to be careful about doing anything to p without checking if It is a SciMLBase.NullParameters object, which doesn't support indexing or any other operation really. All operations on p should happen in an if else block, first checking whether it is of this type before doing anything fancy.

@sdesai1287
Copy link
Contributor Author

sdesai1287 commented Jul 18, 2023

Hi @xtalax I have tried a several things but nothing seems to work. However, I think NNODE itself doesnt work for in place problems, so modifying it for the DAE case for in place methods is probably not going to work either. I am trying to find a DAE problem to test that is out of place, but both the DAEProblems from the DAE Problem Library (from DiffEqProblemLibrary.jl), and their corresponding DAEFunctions are in place. If you can direct me to an out of place DAE problem to test (or some other solution for this issue), that would be greatly appreciated

Furthermore, if you think I should pursue making NNODE (and then DAE Problems) work for in place problems, I can do that too. Lastly, I did not intend to do anything to p, and the solver should not need p to work. However, if you see something where I am unintentionally indexing p, please let me know and I will do my best to address it

@ChrisRackauckas
Copy link
Member

I think the problem is the argument signature that you have called f with. You have provided a path for the case where f is an out of place function like f(u, p, t), but you must also provide a path for the case where it is in place like f!(du, u, p, t), which commonly happens as it is a pattern for better speed in iterative solves.

Kind of. DAEProblems are represented via f(out,du,u,p,t) and f(du,u,p,t). See https://docs.sciml.ai/DiffEqDocs/stable/tutorials/dae_example/#Implicitly-Defined-Differential-Algebraic-Equations-(DAEs).

src/ode_solve.jl Outdated Show resolved Hide resolved
src/ode_solve.jl Outdated Show resolved Hide resolved
@ChrisRackauckas
Copy link
Member

The error is pointing to the fact that you're using the wrong function signature. The out of place DAE is 4 arguments, not 3. It needs to know the current state, not just the derivative.

@sdesai1287
Copy link
Contributor Author

Hi @ChrisRackauckas I tried adding u like you said but it still gives me an error for u being undefined at line 302. Going to keep working on this

@ChrisRackauckas
Copy link
Member

Your test case wasn't a valid out of place DAEProblem. See my suggestion. That should do fine.

src/ode_solve.jl Outdated Show resolved Hide resolved
src/ode_solve.jl Outdated Show resolved Hide resolved
@sdesai1287
Copy link
Contributor Author

sdesai1287 commented Jul 26, 2023

I made some fixes to this but I still have the following error:
Screen Shot 2023-07-26 at 7 47 51 PM

I think my new problem definition is a valid out of place DAE problem. My primary confusion lies with when to use phi vs phi(t, θ) vs Array(phi(t, θ)) when generating the loss function. The error is caught here

@ChrisRackauckas
Copy link
Member

what's your MWE?

@sdesai1287
Copy link
Contributor Author

what's your MWE?

Does that mean minimum working example? If so, I dont really have one. I tried using a simpler chain but it solved nothing, and I used a problem directly from the DAEProblemLibrary (and from you) so I doubt that is the issue either. Not too sure how to resolve this

src/ode_solve.jl Outdated

function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
p) where {C, T, U <: Number}
sum(abs2, ode_dfdx(phi, t, θ, autodiff) - f(ode_dfdx(phi, t, θ, autodiff), phi, p, t))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't make sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the output of a dae function is the residual, so there's nothing to subtract

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I just make that line sum(abs2, ode_dfdx(phi, t, θ, autodiff)) then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, you need to evaluate f... sum(abs2,f(ode_dfdx(phi, t, θ, autodiff), phi(t, θ), p, t)))

src/ode_solve.jl Outdated
p) where {C, T, U <: Number}
out = phi(t, θ)
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
sum(abs2, f(dxdtguess[i], phi, p, t[i]) for i in 1:size(out, 2)) / length(t)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can't pass phi to f, it's a function.

src/ode_solve.jl Outdated
p) where {C, T, U <: Number}
out = phi(t, θ)
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
sum(abs2, f(dxdtguess[i], phi, p, t[i]) for i in 1:size(out, 2)) / length(t)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sum(abs2, f(dxdtguess[i], phi, p, t[i]) for i in 1:size(out, 2)) / length(t)
sum(abs2, f(dxdtguess[i], out, p, t[i]) for i in 1:size(out, 2)) / length(t)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but you need out at each t

Copy link
Contributor Author

@sdesai1287 sdesai1287 Aug 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do I need out or out[:, i]? I think it should be out[:, i] but not 100% sure

src/ode_solve.jl Outdated
Comment on lines 294 to 302
sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(ode_dfdx(phi, t, θ, autodiff), phi, p, t))
end

function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,
p) where {C, T, U}
out = Array(phi(t, θ))
arrt = Array(t)
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
sum(abs2, f(dxdtguess[:, i], phi, p, arrt[i]) for i in 1:size(out, 2)) / length(t)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two have the same errors as above.

@ChrisRackauckas
Copy link
Member

What example are you working with to test? It should point exactly to the 4 issues I pointed out above.

@sdesai1287
Copy link
Contributor Author

What example are you working with to test? It should point exactly to the 4 issues I pointed out above.

Here is my test code, which should also be in the commit

using Optimisers, OptimizationOptimisers, Sundials
using NeuralPDE, Lux, Test, Statistics, Plots

f = function (yp, y, p, tres)
    [-0.04 * y[1] + 1.0e4 * y[2] * y[3] - yp[1],
     -(-0.04 * y[1] + 1.0e4 * y[2] * y[3]) - 3.0e7 * y[2] * y[2] - yp[2],
      y[1] + y[2] + y[3] - 1.0]
end
u0 = [1.0, 0, 0]
du0 = [-0.04, 0.04, 0.0]

prob_oop = DAEProblem{false}(f, du0, u0, (0.0, 100000.0))
true_sol = solve(prob_oop, IDA(), saveat = 0.01)

func = Lux.σ
N = 12
chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, N, func),
                    Lux.Dense(N, N, func), Lux.Dense(N, length(u0)))

opt = Optimisers.Adam(0.01)
dx = 0.05
alg = NeuralPDE.NNDAE(chain, opt, autodiff = false, strategy = NeuralPDE.GridTraining(dx))
sol = solve(prob_oop, alg, verbose=true, maxiters = 100000, saveat = 0.01)

This should be the Robertson biochemical reactions in DAE form. I've checked this over several times, so I am pretty sure this is not the problem, but there could certainly be something I missed

@sdesai1287
Copy link
Contributor Author

@ChrisRackauckas @xtalax I implemented the new edits that you suggested and I am pretty sure that the loss function generation is working properly now. However, the last step here of calling Julia's solve method within my DAEProblem solver runs infinitely without printing anything as part of the callback function. I think I am now very close on this, and any tips with debugging this would be greatly appreciated

@sdesai1287 sdesai1287 closed this Aug 17, 2023
@sdesai1287 sdesai1287 reopened this Aug 17, 2023
@ChrisRackauckas
Copy link
Member

Done in #790

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants