diff --git a/src/aggregators/aggregators.jl b/src/aggregators/aggregators.jl index 86b81273..6898da8c 100644 --- a/src/aggregators/aggregators.jl +++ b/src/aggregators/aggregators.jl @@ -188,6 +188,38 @@ needs_vartojumps_map(aggregator::RSSACR) = true supports_variablerates(aggregator::AbstractAggregatorAlgorithm) = false supports_variablerates(aggregator::Coevolve) = true +# true if aggregator supports hops, e.g. diffusion is_spatial(aggregator::AbstractAggregatorAlgorithm) = false is_spatial(aggregator::NSM) = true is_spatial(aggregator::DirectCRDirect) = true + +# return the fastest aggregator out of the available ones +function select_aggregator(jumps::JumpSet; vartojumps_map=nothing, jumptovars_map=nothing, dep_graph=nothing, spatial_system=nothing, hopping_constants=nothing) + + # detect if a spatial SSA should be used + !isnothing(spatial_system) && !isnothing(hopping_constants) && return DirectCRDirect + + # if variable rate jumps are present, return one of the two SSAs that support them + if num_vrjs(jumps) != 0 + any(isbounded, vrjs) && return Coevolve + return Direct + end + + # if the number of jumps is small, return the Direct + num_jumps(jumps) < 10 && return Direct + + # if there are only massaction jumps, we can build the species-jumps dependency graphs + can_build_dependency_graphs = num_crjs(jumps) == 0 && num_vrjs(jumps) == 0 + have_species_to_jumps_dependency_graphs = !isnothing(vartojumps_map) && !isnothing(jumptovars_map) + + # if we have the species-jumps dependency graphs or can build them, use one of the Rejection-based methods + if can_build_dependency_graphs || have_species_to_jumps_dependency_graphs + num_jumps(jumps) < 100 && return RSSA + return RSSACR + # if we have the jumps-jumps dependency graph, use the Composition-Rejection Direct method + elseif !isnothing(dep_graph) + return DirectCR + else + return Direct + end +end \ No newline at end of file diff --git a/src/problem.jl b/src/problem.jl index 67f8045c..698e0c04 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -171,9 +171,12 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::Abstr kwargs...) JumpProblem(prob, aggregator, JumpSet(jumps...); kwargs...) end -function JumpProblem(prob, jumps::JumpSet; kwargs...) - JumpProblem(prob, NullAggregator(), jumps; kwargs...) +function JumpProblem(prob, jumps::JumpSet; vartojumps_map=nothing, jumptovars_map=nothing, dep_graph=nothing, spatial_system=nothing, hopping_constants=nothing, kwargs...) + aggregator = select_aggregator(jumps::JumpSet; vartojumps_map=vartojumps_map, jumptovars_map=jumptovars_map, dep_graph=dep_graph, spatial_system=spatial_system, hopping_constants=hopping_constants) + return JumpProblem(prob, aggregator(), jumps; vartojumps_map=vartojumps_map, jumptovars_map=jumptovars_map, dep_graph=dep_graph, spatial_system=spatial_system, hopping_constants=hopping_constants, kwargs...) end +# this makes it easier to test the aggregator selection +JumpProblem(prob, aggregator::NullAggregator, jumps::JumpSet; kwargs...) = JumpProblem(prob, jumps; kwargs...) make_kwarg(; kwargs...) = kwargs diff --git a/test/geneexpr_test.jl b/test/geneexpr_test.jl index 24e1e414..56240348 100644 --- a/test/geneexpr_test.jl +++ b/test/geneexpr_test.jl @@ -12,7 +12,7 @@ dotestmean = true doprintmeans = false # SSAs to test -SSAalgs = (RDirect(), RSSACR(), Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(), +SSAalgs = (JumpProcesses.NullAggregator(), RDirect(), RSSACR(), Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(), NRM(), RSSA(), DirectCR(), Coevolve()) # numerical parameters @@ -116,3 +116,12 @@ end # end # println() # end + +# no-aggregator tests +jump_prob = JumpProblem(prob, majumps, save_positions = (false, false), + vartojumps_map = spec_to_dep_jumps, + jumptovars_map = jump_to_dep_specs, rng = rng) +@test abs(runSSAs(jump_prob) - expected_avg) < reltol * expected_avg + +jump_prob = JumpProblem(prob, majumps, save_positions = (false, false), rng = rng) +@test abs(runSSAs(jump_prob) - expected_avg) < reltol * expected_avg diff --git a/test/linearreaction_test.jl b/test/linearreaction_test.jl index d169b571..e0e2112c 100644 --- a/test/linearreaction_test.jl +++ b/test/linearreaction_test.jl @@ -16,7 +16,7 @@ tf = 0.1 baserate = 0.1 A0 = 100 exactmean = (t, ratevec) -> A0 * exp(-sum(ratevec) * t) -SSAalgs = [RSSACR(), Direct(), RSSA()] +SSAalgs = [RSSACR(), Direct(), RSSA(), JumpProcesses.NullAggregator()] spec_to_dep_jumps = [collect(1:Nrxs), []] jump_to_dep_specs = [[1, 2] for i in 1:Nrxs] diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index c44358e8..2b706d84 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -56,6 +56,9 @@ jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, push!(jump_problems, JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, spatial_system = grids[1], save_positions = (false, false), rng = rng)) +push!(jump_problems, + JumpProblem(prob, majumps, hopping_constants = hopping_constants, + spatial_system = grids[1], save_positions = (false, false), rng = rng)) # setup flattenned jump prob push!(jump_problems, JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants,