Skip to content

Commit

Permalink
Use macro for shared caches
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 14, 2024
1 parent dccc1dd commit 75f1874
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 78 deletions.
6 changes: 4 additions & 2 deletions src/algorithms/multistep.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing,

Check warning on line 1 in src/algorithms/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/multistep.jl#L1

Added line #L1 was not covered by tests
scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing,
vjp_autodiff = nothing, linesearch = NoLineSearch())
scheme_concrete = apply_patch(scheme, (; autodiff, vjp_autodiff))
forward_ad = ifelse(autodiff isa ADTypes.AbstractForwardMode, autodiff, nothing)
scheme_concrete = apply_patch(

Check warning on line 5 in src/algorithms/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/multistep.jl#L4-L5

Added lines #L4 - L5 were not covered by tests
scheme, (; autodiff, vjp_autodiff, jvp_autodiff = forward_ad))
descent = GenericMultiStepDescent(; scheme = scheme_concrete, linsolve, precs)
return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = MSS.display_name(scheme),

Check warning on line 8 in src/algorithms/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/multistep.jl#L7-L8

Added lines #L7 - L8 were not covered by tests
descent, jacobian_ad = autodiff, linesearch, reverse_ad = vjp_autodiff)
descent, jacobian_ad = autodiff, linesearch, reverse_ad = vjp_autodiff, forward_ad)
end
6 changes: 1 addition & 5 deletions src/descent/damped_newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,7 @@ function __internal_init(
shared::Val{N} = Val(1), kwargs...) where {INV, N}
length(fu) != length(u) &&
@assert !INV "Precomputed Inverse for Non-Square Jacobian doesn't make sense."
@bb δu = similar(u)
δus = N 1 ? nothing : map(2:N) do i
@bb δu_ = similar(u)
end

δu, δus = @shared_caches N (@bb δu = similar(u))
normal_form_damping = returns_norm_form_damping(alg.damping_fn)
normal_form_linsolve = __needs_square_A(alg.linsolve, u)
if u isa Number
Expand Down
5 changes: 1 addition & 4 deletions src/descent/dogleg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,7 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::Dogleg, J, fu, u;
linsolve_kwargs, abstol, reltol, shared, kwargs...)
cauchy_cache = __internal_init(prob, alg.steepest_descent, J, fu, u; pre_inverted,
linsolve_kwargs, abstol, reltol, shared, kwargs...)
@bb δu = similar(u)
δus = N 1 ? nothing : map(2:N) do i
@bb δu_ = similar(u)
end
δu, δus = @shared_caches N (@bb δu = similar(u))
@bb δu_cache_1 = similar(u)
@bb δu_cache_2 = similar(u)
@bb δu_cache_mul = similar(u)
Expand Down
5 changes: 1 addition & 4 deletions src/descent/geodesic_acceleration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,7 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::GeodesicAccelerati
abstol = nothing, reltol = nothing, internalnorm::F = DEFAULT_NORM,
kwargs...) where {INV, N, F}
T = promote_type(eltype(u), eltype(fu))
@bb δu = similar(u)
δus = N 1 ? nothing : map(2:N) do i
@bb δu_ = similar(u)
end
δu, δus = @shared_caches N (@bb δu = similar(u))
descent_cache = __internal_init(prob, alg.descent, J, fu, u; shared = Val(N * 2),
pre_inverted, linsolve_kwargs, abstol, reltol, kwargs...)
@bb Jv = similar(fu)
Expand Down
56 changes: 24 additions & 32 deletions src/descent/multistep.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,36 @@ function Base.show(io::IO, mss::AbstractMultiStepScheme)
print(io, "MultiStepSchemes.$(string(nameof(typeof(mss)))[3:end])")

Check warning on line 15 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L14-L15

Added lines #L14 - L15 were not covered by tests
end

alg_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = alg_steps(T())
newton_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = newton_steps(T())

Check warning on line 18 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L18

Added line #L18 was not covered by tests

struct __PotraPtak3 <: AbstractMultiStepScheme end
const PotraPtak3 = __PotraPtak3()

alg_steps(::__PotraPtak3) = 2
newton_steps(::__PotraPtak3) = 2
nintermediates(::__PotraPtak3) = 1

Check warning on line 24 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L23-L24

Added lines #L23 - L24 were not covered by tests

@kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme
jvp_autodiff = nothing
end
const SinghSharma4 = __SinghSharma4()

alg_steps(::__SinghSharma4) = 3
newton_steps(::__SinghSharma4) = 4
nintermediates(::__SinghSharma4) = 2

Check warning on line 32 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L31-L32

Added lines #L31 - L32 were not covered by tests

@kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme
jvp_autodiff = nothing
end
const SinghSharma5 = __SinghSharma5()

alg_steps(::__SinghSharma5) = 3
newton_steps(::__SinghSharma5) = 4
nintermediates(::__SinghSharma5) = 2

Check warning on line 40 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L39-L40

Added lines #L39 - L40 were not covered by tests

@kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme
jvp_autodiff = nothing
end
const SinghSharma7 = __SinghSharma7()

alg_steps(::__SinghSharma7) = 4
newton_steps(::__SinghSharma7) = 6

Check warning on line 47 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L47

Added line #L47 was not covered by tests

@generated function display_name(alg::T) where {T <: AbstractMultiStepScheme}
res = Symbol(first(split(last(split(string(T), ".")), "{"; limit = 2))[3:end])
Expand Down Expand Up @@ -75,6 +77,8 @@ supports_trust_region(::GenericMultiStepDescent) = false
fus
internal_cache
internal_caches
extra
extras
scheme::S
timer
nf::Int
Expand All @@ -91,49 +95,37 @@ function __reinit_internal!(cache::GenericMultiStepDescentCache, args...; p = ca
end

function __internal_multistep_caches(

Check warning on line 97 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L97

Added line #L97 was not covered by tests
scheme::MSS.__PotraPtak3, alg::GenericMultiStepDescent,
prob, args...; shared::Val{N} = Val(1), kwargs...) where {N}
scheme::Union{MSS.__PotraPtak3, MSS.__SinghSharma4, MSS.__SinghSharma5},
alg::GenericMultiStepDescent, prob, args...;
shared::Val{N} = Val(1), kwargs...) where {N}
internal_descent = NewtonDescent(; alg.linsolve, alg.precs)
internal_cache = __internal_init(
return @shared_caches N __internal_init(

Check warning on line 102 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L101-L102

Added lines #L101 - L102 were not covered by tests
prob, internal_descent, args...; kwargs..., shared = Val(2))
internal_caches = N 1 ? nothing :
map(2:N) do i
__internal_init(prob, internal_descent, args...; kwargs..., shared = Val(2))
end
return internal_cache, internal_caches
end

__extras_cache(::MSS.AbstractMultiStepScheme, args...; kwargs...) = nothing, nothing

Check warning on line 106 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L106

Added line #L106 was not covered by tests

function __internal_init(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},

Check warning on line 108 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L108

Added line #L108 was not covered by tests
alg::GenericMultiStepDescent, J, fu, u; shared::Val{N} = Val(1),
pre_inverted::Val{INV} = False, linsolve_kwargs = (;),
abstol = nothing, reltol = nothing, timer = get_timer_output(),
kwargs...) where {INV, N}
@bb δu = similar(u)
δus = N 1 ? nothing : map(2:N) do i
@bb δu_ = similar(u)
end
fu_cache = ntuple(MSS.nintermediates(alg.scheme)) do i
δu, δus = @shared_caches N (@bb δu = similar(u))
fu_cache, fus_cache = @shared_caches N (ntuple(MSS.nintermediates(alg.scheme)) do i
@bb xx = similar(fu)

Check warning on line 115 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L113-L115

Added lines #L113 - L115 were not covered by tests
end
fus_cache = N 1 ? nothing : map(2:N) do i
ntuple(MSS.nintermediates(alg.scheme)) do j
@bb xx = similar(fu)
end
end
u_cache = ntuple(MSS.nintermediates(alg.scheme)) do i
end)
u_cache, us_cache = @shared_caches N (ntuple(MSS.nintermediates(alg.scheme)) do i
@bb xx = similar(u)

Check warning on line 118 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L117-L118

Added lines #L117 - L118 were not covered by tests
end
us_cache = N 1 ? nothing : map(2:N) do i
ntuple(MSS.nintermediates(alg.scheme)) do j
@bb xx = similar(u)
end
end
end)
internal_cache, internal_caches = __internal_multistep_caches(

Check warning on line 120 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L120

Added line #L120 was not covered by tests
alg.scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs,
abstol, reltol, timer, kwargs...)
extra, extras = __extras_cache(

Check warning on line 123 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L123

Added line #L123 was not covered by tests
alg.scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs,
abstol, reltol, timer, kwargs...)
return GenericMultiStepDescentCache(

Check warning on line 126 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L126

Added line #L126 was not covered by tests
prob.f, prob.p, δu, δus, u_cache, us_cache, fu_cache, fus_cache,
internal_cache, internal_caches, alg.scheme, timer, 0)
internal_cache, internal_caches, extra, extras, alg.scheme, timer, 0)
end

function __internal_solve!(cache::GenericMultiStepDescentCache{MSS.__PotraPtak3, INV}, J,

Check warning on line 131 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L131

Added line #L131 was not covered by tests
Expand Down
10 changes: 2 additions & 8 deletions src/descent/newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ function __internal_init(prob::NonlinearProblem, alg::NewtonDescent, J, fu, u;
shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, linsolve_kwargs = (;),
abstol = nothing, reltol = nothing, timer = get_timer_output(),
kwargs...) where {INV, N}
@bb δu = similar(u)
δus = N 1 ? nothing : map(2:N) do i
@bb δu_ = similar(u)
end
δu, δus = @shared_caches N (@bb δu = similar(u))
INV && return NewtonDescentCache{true, false}(δu, δus, nothing, nothing, nothing, timer)
lincache = LinearSolverCache(alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol,
linsolve_kwargs...)
Expand All @@ -64,10 +61,7 @@ function __internal_init(prob::NonlinearLeastSquaresProblem, alg::NewtonDescent,
end
lincache = LinearSolverCache(alg, alg.linsolve, A, b, _vec(u); abstol, reltol,
linsolve_kwargs...)
@bb δu = similar(u)
δus = N 1 ? nothing : map(2:N) do i
@bb δu_ = similar(u)
end
δu, δus = @shared_caches N (@bb δu = similar(u))
return NewtonDescentCache{false, normal_form}(δu, δus, lincache, JᵀJ, Jᵀfu, timer)
end

Expand Down
5 changes: 1 addition & 4 deletions src/descent/steepest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ end
linsolve_kwargs = (;), abstol = nothing, reltol = nothing,
timer = get_timer_output(), kwargs...) where {INV, N}
INV && @assert length(fu)==length(u) "Non-Square Jacobian Inverse doesn't make sense."
@bb δu = similar(u)
δus = N 1 ? nothing : map(2:N) do i
@bb δu_ = similar(u)
end
δu, δus = @shared_caches N (@bb δu = similar(u))
if INV
lincache = LinearSolverCache(alg, alg.linsolve, transpose(J), _vec(fu), _vec(u);
abstol, reltol, linsolve_kwargs...)
Expand Down
38 changes: 38 additions & 0 deletions src/internal/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,41 @@ function __internal_caches(__source__, __module__, cType, internal_cache_names::
end
end)
end

"""
apply_patch(scheme, patch::NamedTuple{names})
Applies the patch to the scheme, returning the new scheme. If some of the `names` are not,
present in the scheme, they are ignored.
"""
@generated function apply_patch(scheme, patch::NamedTuple{names}) where {names}
exprs = []
for name in names
hasfield(scheme, name) || continue
push!(exprs, quote
lens = PropertyLens{$(Meta.quot(name))}()
return set(scheme, lens, getfield(patch, $(Meta.quot(name))))

Check warning on line 284 in src/internal/helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/helpers.jl#L278-L284

Added lines #L278 - L284 were not covered by tests
end)
end
push!(exprs, :(return scheme))
return Expr(:block, exprs...)

Check warning on line 288 in src/internal/helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/helpers.jl#L286-L288

Added lines #L286 - L288 were not covered by tests
end

"""
@shared_caches N expr
Create a shared cache and a vector of caches. If `N` is 1, then the vector of caches is
`nothing`.
"""
macro shared_caches(N, expr)
@gensym cache caches
return esc(quote
begin

Check warning on line 300 in src/internal/helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/helpers.jl#L300

Added line #L300 was not covered by tests
$(cache) = $(expr)
$(caches) = $(N) 1 ? nothing : map(2:($(N))) do i
$(expr)
end
($cache, $caches)
end
end)
end
19 changes: 0 additions & 19 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,22 +158,3 @@ Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the i
"""
@inline pickchunksize(x) = pickchunksize(length(x))
@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)

"""
apply_patch(scheme, patch::NamedTuple{names})
Applies the patch to the scheme, returning the new scheme. If some of the `names` are not,
present in the scheme, they are ignored.
"""
@generated function apply_patch(scheme, patch::NamedTuple{names}) where {names}
exprs = []
for name in names
hasfield(scheme, name) || continue
push!(exprs, quote
lens = PropertyLens{$(Meta.quot(name))}()
return set(scheme, lens, getfield(patch, $(Meta.quot(name))))
end)
end
push!(exprs, :(return scheme))
return Expr(:block, exprs...)
end

0 comments on commit 75f1874

Please sign in to comment.