From 4f1a864efd03bb52eb9597f1c713932edd446329 Mon Sep 17 00:00:00 2001 From: Samuel Isaacson Date: Mon, 5 Nov 2018 13:08:23 -0500 Subject: [PATCH] testing performance tweaks --- src/aggregators/rssa.jl | 16 +++++++++------- test/bimolerx_test.jl | 14 ++++++++------ test/geneexpr_test.jl | 23 +++++++++++------------ 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/src/aggregators/rssa.jl b/src/aggregators/rssa.jl index 4d63fa58..dd23319f 100644 --- a/src/aggregators/rssa.jl +++ b/src/aggregators/rssa.jl @@ -22,6 +22,7 @@ mutable struct RSSAJumpAggregation{T,T2,S,F1,F2,RNG,VJMAP,JVMAP,BD,T2V} <: Abstr bracket_data::BD ulow::T2V uhigh::T2V + eventcnt::Int end function RSSAJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, @@ -57,7 +58,7 @@ function RSSAJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, RSSAJumpAggregation{T,eltype(U),S,F1,F2,RNG,typeof(vtoj_map),typeof(jtov_map),typeof(bd),typeof(ulow)}( nj, njt, et, crl_bnds, crh_bnds, sr, cs_bnds, maj, rs, - affs!, sps, rng, vtoj_map, jtov_map, bd, ulow, uhigh) + affs!, sps, rng, vtoj_map, jtov_map, bd, ulow, uhigh, 0) end @@ -73,6 +74,7 @@ function (p::RSSAJumpAggregation)(integrator) execute_jumps!(p, integrator, integrator.u, integrator.p, integrator.t) generate_jumps!(p, integrator, integrator.u, integrator.p, integrator.t) register_next_jump_time!(integrator, p, integrator.t) + p.eventcnt += 1 nothing end @@ -185,8 +187,8 @@ end crhigh = p.cur_rate_high majumps = p.ma_jumps num_majumps = get_num_majumps(majumps) - rerl = one(sum_rate) - #rerl = zero(sum_rate) + #rerl = one(sum_rate) + rerl = zero(sum_rate) notdone = true jidx = 0 @inbounds while notdone @@ -218,14 +220,14 @@ end end end - rerl *= rand(p.rng) - #rerl += randexp(p.rng) + #rerl *= rand(p.rng) + rerl += randexp(p.rng) end p.next_jump = jidx # update time to next jump - p.next_jump_time = t + (-one(sum_rate) / sum_rate) * log(rerl) - #p.next_jump_time = t + rerl / sum_rate + #p.next_jump_time = t + (-one(sum_rate) / sum_rate) * log(rerl) + p.next_jump_time = t + rerl / sum_rate nothing end diff --git a/test/bimolerx_test.jl b/test/bimolerx_test.jl index 9d8a805e..97035df9 100644 --- a/test/bimolerx_test.jl +++ b/test/bimolerx_test.jl @@ -1,5 +1,5 @@ using DiffEqBase, DiffEqJump -using Test +using Test, Statistics # using Plots; plotlyjs() doplot = false @@ -11,7 +11,7 @@ dotestmean = true doprintmeans = false # SSAs to test -SSAalgs = (Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(), NRM()) +SSAalgs = (Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(), NRM(), RSSA()) Nsims = 32000 tf = .01 @@ -49,6 +49,8 @@ netstoch = [1 => 3, 3 => -3] ] rates = [1., 2., .5, .75, .25] +spec_to_dep_jumps = [[1,3],[2,3],[4,5]] +jump_to_dep_specs = [[1,2],[1,2],[1,2,3],[1,2,3],[1,3]] majumps = MassActionJump(rates, reactstoch, netstoch) # average number of proteins in a simulation @@ -67,7 +69,7 @@ prob = DiscreteProblem(u0, (0.0, tf), rates) # plotting one full trajectory if doplot for alg in SSAalgs - jump_prob = JumpProblem(prob, alg, majumps) + jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs) sol = solve(jump_prob, SSAStepper()) plothand = plot(sol, seriestype=:steppost, reuse=false) display(plothand) @@ -78,7 +80,7 @@ end if dotestmean means = zeros(Float64,length(SSAalgs)) for (i,alg) in enumerate(SSAalgs) - jump_prob = JumpProblem(prob, alg, majumps, save_positions=(false,false)) + jump_prob = JumpProblem(prob, alg, majumps, save_positions=(false,false), vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs) means[i] = runSSAs(jump_prob) relerr = abs(means[i] - expected_avg) / expected_avg if doprintmeans @@ -100,7 +102,7 @@ end # # exact methods # for alg in SSAalgs # println("Solving with method: ", typeof(alg), ", using SSAStepper") -# jump_prob = JumpProblem(prob, alg, majumps) +# jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs) # @btime solve($jump_prob, SSAStepper()) # end # println() @@ -115,7 +117,7 @@ if dotestmean push!(majump_vec, MassActionJump(rates[i], reactstoch[i], netstoch[i])) end jset = JumpSet((),(),nothing,majump_vec) - jump_prob = JumpProblem(prob, Direct(), jset, save_positions=(false,false)) + jump_prob = JumpProblem(prob, Direct(), jset, save_positions=(false,false), vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs) meanval = runSSAs(jump_prob) relerr = abs(meanval - expected_avg) / expected_avg if doprintmeans diff --git a/test/geneexpr_test.jl b/test/geneexpr_test.jl index 691d2555..e0116529 100644 --- a/test/geneexpr_test.jl +++ b/test/geneexpr_test.jl @@ -3,8 +3,8 @@ using Test, Statistics # using Plots; plotlyjs() doplot = false -using BenchmarkTools -dobenchmark = false +# using BenchmarkTools +# dobenchmark = false dotestmean = true doprintmeans = false @@ -97,19 +97,18 @@ if dotestmean # @btime (runSSAs($jump_prob);) # end - @test abs(means[i] - expected_avg) < reltol*expected_avg end end # # benchmark performance -if dobenchmark - # exact methods - for alg in SSAalgs - println("Solving with method: ", typeof(alg), ", using SSAStepper") - jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs) - @btime solve($jump_prob, SSAStepper()) - end - println() -end +# if dobenchmark +# # exact methods +# for alg in SSAalgs +# println("Solving with method: ", typeof(alg), ", using SSAStepper") +# jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs) +# @btime solve($jump_prob, SSAStepper()) +# end +# println() +# end