Skip to content

Commit

Permalink
testing performance tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacsas committed Nov 5, 2018
1 parent d3d22ff commit 4f1a864
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 25 deletions.
16 changes: 9 additions & 7 deletions src/aggregators/rssa.jl
Expand Up @@ -22,6 +22,7 @@ mutable struct RSSAJumpAggregation{T,T2,S,F1,F2,RNG,VJMAP,JVMAP,BD,T2V} <: Abstr
bracket_data::BD
ulow::T2V
uhigh::T2V
eventcnt::Int
end

function RSSAJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T,
Expand Down Expand Up @@ -57,7 +58,7 @@ function RSSAJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T,

RSSAJumpAggregation{T,eltype(U),S,F1,F2,RNG,typeof(vtoj_map),typeof(jtov_map),typeof(bd),typeof(ulow)}(
nj, njt, et, crl_bnds, crh_bnds, sr, cs_bnds, maj, rs,
affs!, sps, rng, vtoj_map, jtov_map, bd, ulow, uhigh)
affs!, sps, rng, vtoj_map, jtov_map, bd, ulow, uhigh, 0)
end


Expand All @@ -73,6 +74,7 @@ function (p::RSSAJumpAggregation)(integrator)
execute_jumps!(p, integrator, integrator.u, integrator.p, integrator.t)
generate_jumps!(p, integrator, integrator.u, integrator.p, integrator.t)
register_next_jump_time!(integrator, p, integrator.t)
p.eventcnt += 1
nothing
end

Expand Down Expand Up @@ -185,8 +187,8 @@ end
crhigh = p.cur_rate_high
majumps = p.ma_jumps
num_majumps = get_num_majumps(majumps)
rerl = one(sum_rate)
#rerl = zero(sum_rate)
#rerl = one(sum_rate)
rerl = zero(sum_rate)
notdone = true
jidx = 0
@inbounds while notdone
Expand Down Expand Up @@ -218,14 +220,14 @@ end
end
end

rerl *= rand(p.rng)
#rerl += randexp(p.rng)
#rerl *= rand(p.rng)
rerl += randexp(p.rng)
end
p.next_jump = jidx

# update time to next jump
p.next_jump_time = t + (-one(sum_rate) / sum_rate) * log(rerl)
#p.next_jump_time = t + rerl / sum_rate
#p.next_jump_time = t + (-one(sum_rate) / sum_rate) * log(rerl)
p.next_jump_time = t + rerl / sum_rate

nothing
end
14 changes: 8 additions & 6 deletions test/bimolerx_test.jl
@@ -1,5 +1,5 @@
using DiffEqBase, DiffEqJump
using Test
using Test, Statistics

# using Plots; plotlyjs()
doplot = false
Expand All @@ -11,7 +11,7 @@ dotestmean = true
doprintmeans = false

# SSAs to test
SSAalgs = (Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(), NRM())
SSAalgs = (Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(), NRM(), RSSA())

Nsims = 32000
tf = .01
Expand Down Expand Up @@ -49,6 +49,8 @@ netstoch =
[1 => 3, 3 => -3]
]
rates = [1., 2., .5, .75, .25]
spec_to_dep_jumps = [[1,3],[2,3],[4,5]]
jump_to_dep_specs = [[1,2],[1,2],[1,2,3],[1,2,3],[1,3]]
majumps = MassActionJump(rates, reactstoch, netstoch)

# average number of proteins in a simulation
Expand All @@ -67,7 +69,7 @@ prob = DiscreteProblem(u0, (0.0, tf), rates)
# plotting one full trajectory
if doplot
for alg in SSAalgs
jump_prob = JumpProblem(prob, alg, majumps)
jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs)
sol = solve(jump_prob, SSAStepper())
plothand = plot(sol, seriestype=:steppost, reuse=false)
display(plothand)
Expand All @@ -78,7 +80,7 @@ end
if dotestmean
means = zeros(Float64,length(SSAalgs))
for (i,alg) in enumerate(SSAalgs)
jump_prob = JumpProblem(prob, alg, majumps, save_positions=(false,false))
jump_prob = JumpProblem(prob, alg, majumps, save_positions=(false,false), vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs)
means[i] = runSSAs(jump_prob)
relerr = abs(means[i] - expected_avg) / expected_avg
if doprintmeans
Expand All @@ -100,7 +102,7 @@ end
# # exact methods
# for alg in SSAalgs
# println("Solving with method: ", typeof(alg), ", using SSAStepper")
# jump_prob = JumpProblem(prob, alg, majumps)
# jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs)
# @btime solve($jump_prob, SSAStepper())
# end
# println()
Expand All @@ -115,7 +117,7 @@ if dotestmean
push!(majump_vec, MassActionJump(rates[i], reactstoch[i], netstoch[i]))
end
jset = JumpSet((),(),nothing,majump_vec)
jump_prob = JumpProblem(prob, Direct(), jset, save_positions=(false,false))
jump_prob = JumpProblem(prob, Direct(), jset, save_positions=(false,false), vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs)
meanval = runSSAs(jump_prob)
relerr = abs(meanval - expected_avg) / expected_avg
if doprintmeans
Expand Down
23 changes: 11 additions & 12 deletions test/geneexpr_test.jl
Expand Up @@ -3,8 +3,8 @@ using Test, Statistics

# using Plots; plotlyjs()
doplot = false
using BenchmarkTools
dobenchmark = false
# using BenchmarkTools
# dobenchmark = false

dotestmean = true
doprintmeans = false
Expand Down Expand Up @@ -97,19 +97,18 @@ if dotestmean
# @btime (runSSAs($jump_prob);)
# end


@test abs(means[i] - expected_avg) < reltol*expected_avg
end
end


# # benchmark performance
if dobenchmark
# exact methods
for alg in SSAalgs
println("Solving with method: ", typeof(alg), ", using SSAStepper")
jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs)
@btime solve($jump_prob, SSAStepper())
end
println()
end
# if dobenchmark
# # exact methods
# for alg in SSAalgs
# println("Solving with method: ", typeof(alg), ", using SSAStepper")
# jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs)
# @btime solve($jump_prob, SSAStepper())
# end
# println()
# end

0 comments on commit 4f1a864

Please sign in to comment.