Skip to content

Commit

Permalink
Merge pull request #139 from SciML/ap/defaults
Browse files Browse the repository at this point in the history
Default to using SimpleGMRES for the backward pass
  • Loading branch information
avik-pal committed Dec 22, 2023
2 parents 8b1d48d + 64f84a1 commit f055519
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 12 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DeepEquilibriumNetworks"
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
authors = ["Avik Pal <avikpal@mit.edu>"]
version = "2.0.0"
version = "2.0.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -18,11 +18,12 @@ SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"

[weakdeps]
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DeepEquilibriumNetworksSciMLSensitivityExt = "SciMLSensitivity"
DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
DeepEquilibriumNetworksZygoteExt = "Zygote"

[compat]
Expand All @@ -32,6 +33,7 @@ ConcreteStructs = "0.2"
ConstructionBase = "1"
DiffEqBase = "6.119"
LinearAlgebra = "1"
LinearSolve = "2.21.2"
Lux = "0.5.11"
Random = "1"
SciMLBase = "2"
Expand Down
38 changes: 38 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,44 @@ To construct a continuous DEQ, any ODE solver compatible with `DifferentialEquat
can be passed as the solver. To construct a discrete DEQ, any root finding algorithm
compatible with `NonlinearSolve.jl` API can be passed as the solver.

## Choosing a Solver

### Root Finding Algorithms

Using Root Finding Algorithms give fast convergence when possible, but these methods also
tend to be unstable. If you must use a root finding algorithm, we recommend using:

1. `NewtonRaphson` or `TrustRegion` for small models
2. `LimitedMemoryBroyden` for large Deep Learning applications (with well-conditioned
Jacobians)
3. `NewtonRaphson(; linsolve = KrylovJL_GMRES())` for cases when Broyden methods fail

Note that Krylov Methods rely on efficient VJPs which are not available for all Lux models.
If you think this is causing a performance regression, please open an issue in
[Lux.jl](https://github.com/LuxDL/Lux.jl).

### ODE Solvers

Using ODE Solvers give slower convergence, but are more stable. We generally recommend these
methods over root finding algorithms. If you use implicit ODE solvers, remember to use
Krylov linear solvers, see OrdinaryDiffEq.jl documentation for these. For most cases, we
recommend:

1. `VCAB3()` for high tolerance problems
2. `Tsit5()` for high tolerance problems where `VCAB3()` fails
3. In all other cases, follow the recommendation given in [OrdinaryDiffEq.jl](https://docs.sciml.ai/DiffEqDocs/stable/solvers/ode_solve/#ode_solve) documentation

### Sensitivity Analysis

1. For `MultiScaleNeuralODE`, we default to `GaussAdjoint(; autojacvec = ZygoteVJP())`. A
faster alternative would be `BacksolveAdjoint(; autojacvec = ZygoteVJP())` but there are
stability concerns for using that. Follow the recommendation given in [SciMLSensitivity.jl](https://docs.sciml.ai/SciMLSensitivity/stable/manual/differential_equation_sensitivities/#Choosing-a-Sensitivity-Algorithm) documentation.
2. For Steady State Problems, we default to
`SteadyStateAdjoint(; linsolve = SimpleGMRES(; blocksize, linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3)))`.
This default will perform poorly on small models. It is recommended to pass
`sensealg = SteadyStateAdjoint()` or
`sensealg = SteadyStateAdjoint(; linsolve = LUFactorization())` for small models.

## Standard Models

```@docs
Expand Down
18 changes: 18 additions & 0 deletions ext/DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt

# Linear Solve is a dependency of SciMLSensitivity, so we only need to load SciMLSensitivity
# to load this extension
using LinearSolve, SciMLBase, SciMLSensitivity
import DeepEquilibriumNetworks: __default_sensealg

@inline function __default_sensealg(prob::SteadyStateProblem)
# We want to avoid the cost for cache construction for linsolve = nothing
# For small problems we should use concrete jacobian but we assume users want to solve
# large problems with this package so we default to GMRES and avoid runtime dispatches
linsolve = SimpleGMRES{true}(; blocksize=prod(size(prob.u0)[1:(end - 1)]))
linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3)
return SteadyStateAdjoint(; linsolve, linsolve_kwargs, autojacvec=ZygoteVJP())
end
@inline __default_sensealg(::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP())

end
10 changes: 0 additions & 10 deletions ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl

This file was deleted.

2 comments on commit f055519

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/97635

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.0.1 -m "<description of version>" f0555193c1a497920fcffe38b81352dbb16fed3e
git push origin v2.0.1

Please sign in to comment.