From 5646fae2d4e98df3bdbba18f985f1608d8b3d95c Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 11 Jul 2023 08:53:04 -0400 Subject: [PATCH 1/6] better support for automatically handling vector ensembles --- src/ensemble/basic_ensemble_solve.jl | 6 ++++++ src/ensemble/ensemble_problems.jl | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/src/ensemble/basic_ensemble_solve.jl b/src/ensemble/basic_ensemble_solve.jl index 83a87bb28..00bf89860 100644 --- a/src/ensemble/basic_ensemble_solve.jl +++ b/src/ensemble/basic_ensemble_solve.jl @@ -42,6 +42,12 @@ end tighten_container_eltype(u::Vector{Any}) = map(identity, u) tighten_container_eltype(u) = u +function __solve(prob::EnsembleProblem{<:AbstractVector}, + alg::Union{AbstractDEAlgorithm, Nothing}, + ensemblealg::BasicEnsembleAlgorithm; kwargs...) + solve(prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...) +end + function __solve(prob::AbstractEnsembleProblem, alg::Union{AbstractDEAlgorithm, Nothing}, ensemblealg::BasicEnsembleAlgorithm; diff --git a/src/ensemble/ensemble_problems.jl b/src/ensemble/ensemble_problems.jl index 0d09f97da..7d00f90df 100644 --- a/src/ensemble/ensemble_problems.jl +++ b/src/ensemble/ensemble_problems.jl @@ -13,6 +13,10 @@ end DEFAULT_PROB_FUNC(prob, i, repeat) = prob DEFAULT_OUTPUT_FUNC(sol, i) = (sol, false) DEFAULT_REDUCTION(u, data, I) = append!(u, data), false +DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = remakie(prob[i]) +function EnsembleProblem(prob::AbstractVector; kwargs...) + EnsembleProblem(prob; kwargs..., prob_func=DEFAULT_VECTOR_PROB_FUNC) +end function EnsembleProblem(prob; output_func = DEFAULT_OUTPUT_FUNC, prob_func = DEFAULT_PROB_FUNC, From aff7c1a7c403c77207a3686a7b085589456a4d64 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 11 Jul 2023 09:50:14 -0400 Subject: [PATCH 2/6] fix typo and add test --- src/ensemble/basic_ensemble_solve.jl | 2 +- src/ensemble/ensemble_problems.jl | 4 ++-- test/downstream/ensemble_first_batch.jl | 6 +----- test/runtests.jl | 3 +++ 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/ensemble/basic_ensemble_solve.jl b/src/ensemble/basic_ensemble_solve.jl index 00bf89860..bdba437e8 100644 --- a/src/ensemble/basic_ensemble_solve.jl +++ b/src/ensemble/basic_ensemble_solve.jl @@ -42,7 +42,7 @@ end tighten_container_eltype(u::Vector{Any}) = map(identity, u) tighten_container_eltype(u) = u -function __solve(prob::EnsembleProblem{<:AbstractVector}, +function __solve(prob::EnsembleProblem{<:AbstractVector{<:AbstractSciMLProblem}}, alg::Union{AbstractDEAlgorithm, Nothing}, ensemblealg::BasicEnsembleAlgorithm; kwargs...) solve(prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...) diff --git a/src/ensemble/ensemble_problems.jl b/src/ensemble/ensemble_problems.jl index 7d00f90df..4a25e572a 100644 --- a/src/ensemble/ensemble_problems.jl +++ b/src/ensemble/ensemble_problems.jl @@ -13,8 +13,8 @@ end DEFAULT_PROB_FUNC(prob, i, repeat) = prob DEFAULT_OUTPUT_FUNC(sol, i) = (sol, false) DEFAULT_REDUCTION(u, data, I) = append!(u, data), false -DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = remakie(prob[i]) -function EnsembleProblem(prob::AbstractVector; kwargs...) +DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i] +function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...) EnsembleProblem(prob; kwargs..., prob_func=DEFAULT_VECTOR_PROB_FUNC) end function EnsembleProblem(prob; diff --git a/test/downstream/ensemble_first_batch.jl b/test/downstream/ensemble_first_batch.jl index 09043453c..ec2c6eabd 100644 --- a/test/downstream/ensemble_first_batch.jl +++ b/test/downstream/ensemble_first_batch.jl @@ -8,8 +8,4 @@ reduction(u, batch, I) = (append!(u, mean(batch)), false) # make sure first batch is timed (test using 1 batch but reduction) ensemble_prob = EnsembleProblem(prob, prob_func = prob_func, output_func = output_func, reduction = reduction) -sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 1000, - batch_size = 1000) - -@test sim.elapsedTime > 1000 * @elapsed for i in 2:1 -end +sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), batch_size = 2) diff --git a/test/runtests.jl b/test/runtests.jl index 5ca0c7f1a..e4c918a9b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -63,6 +63,9 @@ end @time @safetestset "Timing first batch when solving Ensembles" begin include("downstream/ensemble_first_batch.jl") end + @time @safetestset "solving Ensembles with multiple problems" begin + include("downstream/ensemble_multi_prob.jl") + end @time @safetestset "Symbol and integer based indexing of interpolated solutions" begin include("downstream/symbol_indexing.jl") end From 201fc0a1fa0a8acc4d8f0fa24dd3634e398c7998 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 11 Jul 2023 10:05:10 -0400 Subject: [PATCH 3/6] oops --- test/downstream/ensemble_first_batch.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/downstream/ensemble_first_batch.jl b/test/downstream/ensemble_first_batch.jl index ec2c6eabd..09043453c 100644 --- a/test/downstream/ensemble_first_batch.jl +++ b/test/downstream/ensemble_first_batch.jl @@ -8,4 +8,8 @@ reduction(u, batch, I) = (append!(u, mean(batch)), false) # make sure first batch is timed (test using 1 batch but reduction) ensemble_prob = EnsembleProblem(prob, prob_func = prob_func, output_func = output_func, reduction = reduction) -sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), batch_size = 2) +sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 1000, + batch_size = 1000) + +@test sim.elapsedTime > 1000 * @elapsed for i in 2:1 +end From d5f59186e4100b867d54729eedac461d56d3632f Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 11 Jul 2023 10:50:44 -0400 Subject: [PATCH 4/6] actually add the tests --- test/downstream/ensemble_multi_prob.jl | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 test/downstream/ensemble_multi_prob.jl diff --git a/test/downstream/ensemble_multi_prob.jl b/test/downstream/ensemble_multi_prob.jl new file mode 100644 index 000000000..4e684342f --- /dev/null +++ b/test/downstream/ensemble_multi_prob.jl @@ -0,0 +1,8 @@ +using OrdinaryDiffEq, Test + +prob1 = ODEProblem((u, p, t) -> 0.99u, 0.55, (0.0, 1.1)) +prob1 = ODEProblem((u, p, t) -> 1.0u, 0.45, (0.0, 0.9)) +output_func(sol, i) = (last(sol), false) +# test that when passing a vector of problems, trajectories and the prob_func are chosen appropriately +ensemble_prob = EnsembleProblem([prob1, prob2], output_func = output_func, reduction = reduction) +sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), batch_size = 2) From b9cd12c8f10038756914fae8f888fa61bd004171 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 11 Jul 2023 11:17:02 -0400 Subject: [PATCH 5/6] Update test/downstream/ensemble_multi_prob.jl --- test/downstream/ensemble_multi_prob.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/downstream/ensemble_multi_prob.jl b/test/downstream/ensemble_multi_prob.jl index 4e684342f..89ce717d6 100644 --- a/test/downstream/ensemble_multi_prob.jl +++ b/test/downstream/ensemble_multi_prob.jl @@ -5,4 +5,4 @@ prob1 = ODEProblem((u, p, t) -> 1.0u, 0.45, (0.0, 0.9)) output_func(sol, i) = (last(sol), false) # test that when passing a vector of problems, trajectories and the prob_func are chosen appropriately ensemble_prob = EnsembleProblem([prob1, prob2], output_func = output_func, reduction = reduction) -sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), batch_size = 2) +sim = solve(ensemble_prob, Tsit5(), EnsembleThreads()) From c9cac8f00b4fb34a8519ce8c8a8a3409e0ac0436 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 11 Jul 2023 11:18:10 -0400 Subject: [PATCH 6/6] Update test/downstream/ensemble_multi_prob.jl --- test/downstream/ensemble_multi_prob.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/downstream/ensemble_multi_prob.jl b/test/downstream/ensemble_multi_prob.jl index 89ce717d6..cc6040581 100644 --- a/test/downstream/ensemble_multi_prob.jl +++ b/test/downstream/ensemble_multi_prob.jl @@ -4,5 +4,5 @@ prob1 = ODEProblem((u, p, t) -> 0.99u, 0.55, (0.0, 1.1)) prob1 = ODEProblem((u, p, t) -> 1.0u, 0.45, (0.0, 0.9)) output_func(sol, i) = (last(sol), false) # test that when passing a vector of problems, trajectories and the prob_func are chosen appropriately -ensemble_prob = EnsembleProblem([prob1, prob2], output_func = output_func, reduction = reduction) +ensemble_prob = EnsembleProblem([prob1, prob2], output_func = output_func) sim = solve(ensemble_prob, Tsit5(), EnsembleThreads())