Skip to content

Commit

Permalink
Merge pull request #696 from AayushSabharwal/as/bv-diffeq-internals
Browse files Browse the repository at this point in the history
refactor: move type-pirated function from BoundaryValueDiffEq here, use Accessors.jl
  • Loading branch information
ChrisRackauckas committed Jun 5, 2024
2 parents 9185795 + 17a165c commit 8706b5c
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 86 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "2.39.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand Down Expand Up @@ -51,6 +52,7 @@ SciMLBaseZygoteExt = "Zygote"

[compat]
ADTypes = "0.2.5,1.0.0"
Accessors = "0.1.36"
ArrayInterface = "7.6"
ChainRules = "1.58.0"
ChainRulesCore = "1.18"
Expand All @@ -76,7 +78,7 @@ PyCall = "1.96"
PythonCall = "0.9.15"
RCall = "0.14.0"
RecipesBase = "1.3.4"
RecursiveArrayTools = "3.14.0"
RecursiveArrayTools = "3.22.0"
Reexport = "1"
RuntimeGeneratedFunctions = "0.5.12"
SciMLOperators = "0.3.7"
Expand Down
10 changes: 1 addition & 9 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,7 @@ end
du, dprob
end
T = eltype(eltype(VA.u))
if dprob.u0 === nothing
N = 2
elseif dprob isa SciMLBase.BVProblem && !hasmethod(size, Tuple{typeof(dprob.u0)})
__u0 = hasmethod(dprob.u0, Tuple{typeof(dprob.p), typeof(first(dprob.tspan))}) ?
dprob.u0(dprob.p, first(dprob.tspan)) : dprob.u0(first(dprob.tspan))
N = length((size(__u0)..., length(du)))
else
N = length((size(dprob.u0)..., length(du)))
end
N = ndims(VA)
Δ′ = ODESolution{T, N}(du, nothing, nothing,
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.alg_choice, VA.retcode)
Expand Down
1 change: 1 addition & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import FunctionWrappersWrappers
import RuntimeGeneratedFunctions
import EnumX
import ADTypes: AbstractADType
import Accessors: @set, @reset

using Reexport
using SciMLOperators
Expand Down
6 changes: 4 additions & 2 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,8 @@ function check_error(integrator::DEIntegrator)
@warn("dt($(integrator.dt)) <= dtmin($(opts.dtmin)) at t=$(integrator.t)$EEst. Aborting. There is either an error in your model specification or the true solution is unstable.")
end
return ReturnCode.DtLessThanMin
elseif !step_accepted && integrator.t isa AbstractFloat && abs(integrator.dt) <= abs(eps(integrator.t))
elseif !step_accepted && integrator.t isa AbstractFloat &&
abs(integrator.dt) <= abs(eps(integrator.t))
if verbose
if isdefined(integrator, :EEst)
EEst = ", and step error estimate = $(integrator.EEst)"
Expand All @@ -634,7 +635,8 @@ function check_error(integrator::DEIntegrator)
return ReturnCode.Unstable
end
end
if step_accepted && opts.unstable_check(integrator.dt, integrator.u, integrator.p, integrator.t)
if step_accepted &&
opts.unstable_check(integrator.dt, integrator.u, integrator.p, integrator.t)
if verbose
@warn("Instability detected. Aborting")
end
Expand Down
101 changes: 34 additions & 67 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, S,
original::O
end

function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: ODESolution{T, N}}
ODESolution{T, N}
end

function ConstructionBase.setproperties(sol::ODESolution, patch::NamedTuple)
u = get(patch, :u, sol.u)
N = u === nothing ? 2 : ndims(eltype(u)) + 1
T = eltype(eltype(u))
patch = merge(getproperties(sol), patch)
return ODESolution{T, N}(patch.u, patch.u_analytic, patch.errors, patch.t, patch.k,
patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats,
patch.alg_choice, patch.retcode, patch.resid, patch.original)
end

Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Symbol)
if s === :destats
Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats")
Expand Down Expand Up @@ -272,7 +286,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
prob.u0(prob.p, first(prob.tspan)) : prob.u0(first(prob.tspan))
N = length((size(__u0)..., length(u)))
else
N = length((size(prob.u0)..., length(u)))
N = ndims(eltype(u)) + 1
end

if prob.f isa Tuple
Expand Down Expand Up @@ -372,75 +386,31 @@ function calculate_solution_errors!(sol::AbstractODESolution; fill_uanalytic = t
end

function build_solution(sol::ODESolution{T, N}, u_analytic, errors) where {T, N}
ODESolution{T, N}(sol.u,
u_analytic,
errors,
sol.t,
sol.k,
sol.prob,
sol.alg,
sol.interp,
sol.dense,
sol.tslocation,
sol.stats,
sol.alg_choice,
sol.retcode,
sol.resid,
sol.original)
@reset sol.u_analytic = u_analytic
return @set sol.errors = errors
end

function solution_new_retcode(sol::ODESolution{T, N}, retcode) where {T, N}
ODESolution{T, N}(sol.u,
sol.u_analytic,
sol.errors,
sol.t,
sol.k,
sol.prob,
sol.alg,
sol.interp,
sol.dense,
sol.tslocation,
sol.stats,
sol.alg_choice,
retcode,
sol.resid,
sol.original)
return @set sol.retcode = retcode
end

function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N}
ODESolution{T, N}(sol.u,
sol.u_analytic,
sol.errors,
sol.t,
sol.k,
sol.prob,
sol.alg,
sol.interp,
sol.dense,
tslocation,
sol.stats,
sol.alg_choice,
sol.retcode,
sol.resid,
sol.original)
return @set sol.tslocation = tslocation
end

function solution_new_original_retcode(
sol::ODESolution{T, N}, original, retcode, resid) where {T, N}
@reset sol.original = original
@reset sol.retcode = retcode
return @set sol.resid = resid
end

function solution_slice(sol::ODESolution{T, N}, I) where {T, N}
ODESolution{T, N}(sol.u[I],
sol.u_analytic === nothing ? nothing : sol.u_analytic[I],
sol.errors,
sol.t[I],
sol.dense ? sol.k[I] : sol.k,
sol.prob,
sol.alg,
sol.interp,
false,
sol.tslocation,
sol.stats,
sol.alg_choice,
sol.retcode,
sol.resid,
sol.original)
@reset sol.u = sol.u[I]
@reset sol.u_analytic = sol.u_analytic === nothing ? nothing : sol.u_analytic[I]
@reset sol.t = sol.t[I]
@reset sol.k = sol.dense ? sol.k[I] : sol.k
return @set sol.alg = false
end

function sensitivity_solution(sol::ODESolution, u, t)
Expand All @@ -455,10 +425,7 @@ function sensitivity_solution(sol::ODESolution, u, t)
end

interp = enable_interpolation_sensitivitymode(sol.interp)
ODESolution{T, N}(u, sol.u_analytic, sol.errors,
t isa Vector ? t : collect(t),
sol.k, sol.prob,
sol.alg, interp,
sol.dense, sol.tslocation,
sol.stats, sol.alg_choice, sol.retcode, sol.resid, sol.original)
@reset sol.u = u
@reset sol.t = t isa Vector ? t : collect(t)
return @set sol.interp = interp
end
3 changes: 3 additions & 0 deletions test/aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ using Test
using SciMLBase
using Aqua

# https://github.com/JuliaArrays/FillArrays.jl/pull/163
@test_broken isempty(detect_ambiguities(SciMLBase))

@testset "Aqua tests (performance)" begin
# This tests that we don't accidentally run into
# https://github.com/JuliaLang/julia/issues/29393
Expand Down
2 changes: 2 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -35,5 +36,6 @@ SciMLSensitivity = "7.11"
SciMLStructures = "1.1"
Sundials = "4.11"
SymbolicIndexingInterface = "0.3"
SymbolicUtils = "<1.6"
Unitful = "1.12"
Zygote = "0.6"
7 changes: 4 additions & 3 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ x_val = vcat.(getindex.((sol,), x_idx, :)...)
y_val = sol[y_idx, :]
obs_val = sol[x[1] + y]

# checking inference for mixed-type arrays will always fail
# don't check inference for weird cases of nested arrays/tuples
for (sym, val, check_inference) in [
(x, x_val, true),
(y, y_val, true),
Expand All @@ -254,7 +254,7 @@ for (sym, val, check_inference) in [
((x, x), [(i, i) for i in x_val], true),
((x, x_idx), [(i, i) for i in x_val], true),
((x, x[1] + y), [(i, j) for (i, j) in zip(x_val, obs_val)], true),
((x, (x[1] + y, y)), [(i, (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], true),
((x, (x[1] + y, y)), [(i, (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], false),
([x, [x[1] + y, y]], [[i, [k, j]] for (i, j, k) in zip(x_val, y_val, obs_val)], false),
((x, [x[1] + y, y], (x[1] + y, y_idx)),
[(i, [k, j], (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], false),
Expand Down Expand Up @@ -311,6 +311,7 @@ end
pval = [1.0, 2.0, 3.0]
pval_new = [4.0, 5.0, 6.0]

# don't check inference for nested tuples/arrays
for (sym, oldval, newval, check_inference) in [
(p[1], pval[1], pval_new[1], true),
(p, pval, pval_new, true),
Expand All @@ -319,7 +320,7 @@ for (sym, oldval, newval, check_inference) in [
((p[1], p[2:3]), (pval[1], pval[2:3]), (pval_new[1], pval_new[2:3]), true),
([p[1], p[2:3]], [pval[1], pval[2:3]], [pval_new[1], pval_new[2:3]], false),
((p[1], (p[2],), [p[3]]), (pval[1], (pval[2],), [pval[3]]),
(pval_new[1], (pval_new[2],), [pval_new[3]]), true),
(pval_new[1], (pval_new[2],), [pval_new[3]]), false),
([p[1], (p[2],), [p[3]]], [pval[1], (pval[2],), [pval[3]]],
[pval_new[1], (pval_new[2],), [pval_new[3]]], false)
]
Expand Down
4 changes: 0 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
using Pkg
using SafeTestsets
using Test
using SciMLBase

# https://github.com/JuliaArrays/FillArrays.jl/pull/163
@test_broken isempty(detect_ambiguities(SciMLBase))

const GROUP = get(ENV, "GROUP", "All")
const is_APPVEYOR = (Sys.iswindows() && haskey(ENV, "APPVEYOR"))
Expand Down

0 comments on commit 8706b5c

Please sign in to comment.