Skip to content

Commit

Permalink
Merge pull request #328 from isaacsas/allocations_test_fix
Browse files Browse the repository at this point in the history
Allocations test fix
  • Loading branch information
isaacsas committed Jul 11, 2023
2 parents 029aa52 + f4a02fc commit d3a1e33
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions test/allocations.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using Test, JumpProcesses
using StableRNGs

# tests for https://github.com/SciML/JumpProcesses.jl/issues/305

let
rng = StableRNG(123)
save_positions = (false, false)

β = 0.1 / 1000.0
Expand Down Expand Up @@ -33,14 +35,14 @@ let
u₀ = [999, 10, 0]
tspan = (0.0, 250.0)
dprob = DiscreteProblem(u₀, tspan, p)
jprob = JumpProblem(dprob, Direct(), maj, jump, jump2; save_positions)
jprob = JumpProblem(dprob, Direct(), maj, jump, jump2; save_positions, rng)
sol = solve(jprob, SSAStepper())

al1 = @allocations solve(jprob, SSAStepper())

tspan2 = (0.0, 2500.0)
dprob2 = DiscreteProblem(u₀, tspan2, p)
jprob2 = JumpProblem(dprob2, Direct(), maj, jump, jump2; save_positions)
jprob2 = JumpProblem(dprob2, Direct(), maj, jump, jump2; save_positions, rng)
sol2 = solve(jprob2, SSAStepper())

al2 = @allocations solve(jprob2, SSAStepper())
Expand All @@ -54,7 +56,7 @@ let
end

function makeprob(; T = 100.0, alg = Direct(), save_positions = (false, false),
graphkwargs = (;))
graphkwargs = (;), rng)
r1(u, p, t) = rate(p[1], u[1], u[2], p[2]) * u[1]
r2(u, p, t) = rate(p[1], u[2], u[1], p[2]) * u[2]
r3(u, p, t) = p[3] * u[1]
Expand Down Expand Up @@ -84,7 +86,7 @@ let
ConstantRateJump(r3, aff3!),
ConstantRateJump(r4, aff4!), ConstantRateJump(r5, aff5!),
ConstantRateJump(r6, aff6!);
save_positions, graphkwargs...)
save_positions, rng, graphkwargs...)
return jprob
end

Expand All @@ -97,12 +99,14 @@ let
graphkwargs = (; dep_graph, vartojumps_map, jumptovars_map)

@testset "Allocations for $agg" for agg in JumpProcesses.JUMP_AGGREGATORS
jprob1 = makeprob(; alg = agg, T = 10.0, graphkwargs)
jprob2 = makeprob(; alg = agg, T = 100.0, graphkwargs)
jprob1 = makeprob(; alg = agg, T = 10.0, graphkwargs, rng = StableRNG(1234))
stepper = SSAStepper()
sol1 = solve(jprob1, stepper)
sol1 = solve(jprob1, stepper)
al1 = @allocated solve(jprob1, stepper)
sol2 = solve(jprob2, SSAStepper())
jprob2 = makeprob(; alg = agg, T = 100.0, graphkwargs, rng = StableRNG(1234))
sol2 = solve(jprob2, stepper)
sol2 = solve(jprob2, stepper)
al2 = @allocated solve(jprob2, stepper)
@test al1 == al2
end
Expand Down

0 comments on commit d3a1e33

Please sign in to comment.