Skip to content

Commit

Permalink
Remove the static arrays special casing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 26, 2024
1 parent ae0bf10 commit fb2f20d
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 32 deletions.
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,18 @@ MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
SimpleNonlinearSolveTrackerExt = "Tracker"
SimpleNonlinearSolveZygoteExt = "Zygote"

Expand All @@ -40,7 +39,7 @@ AllocCheck = "0.1.1"
Aqua = "0.8"
ArrayInterface = "7.9"
CUDA = "5.2"
ChainRulesCore = "1.22"
ChainRulesCore = "1.23"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.149"
DiffResults = "1.1"
Expand All @@ -59,13 +58,14 @@ PrecompileTools = "1.2"
Random = "1.10"
ReTestItems = "1.23"
Reexport = "1.2"
ReverseDiff = "1.15"
ReverseDiff = "1.15.3"
SciMLBase = "2.37.0"
SciMLSensitivity = "7.58"
Setfield = "1.1.1"
StaticArrays = "1.9"
StaticArraysCore = "1.4.2"
Test = "1.10"
Tracker = "0.2.32"
Tracker = "0.2.33"
Zygote = "0.6.69"
julia = "1.10"

Expand Down
7 changes: 0 additions & 7 deletions ext/SimpleNonlinearSolveStaticArraysExt.jl

This file was deleted.

1 change: 1 addition & 0 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
build_solution, isinplace, _unwrap_val
using Setfield: @set!
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
end

Expand Down
6 changes: 4 additions & 2 deletions src/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,12 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
end
else
# For small problems, nesting ForwardDiff is actually quite fast
_f = Base.Fix2(prob.f, newprob.p)
if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol.resid) 50)
# TODO: Remove once DI has the value_and_pullback_split defined
_F = @closure (u, p) -> __zygote_compute_nlls_vjp(_f, u, p)
_F = @closure (u, p) -> begin
_f = Base.Fix2(prob.f, p)
return __zygote_compute_nlls_vjp(_f, u, p)

Check warning on line 116 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L114-L116

Added lines #L114 - L116 were not covered by tests
end
else
_F = @closure (u, p) -> begin
_f = Base.Fix2(prob.f, p)
Expand Down
15 changes: 9 additions & 6 deletions src/nlsolve/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
α_1 = one(T)
f_1 = fx_norm

history_f_k = if x isa SArray ||
(x isa Number && __is_extension_loaded(Val(:StaticArrays)))
ones(SVector{M, T}) * fx_norm
else
fill(fx_norm, M)
end
history_f_k = x isa SArray ? ones(SVector{M, T}) * fx_norm :
__history_vec(fx_norm, Val(M))

# Generate the cache
@bb x_cache = similar(x)
Expand Down Expand Up @@ -150,6 +146,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
# Store function value
if history_f_k isa SVector
history_f_k = Base.setindex(history_f_k, fx_norm_new, mod1(k, M))
elseif history_f_k isa NTuple
@set! history_f_k[mod1(k, M)] = fx_norm_new
else
history_f_k[mod1(k, M)] = fx_norm_new
end
Expand All @@ -158,3 +156,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...

return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
end

@inline @generated function __history_vec(fx_norm, ::Val{M}) where {M}

Check warning on line 160 in src/nlsolve/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/nlsolve/dfsane.jl#L160

Added line #L160 was not covered by tests
M 11 && return :(fill(fx_norm, M)) # Julia can't specialize here
return :(ntuple(Returns(fx_norm), $(M)))
end
14 changes: 7 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,15 @@ function value_and_jacobian(

if isinplace(prob)
if cache isa HasAnalyticJacobian
prob.f.jac(J, x, p)
prob.f.jac(J, x, prob.p)
f(y, x)
else
DI.jacobian!(f, y, J, ad, x, cache)
return y, J
end
return y, J
return DI.value_and_jacobian!(f, y, J, ad, x, cache)
else
cache isa HasAnalyticJacobian && return f(x), prob.f.jac(x, prob.p)
J === nothing && return DI.value_and_jacobian(f, ad, x, cache)
y, _ = DI.value_and_jacobian!(f, J, ad, x, cache)
y, J = DI.value_and_jacobian!(f, J, ad, x, cache)
return y, J
end
end
Expand All @@ -63,8 +62,9 @@ end
function compute_jacobian_and_hessian(
ad, prob::AbstractNonlinearProblem, f::F, y, x) where {F}
if x isa Number
df = @closure x -> DI.derivative(f, ad, x)
return f(x), df(x), DI.derivative(df, ad, x)
H = DI.second_derivative(f, ad, x)
v, J = DI.value_and_derivative(f, ad, x)
return v, J, H
end

if isinplace(prob)
Expand Down
9 changes: 4 additions & 5 deletions test/core/rootfind_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,14 @@ end

export quadratic_f, quadratic_f!, quadratic_f2, newton_fails, TERMINATION_CONDITIONS,
benchmark_nlsolve_oop, benchmark_nlsolve_iip

end

@testitem "First Order Methods" setup=[RootfindingTesting] tags=[:core] begin
@testset "$(alg)" for alg in (SimpleNewtonRaphson,
SimpleTrustRegion,
(args...; kwargs...) -> SimpleTrustRegion(
args...; nlsolve_update_rule = Val(true), kwargs...))
@testset "AutoDiff: $(nameof(typeof(autodiff))))" for autodiff in (
@testset "AutoDiff: $(nameof(typeof(autodiff)))" for autodiff in (
AutoFiniteDiff(), AutoForwardDiff(), AutoPolyesterForwardDiff())
@testset "[OOP] u0: $(typeof(u0))" for u0 in (
[1.0, 1.0], @SVector[1.0, 1.0], 1.0)
Expand All @@ -59,7 +58,7 @@ end
end
end

@testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
@testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand All @@ -79,7 +78,7 @@ end
end
end

@testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
@testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand All @@ -104,7 +103,7 @@ end
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
end

@testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
@testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand Down

0 comments on commit fb2f20d

Please sign in to comment.