diff --git a/src/ensemble/basic_ensemble_solve.jl b/src/ensemble/basic_ensemble_solve.jl index 83a87bb28..bdba437e8 100644 --- a/src/ensemble/basic_ensemble_solve.jl +++ b/src/ensemble/basic_ensemble_solve.jl @@ -42,6 +42,12 @@ end tighten_container_eltype(u::Vector{Any}) = map(identity, u) tighten_container_eltype(u) = u +function __solve(prob::EnsembleProblem{<:AbstractVector{<:AbstractSciMLProblem}}, + alg::Union{AbstractDEAlgorithm, Nothing}, + ensemblealg::BasicEnsembleAlgorithm; kwargs...) + solve(prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...) +end + function __solve(prob::AbstractEnsembleProblem, alg::Union{AbstractDEAlgorithm, Nothing}, ensemblealg::BasicEnsembleAlgorithm; diff --git a/src/ensemble/ensemble_problems.jl b/src/ensemble/ensemble_problems.jl index 0d09f97da..4a25e572a 100644 --- a/src/ensemble/ensemble_problems.jl +++ b/src/ensemble/ensemble_problems.jl @@ -13,6 +13,10 @@ end DEFAULT_PROB_FUNC(prob, i, repeat) = prob DEFAULT_OUTPUT_FUNC(sol, i) = (sol, false) DEFAULT_REDUCTION(u, data, I) = append!(u, data), false +DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i] +function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...) + EnsembleProblem(prob; kwargs..., prob_func=DEFAULT_VECTOR_PROB_FUNC) +end function EnsembleProblem(prob; output_func = DEFAULT_OUTPUT_FUNC, prob_func = DEFAULT_PROB_FUNC, diff --git a/test/downstream/ensemble_multi_prob.jl b/test/downstream/ensemble_multi_prob.jl new file mode 100644 index 000000000..cc6040581 --- /dev/null +++ b/test/downstream/ensemble_multi_prob.jl @@ -0,0 +1,8 @@ +using OrdinaryDiffEq, Test + +prob1 = ODEProblem((u, p, t) -> 0.99u, 0.55, (0.0, 1.1)) +prob1 = ODEProblem((u, p, t) -> 1.0u, 0.45, (0.0, 0.9)) +output_func(sol, i) = (last(sol), false) +# test that when passing a vector of problems, trajectories and the prob_func are chosen appropriately +ensemble_prob = EnsembleProblem([prob1, prob2], output_func = output_func) +sim = solve(ensemble_prob, Tsit5(), EnsembleThreads()) diff --git a/test/runtests.jl b/test/runtests.jl index 5ca0c7f1a..e4c918a9b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -63,6 +63,9 @@ end @time @safetestset "Timing first batch when solving Ensembles" begin include("downstream/ensemble_first_batch.jl") end + @time @safetestset "solving Ensembles with multiple problems" begin + include("downstream/ensemble_multi_prob.jl") + end @time @safetestset "Symbol and integer based indexing of interpolated solutions" begin include("downstream/symbol_indexing.jl") end