diff --git a/src/ensemble/basic_ensemble_solve.jl b/src/ensemble/basic_ensemble_solve.jl index bdba437e8..0c1047c35 100644 --- a/src/ensemble/basic_ensemble_solve.jl +++ b/src/ensemble/basic_ensemble_solve.jl @@ -45,7 +45,9 @@ tighten_container_eltype(u) = u function __solve(prob::EnsembleProblem{<:AbstractVector{<:AbstractSciMLProblem}}, alg::Union{AbstractDEAlgorithm, Nothing}, ensemblealg::BasicEnsembleAlgorithm; kwargs...) - solve(prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...) + # TODO: @invoke + invoke(__solve, Tuple{AbstractEnsembleProblem, typeof(alg), typeof(ensemblealg)}, + prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...) end function __solve(prob::AbstractEnsembleProblem, diff --git a/src/ensemble/ensemble_problems.jl b/src/ensemble/ensemble_problems.jl index 4a25e572a..65cb5ba78 100644 --- a/src/ensemble/ensemble_problems.jl +++ b/src/ensemble/ensemble_problems.jl @@ -15,7 +15,8 @@ DEFAULT_OUTPUT_FUNC(sol, i) = (sol, false) DEFAULT_REDUCTION(u, data, I) = append!(u, data), false DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i] function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...) - EnsembleProblem(prob; kwargs..., prob_func=DEFAULT_VECTOR_PROB_FUNC) + # TODO: @invoke + invoke(EnsembleProblem, Tuple{Any}, prob; prob_func=DEFAULT_VECTOR_PROB_FUNC, kwargs...) end function EnsembleProblem(prob; output_func = DEFAULT_OUTPUT_FUNC, diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index 323e86423..a5fe82714 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -197,3 +197,19 @@ end end end end + + +Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, s) + return [xi[s] for xi in x] +end + +Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...) + return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...) +end +Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Int...) + return [xi[args...] for xi in x] +end + +function (sol::AbstractEnsembleSolution)(args...; kwargs...) + [s(args...; kwargs...) for s in sol] +end diff --git a/test/downstream/ensemble_multi_prob.jl b/test/downstream/ensemble_multi_prob.jl index cc6040581..e8e7640ed 100644 --- a/test/downstream/ensemble_multi_prob.jl +++ b/test/downstream/ensemble_multi_prob.jl @@ -1,8 +1,19 @@ -using OrdinaryDiffEq, Test +using ModelingToolkit, OrdinaryDiffEq, Test + +@variables t, x(t) +D = Differential(t) + +@named sys1 = ODESystem([D(x) ~ 1.1*x]) +@named sys2 = ODESystem([D(x) ~ 1.2*x]) + +prob1 = ODEProblem(sys1, [2.0], (0.0, 1.0)) +prob2 = ODEProblem(sys2, [1.0], (0.0, 1.0)) -prob1 = ODEProblem((u, p, t) -> 0.99u, 0.55, (0.0, 1.1)) -prob1 = ODEProblem((u, p, t) -> 1.0u, 0.45, (0.0, 0.9)) -output_func(sol, i) = (last(sol), false) # test that when passing a vector of problems, trajectories and the prob_func are chosen appropriately -ensemble_prob = EnsembleProblem([prob1, prob2], output_func = output_func) -sim = solve(ensemble_prob, Tsit5(), EnsembleThreads()) +ensemble_prob = EnsembleProblem([prob1, prob2]) +sol = solve(ensemble_prob, Tsit5(), EnsembleThreads()) +@test isapprox(sol[:, x], [2,1] .* map(Base.Fix1(map, exp), [1.1, 1.2] .* sol[:, t]), rtol=1e-4) +# Ensemble is a recursive array +@test sol(0.0, idxs=[x]) == sol[:, 1] == first.(sol[:, x], 1) +# TODO: fix the interpolation +@test sol(1.0, idxs=[x]) ≈ last.(sol[:, x], 1)