Skip to content

Commit

Permalink
fix: fix incorrect dimensionality of ODESolution in `build_function…
Browse files Browse the repository at this point in the history
…` and `@set`
  • Loading branch information
AayushSabharwal committed Jun 4, 2024
1 parent 624b63d commit 317adb1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
10 changes: 1 addition & 9 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,7 @@ end
du, dprob
end
T = eltype(eltype(VA.u))
if dprob.u0 === nothing
N = 2
elseif dprob isa SciMLBase.BVProblem && !hasmethod(size, Tuple{typeof(dprob.u0)})
__u0 = hasmethod(dprob.u0, Tuple{typeof(dprob.p), typeof(first(dprob.tspan))}) ?
dprob.u0(dprob.p, first(dprob.tspan)) : dprob.u0(first(dprob.tspan))
N = length((size(__u0)..., length(du)))
else
N = length((size(dprob.u0)..., length(du)))
end
N = ndims(VA)
Δ′ = ODESolution{T, N}(du, nothing, nothing,
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.alg_choice, VA.retcode)
Expand Down
12 changes: 11 additions & 1 deletion src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: ODESolution
ODESolution{T, N}

Check warning on line 129 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L128-L129

Added lines #L128 - L129 were not covered by tests
end

function ConstructionBase.setproperties(sol::ODESolution, patch::NamedTuple)
u = get(patch, :u, sol.u)
N = u === nothing ? 2 : ndims(eltype(u)) + 1
T = eltype(eltype(u))
patch = merge(getproperties(sol), patch)
return ODESolution{T, N}(patch.u, patch.u_analytic, patch.errors, patch.t, patch.k,
patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats,
patch.alg_choice, patch.retcode, patch.resid, patch.original)
end

Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Symbol)
if s === :destats
Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats")
Expand Down Expand Up @@ -276,7 +286,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
prob.u0(prob.p, first(prob.tspan)) : prob.u0(first(prob.tspan))
N = length((size(__u0)..., length(u)))
else
N = length((size(prob.u0)..., length(u)))
N = ndims(eltype(u)) + 1
end

if prob.f isa Tuple
Expand Down

0 comments on commit 317adb1

Please sign in to comment.