Skip to content

Commit

Permalink
Remove the static arrays special casing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 26, 2024
1 parent ae0bf10 commit 7704d3d
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 19 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,18 @@ MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
SimpleNonlinearSolveTrackerExt = "Tracker"
SimpleNonlinearSolveZygoteExt = "Zygote"

Expand Down Expand Up @@ -62,6 +61,7 @@ Reexport = "1.2"
ReverseDiff = "1.15"
SciMLBase = "2.37.0"
SciMLSensitivity = "7.58"
Setfield = "1.1.1"
StaticArrays = "1.9"
StaticArraysCore = "1.4.2"
Test = "1.10"
Expand Down
7 changes: 0 additions & 7 deletions ext/SimpleNonlinearSolveStaticArraysExt.jl

This file was deleted.

1 change: 1 addition & 0 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
build_solution, isinplace, _unwrap_val
using Setfield: @set!
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
end

Expand Down
16 changes: 10 additions & 6 deletions src/nlsolve/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
α_1 = one(T)
f_1 = fx_norm

history_f_k = if x isa SArray ||
(x isa Number && __is_extension_loaded(Val(:StaticArrays)))
ones(SVector{M, T}) * fx_norm
else
fill(fx_norm, M)
end
history_f_k = x isa SArray ? ones(SVector{M, T}) * fx_norm :
__history_vec(fx_norm, Val(M))

# Generate the cache
@bb x_cache = similar(x)
Expand Down Expand Up @@ -150,6 +146,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
# Store function value
if history_f_k isa SVector
history_f_k = Base.setindex(history_f_k, fx_norm_new, mod1(k, M))
elseif history_f_k isa NTuple
@set! history_f_k[mod1(k, M)] = fx_norm_new
else
history_f_k[mod1(k, M)] = fx_norm_new
end
Expand All @@ -158,3 +156,9 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...

return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
end

@inline @generated function __history_vec(fx_norm, ::Val{M}) where {M}

Check warning on line 160 in src/nlsolve/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/nlsolve/dfsane.jl#L160

Added line #L160 was not covered by tests
# Julia can't specialize here
M 11 && return :(fill(fx_norm, M))
return :(ntuple(Returns(fx_norm), $(M)))
end
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function value_and_jacobian(

if isinplace(prob)
if cache isa HasAnalyticJacobian
prob.f.jac(J, x, p)
prob.f.jac(J, x, prob.p)
f(y, x)

Check warning on line 36 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L35-L36

Added lines #L35 - L36 were not covered by tests
else
DI.jacobian!(f, y, J, ad, x, cache)
Expand Down
6 changes: 3 additions & 3 deletions test/core/rootfind_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ end
end
end

@testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
@testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand All @@ -79,7 +79,7 @@ end
end
end

@testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
@testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand All @@ -104,7 +104,7 @@ end
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
end

@testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
@testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand Down

0 comments on commit 7704d3d

Please sign in to comment.