Skip to content

Commit

Permalink
Fix the dispatch on polyalg
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 24, 2024
1 parent 7f41f1c commit ed53237
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 24 deletions.
6 changes: 3 additions & 3 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,9 @@ version = "0.1.6"

[[deps.GenericSchur]]
deps = ["LinearAlgebra", "Printf"]
git-tree-sha1 = "fb69b2a645fa69ba5f474af09221b9308b160ce6"
git-tree-sha1 = "af49a0851f8113fcfae2ef5027c6d49d0acec39b"
uuid = "c145ed77-6b09-5dd9-b285-bf645a82121e"
version = "0.5.3"
version = "0.5.4"

[[deps.Graphs]]
deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"]
Expand Down Expand Up @@ -608,7 +608,7 @@ version = "1.2.0"
[[deps.NonlinearSolve]]
deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Preferences", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArraysCore", "TimerOutputs"]
git-tree-sha1 = "0e464ca0e5d44a88c91f394c3f9a9448523e378b"
repo-rev = "ap/tstable_findmin"
repo-rev = "master"
repo-url = "https://github.com/SciML/NonlinearSolve.jl.git"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
version = "3.8.2"
Expand Down
4 changes: 2 additions & 2 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function __solve_nlproblem!(

# NOTE: u_at_nodes is updated inplace
nlprob = __internal_nlsolve_problem(prob, M, N, loss_function!, u_at_nodes, prob.p)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(prob, alg.nlsolve)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, alg.nlsolve)
__solve(nlprob, nlsolve_alg; kwargs..., alias_u0 = true)

return nothing
Expand Down Expand Up @@ -188,7 +188,7 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_

# NOTE: u_at_nodes is updated inplace
nlprob = __internal_nlsolve_problem(prob, M, N, loss_function!, u_at_nodes, prob.p)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(prob, alg.nlsolve)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, alg.nlsolve)
__solve(nlprob, nlsolve_alg; kwargs..., alias_u0 = true)

return nothing
Expand Down
2 changes: 1 addition & 1 deletion src/solve/single_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function __solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;),
nlf = __unsafe_nonlinearfunction{iip}(
loss_fn; jac_prototype, resid_prototype, jac = jac_fn)
nlprob = __internal_nlsolve_problem(prob, resid_prototype, u0, nlf, vec(u0), prob.p)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(prob, alg.nlsolve)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, alg.nlsolve)
nlsol = __solve(nlprob, nlsolve_alg; nlsolve_kwargs..., verbose, kwargs...)

# There is no way to reinit with the same cache with different cache. But not saving
Expand Down
36 changes: 18 additions & 18 deletions test/shooting/nlls_tests.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
# FIXME: The nonlinear solve polyalgorithm for NLLS is currently broken because of Bastin
# Jv & Jᵀv computation with the cached ODE solve
@testitem "Overconstrained BVP" begin
using LinearAlgebra, JET

SOLVERS = [
# Shooting(Tsit5()),
SOLVERS = [Shooting(Tsit5()),
Shooting(
Tsit5(), LevenbergMarquardt(; autodiff = AutoForwardDiff(; chunksize = 2))), Shooting(
Tsit5(), LevenbergMarquardt(; autodiff = AutoFiniteDiff())),
Tsit5(), LevenbergMarquardt(; autodiff = AutoForwardDiff(; chunksize = 2))),
Shooting(Tsit5(), LevenbergMarquardt(; autodiff = AutoFiniteDiff())),
Shooting(Tsit5(), GaussNewton(; autodiff = AutoForwardDiff(; chunksize = 2))),
Shooting(Tsit5(), GaussNewton(; autodiff = AutoFiniteDiff())),
Shooting(Tsit5(), TrustRegion(; autodiff = AutoForwardDiff(; chunksize = 2))),
Shooting(Tsit5(), TrustRegion(; autodiff = AutoFiniteDiff())),
# MultipleShooting(10, Tsit5()),
MultipleShooting(10, Tsit5()),
MultipleShooting(
10, Tsit5(), LevenbergMarquardt(; autodiff = AutoForwardDiff(; chunksize = 2))), MultipleShooting(
10, Tsit5(), LevenbergMarquardt(; autodiff = AutoFiniteDiff())),
10, Tsit5(), LevenbergMarquardt(; autodiff = AutoForwardDiff(; chunksize = 2))),
MultipleShooting(10, Tsit5(), LevenbergMarquardt(; autodiff = AutoFiniteDiff())),
MultipleShooting(
10, Tsit5(), GaussNewton(; autodiff = AutoForwardDiff(; chunksize = 2))),
MultipleShooting(10, Tsit5(), GaussNewton(; autodiff = AutoFiniteDiff())),
MultipleShooting(
10, Tsit5(), TrustRegion(; autodiff = AutoForwardDiff(; chunksize = 2))),
MultipleShooting(10, Tsit5(), TrustRegion(; autodiff = AutoFiniteDiff()))]
JET_SKIP = fill(false, length(SOLVERS))
JET_BROKEN = fill(false, length(SOLVERS))
JET_OPT_BROKEN = fill(false, length(SOLVERS))
JET_CALL_BROKEN = fill(false, length(SOLVERS))
JET_CALL_BROKEN[1] = true
JET_CALL_BROKEN[8] = true

# OOP MP-BVP
f1(u, p, t) = [u[2], -u[1]]
Expand Down Expand Up @@ -52,11 +52,11 @@
@test_opt target_modules=(
SciMLBase, DiffEqBase, NonlinearSolve, BoundaryValueDiffEq) solve(
bvp1, solver; verbose = false, abstol = 1e-6, reltol = 1e-6,
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_BROKEN[i]
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_OPT_BROKEN[i]
@test_call target_modules=(
SciMLBase, DiffEqBase, NonlinearSolve, BoundaryValueDiffEq) solve(
bvp1, solver; verbose = false, abstol = 1e-6, reltol = 1e-6,
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_BROKEN[i]
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_CALL_BROKEN[i]
end

# IIP MP-BVP
Expand Down Expand Up @@ -91,11 +91,11 @@
@test_opt target_modules=(
SciMLBase, DiffEqBase, NonlinearSolve, BoundaryValueDiffEq) solve(
bvp2, solver; verbose = false, abstol = 1e-6, reltol = 1e-6,
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_BROKEN[i]
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_OPT_BROKEN[i]
@test_call target_modules=(
SciMLBase, DiffEqBase, NonlinearSolve, BoundaryValueDiffEq) solve(
bvp2, solver; verbose = false, abstol = 1e-6, reltol = 1e-6,
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_BROKEN[i]
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_CALL_BROKEN[i]
end

# OOP TP-BVP
Expand All @@ -118,11 +118,11 @@
@test_opt target_modules=(
SciMLBase, DiffEqBase, NonlinearSolve, BoundaryValueDiffEq) solve(
bvp3, solver; verbose = false, abstol = 1e-6, reltol = 1e-6,
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_BROKEN[i]
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_OPT_BROKEN[i]
@test_call target_modules=(
SciMLBase, DiffEqBase, NonlinearSolve, BoundaryValueDiffEq) solve(
bvp3, solver; verbose = false, abstol = 1e-6, reltol = 1e-6,
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_BROKEN[i]
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_CALL_BROKEN[i]
end

# IIP TP-BVP
Expand All @@ -145,10 +145,10 @@
@test_opt target_modules=(
SciMLBase, DiffEqBase, NonlinearSolve, BoundaryValueDiffEq) solve(
bvp4, solver; verbose = false, abstol = 1e-6, reltol = 1e-6,
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_BROKEN[i]
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_OPT_BROKEN[i]
@test_call target_modules=(
SciMLBase, DiffEqBase, NonlinearSolve, BoundaryValueDiffEq) solve(
bvp4, solver; verbose = false, abstol = 1e-6, reltol = 1e-6,
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_BROKEN[i]
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)) broken=JET_CALL_BROKEN[i]
end
end

0 comments on commit ed53237

Please sign in to comment.