Skip to content

Commit

Permalink
fix indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Jul 17, 2023
1 parent 605824a commit 7ff39ae
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
8 changes: 4 additions & 4 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,16 @@ end
end


Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, s)
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
return VectorOfArray([xi[s] for xi in x])

Check warning on line 203 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L202-L203

Added lines #L202 - L203 were not covered by tests
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...)
return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...)
end
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Int...)
return VectorOfArray([xi[args...] for xi in x])
end
#Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, args::Int..., ::Colon)
# return VectorOfArray([xi[args...] for xi in x])
#end

function (sol::AbstractEnsembleSolution)(args...; kwargs...)
VectorOfArray([s(args...; kwargs...) for s in sol])

Check warning on line 214 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L214

Added line #L214 was not covered by tests
Expand Down
26 changes: 17 additions & 9 deletions test/downstream/ensemble_multi_prob.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
using ModelingToolkit, OrdinaryDiffEq, Test

@variables t, x(t)
@variables t, x(t), y(t)
D = Differential(t)

@named sys1 = ODESystem([D(x) ~ 1.1*x])
@named sys2 = ODESystem([D(x) ~ 1.2*x])
@named sys1 = ODESystem([D(x) ~ x,
D(y) ~ -y])
@named sys2 = ODESystem([D(x) ~ 2x,
D(y) ~ -2y])
@named sys3 = ODESystem([D(x) ~ 3x,
D(y) ~ -3y])

prob1 = ODEProblem(sys1, [2.0], (0.0, 1.0))
prob2 = ODEProblem(sys2, [1.0], (0.0, 1.0))
prob1 = ODEProblem(sys1, [1.0, 1.0], (0.0, 1.0))
prob2 = ODEProblem(sys2, [2.0, 2.0], (0.0, 1.0))
prob3 = ODEProblem(sys3, [3.0, 3.0], (0.0, 1.0))

# test that when passing a vector of problems, trajectories and the prob_func are chosen appropriately
ensemble_prob = EnsembleProblem([prob1, prob2])
ensemble_prob = EnsembleProblem([prob1, prob2, prob3])
sol = solve(ensemble_prob, Tsit5(), EnsembleThreads())
@test isapprox(sol[:, x], [2,1] .* map(Base.Fix1(map, exp), [1.1, 1.2] .* sol[:, t]), rtol=1e-4)
for i in 1:3
@test sol[x, :][i] == sol[i][x]
@test sol[y, :][i] == sol[i][y]
end
# Ensemble is a recursive array
@test sol(0.0, idxs=[x]) == sol[:, 1] == first.(sol[:, x], 1)
@test Matrix(sol(0.0, idxs=[x])) == sol[1:1, 1, :] == Matrix(first(eachrow(sol[x, :]))')
# TODO: fix the interpolation
@test sol(1.0, idxs=[x]) last.(sol[:, x], 1)
@test vec(sol(1.0, idxs=[x])) last.(sol[x, :].u)

0 comments on commit 7ff39ae

Please sign in to comment.