Skip to content

Commit

Permalink
Merge ac1a771 into 888a36b
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacsas committed Apr 3, 2023
2 parents 888a36b + ac1a771 commit ce57918
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 33 deletions.
75 changes: 44 additions & 31 deletions src/aggregators/direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,43 @@ end
function DirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S,
rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG;
kwargs...) where {T, S, F1, F2, RNG}
DirectJumpAggregation{T, S, F1, F2, RNG}(nj, nj, njt, et, crs, sr, maj, rs, affs!, sps,
rng)
affecttype = F2 <: Tuple ? F2 : Any
DirectJumpAggregation{T, S, F1, affecttype, RNG}(nj, nj, njt, et, crs, sr, maj, rs,
affs!, sps, rng)
end

######################### dispatches for type stablity #######################

@inline function concretize_affects!(p::DirectJumpAggregation, ::I) where {I <: DiffEqBase.DEIntegrator}
if p.affects! isa Vector{Any}
AffectWrapper = FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}
p.affects! = AffectWrapper[AffectWrapper(aff) for aff in p.affects!]
end
nothing
end

@inline function concretize_affects!(p::DirectJumpAggregation{T, S, F1, F2}, ::I) where {T, S, F1, F2 <: Tuple, I <: DiffEqBase.DEIntegrator}
nothing
end

# executing jump at the next jump time
function (p::DirectJumpAggregation)(integrator::I) where {I <: DiffEqBase.DEIntegrator}
affects! = p.affects!
if affects! isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}}
execute_jumps!(p, integrator, integrator.u, integrator.p, integrator.t, affects!)
else
error("Error, invalid affects! type in $(typeof(p))")
end
generate_jumps!(p, integrator, integrator.u, integrator.p, integrator.t)
register_next_jump_time!(integrator, p, integrator.t)
nothing
end

function (p::DirectJumpAggregation{T,S,F1,F2})(integrator::I) where {T, S, F1, F2 <: Tuple, I <: DiffEqBase.DEIntegrator}
execute_jumps!(p, integrator, integrator.u, integrator.p, integrator.t, p.affects!)
generate_jumps!(p, integrator, integrator.u, integrator.p, integrator.t)
register_next_jump_time!(integrator, p, integrator.t)
nothing
end

############################# Required Functions #############################
Expand All @@ -36,7 +71,7 @@ function aggregate(aggregator::DirectFW, u, p, t, end_time, constant_jumps,
ma_jumps, save_positions, rng; kwargs...)

# handle constant jumps using function wrappers
rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps)
rates, affects! = get_jump_info_fwrappers_direct(u, p, t, constant_jumps)

build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps,
rates, affects!, save_positions, rng; kwargs...)
Expand All @@ -50,8 +85,8 @@ function initialize!(p::DirectJumpAggregation, integrator, u, params, t)
end

# execute one jump, changing the system state
@inline function execute_jumps!(p::DirectJumpAggregation, integrator, u, params, t)
update_state!(p, integrator, u)
@inline function execute_jumps!(p::DirectJumpAggregation, integrator, u, params, t, affects!)
update_state!(p, integrator, u, affects!)
nothing
end

Expand All @@ -66,8 +101,8 @@ end
######################## SSA specific helper routines ########################

# tuple-based constant jumps
function time_to_next_jump(p::DirectJumpAggregation{T, S, F1, F2, RNG}, u, params,
t) where {T, S, F1 <: Tuple, F2, RNG}
function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params,
t) where {T, S, F1 <: Tuple}
prev_rate = zero(t)
new_rate = zero(t)
cur_rates = p.cur_rates
Expand Down Expand Up @@ -108,8 +143,8 @@ end
end

# function wrapper-based constant jumps
function time_to_next_jump(p::DirectJumpAggregation{T, S, F1, F2, RNG}, u, params,
t) where {T, S, F1 <: AbstractArray, F2, RNG}
function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params,
t) where {T, S, F1 <: AbstractArray}
prev_rate = zero(t)
new_rate = zero(t)
cur_rates = p.cur_rates
Expand All @@ -136,25 +171,3 @@ function time_to_next_jump(p::DirectJumpAggregation{T, S, F1, F2, RNG}, u, param
@inbounds sum_rate = cur_rates[end]
sum_rate, randexp(p.rng) / sum_rate
end

@generated function update_state!(p::DirectJumpAggregation{T, S, F1, F2}, integrator,
u) where {T, S, F1 <: Tuple, F2 <: Tuple}
quote
@unpack ma_jumps, next_jump = p
num_ma_rates = get_num_majumps(ma_jumps)
if next_jump <= num_ma_rates # is next jump a mass action jump
if u isa SVector
integrator.u = executerx(u, next_jump, ma_jumps)
else
@inbounds executerx!(u, next_jump, ma_jumps)
end
else
idx = next_jump - num_ma_rates
Base.Cartesian.@nif $(fieldcount(F2)) i->(i == idx) i->(@inbounds p.affects![i](integrator)) i->(@inbounds p.affects![fieldcount(F2)](integrator))
end

# save jump that was just executed
p.prev_jump = next_jump
return integrator.u
end
end
33 changes: 31 additions & 2 deletions src/aggregators/ssajump.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ end

# setting up a new simulation
function (p::AbstractSSAJumpAggregator)(dj, u, t, integrator) # initialize
concretize_affects!(p, integrator)
initialize!(p, integrator, u, integrator.p, t)
register_next_jump_time!(integrator, p, integrator.t)
u_modified!(integrator, false)
Expand Down Expand Up @@ -170,7 +171,7 @@ end
Execute `p.next_jump`.
"""
@inline function update_state!(p::AbstractSSAJumpAggregator, integrator, u)
@inline function update_state!(p::AbstractSSAJumpAggregator, integrator, u, affects!)
@unpack ma_jumps, next_jump = p
num_ma_rates = get_num_majumps(ma_jumps)
if next_jump <= num_ma_rates # is next jump a mass action jump
Expand All @@ -181,14 +182,40 @@ Execute `p.next_jump`.
end
else
idx = next_jump - num_ma_rates
@inbounds p.affects![idx](integrator)
@inbounds affects![idx](integrator)
end

# save jump that was just executed
p.prev_jump = next_jump
return integrator.u
end

@generated function update_state!(p::AbstractSSAJumpAggregator, integrator, u,
affects!::T) where {T <: Tuple}
quote
@unpack ma_jumps, next_jump = p
num_ma_rates = get_num_majumps(ma_jumps)
if next_jump <= num_ma_rates # is next jump a mass action jump
if u isa SVector
integrator.u = executerx(u, next_jump, ma_jumps)
else
@inbounds executerx!(u, next_jump, ma_jumps)
end
else
idx = next_jump - num_ma_rates
Base.Cartesian.@nif $(fieldcount(T)) i->(i == idx) i->(@inbounds affects![i](integrator)) i->(@inbounds affects![fieldcount(T)](integrator))
end

# save jump that was just executed
p.prev_jump = next_jump
return integrator.u
end
end

@inline update_state!(p::AbstractSSAJumpAggregator, integrator, u) =
update_state!(p, integrator, u, p.affects!)


"""
nomorejumps!(p, sum_rate) :: Bool
Expand Down Expand Up @@ -252,3 +279,5 @@ Perform rejection sampling test (used in RSSA methods).
end
return true
end

concretize_affects!(p::AbstractSSAJumpAggregator, integrator) = nothing
15 changes: 15 additions & 0 deletions src/jumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -703,3 +703,18 @@ function get_jump_info_fwrappers(u, p, t, constant_jumps)

rates, affects!
end

function get_jump_info_fwrappers_direct(u, p, t, constant_jumps)
RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t),
Tuple{typeof(u), typeof(p), typeof(t)}}

if (constant_jumps !== nothing) && !isempty(constant_jumps)
rates = [RateWrapper(c.rate) for c in constant_jumps]
affects! = Any[(x -> (c.affect!(x); nothing)) for c in constant_jumps]
else
rates = Vector{RateWrapper}()
affects! = Any[]
end

rates, affects!
end

0 comments on commit ce57918

Please sign in to comment.