Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ tighten_container_eltype(u) = u
function __solve(prob::EnsembleProblem{<:AbstractVector{<:AbstractSciMLProblem}},
alg::Union{AbstractDEAlgorithm, Nothing},
ensemblealg::BasicEnsembleAlgorithm; kwargs...)
solve(prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...)
# TODO: @invoke
invoke(__solve, Tuple{AbstractEnsembleProblem, typeof(alg), typeof(ensemblealg)},
prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...)
end

function __solve(prob::AbstractEnsembleProblem,
Expand Down
3 changes: 2 additions & 1 deletion src/ensemble/ensemble_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ DEFAULT_OUTPUT_FUNC(sol, i) = (sol, false)
DEFAULT_REDUCTION(u, data, I) = append!(u, data), false
DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i]
function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...)
EnsembleProblem(prob; kwargs..., prob_func=DEFAULT_VECTOR_PROB_FUNC)
# TODO: @invoke
invoke(EnsembleProblem, Tuple{Any}, prob; prob_func=DEFAULT_VECTOR_PROB_FUNC, kwargs...)
end
function EnsembleProblem(prob;
output_func = DEFAULT_OUTPUT_FUNC,
Expand Down
16 changes: 16 additions & 0 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,19 @@ end
end
end
end


Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, s)
return [xi[s] for xi in x]
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...)
return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...)
end
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Int...)
return [xi[args...] for xi in x]
end

function (sol::AbstractEnsembleSolution)(args...; kwargs...)
[s(args...; kwargs...) for s in sol]
end
23 changes: 17 additions & 6 deletions test/downstream/ensemble_multi_prob.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
using OrdinaryDiffEq, Test
using ModelingToolkit, OrdinaryDiffEq, Test

@variables t, x(t)
D = Differential(t)

@named sys1 = ODESystem([D(x) ~ 1.1*x])
@named sys2 = ODESystem([D(x) ~ 1.2*x])

prob1 = ODEProblem(sys1, [2.0], (0.0, 1.0))
prob2 = ODEProblem(sys2, [1.0], (0.0, 1.0))

prob1 = ODEProblem((u, p, t) -> 0.99u, 0.55, (0.0, 1.1))
prob1 = ODEProblem((u, p, t) -> 1.0u, 0.45, (0.0, 0.9))
output_func(sol, i) = (last(sol), false)
# test that when passing a vector of problems, trajectories and the prob_func are chosen appropriately
ensemble_prob = EnsembleProblem([prob1, prob2], output_func = output_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads())
ensemble_prob = EnsembleProblem([prob1, prob2])
sol = solve(ensemble_prob, Tsit5(), EnsembleThreads())
@test isapprox(sol[:, x], [2,1] .* map(Base.Fix1(map, exp), [1.1, 1.2] .* sol[:, t]), rtol=1e-4)
# Ensemble is a recursive array
@test sol(0.0, idxs=[x]) == sol[:, 1] == first.(sol[:, x], 1)
# TODO: fix the interpolation
@test sol(1.0, idxs=[x]) ≈ last.(sol[:, x], 1)