Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ end
tighten_container_eltype(u::Vector{Any}) = map(identity, u)
tighten_container_eltype(u) = u

function __solve(prob::EnsembleProblem{<:AbstractVector{<:AbstractSciMLProblem}},
alg::Union{AbstractDEAlgorithm, Nothing},
ensemblealg::BasicEnsembleAlgorithm; kwargs...)
solve(prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...)
end

function __solve(prob::AbstractEnsembleProblem,
alg::Union{AbstractDEAlgorithm, Nothing},
ensemblealg::BasicEnsembleAlgorithm;
Expand Down
4 changes: 4 additions & 0 deletions src/ensemble/ensemble_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ end
DEFAULT_PROB_FUNC(prob, i, repeat) = prob
DEFAULT_OUTPUT_FUNC(sol, i) = (sol, false)
DEFAULT_REDUCTION(u, data, I) = append!(u, data), false
DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i]
function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...)
EnsembleProblem(prob; kwargs..., prob_func=DEFAULT_VECTOR_PROB_FUNC)
end
function EnsembleProblem(prob;
output_func = DEFAULT_OUTPUT_FUNC,
prob_func = DEFAULT_PROB_FUNC,
Expand Down
8 changes: 8 additions & 0 deletions test/downstream/ensemble_multi_prob.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using OrdinaryDiffEq, Test

prob1 = ODEProblem((u, p, t) -> 0.99u, 0.55, (0.0, 1.1))
prob1 = ODEProblem((u, p, t) -> 1.0u, 0.45, (0.0, 0.9))
output_func(sol, i) = (last(sol), false)
# test that when passing a vector of problems, trajectories and the prob_func are chosen appropriately
ensemble_prob = EnsembleProblem([prob1, prob2], output_func = output_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads())
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ end
@time @safetestset "Timing first batch when solving Ensembles" begin
include("downstream/ensemble_first_batch.jl")
end
@time @safetestset "solving Ensembles with multiple problems" begin
include("downstream/ensemble_multi_prob.jl")
end
@time @safetestset "Symbol and integer based indexing of interpolated solutions" begin
include("downstream/symbol_indexing.jl")
end
Expand Down