Skip to content

Commit

Permalink
Merge pull request #229 from JuliaDiffEq/callback_cache
Browse files Browse the repository at this point in the history
Adds CallbackCache
  • Loading branch information
ChrisRackauckas committed Jun 4, 2019
2 parents 025c16e + dddb333 commit 0b2de6d
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ function get_condition(integrator::DEIntegrator, callback, abst)
end
integrator.sol.destats.ncondition += 1
if callback isa VectorContinuousCallback
callback.condition(@view(integrator.callback_cache[1:callback.len]),tmp,abst,integrator)
return @view(integrator.callback_cache[1:callback.len])
callback.condition(@view(integrator.callback_cache.tmp_condition[1:callback.len]),tmp,abst,integrator)
return @view(integrator.callback_cache.tmp_condition[1:callback.len])
else
return callback.condition(tmp,abst,integrator)
end
Expand Down Expand Up @@ -251,7 +251,7 @@ end
Θs = range(typeof(integrator.t)(0), stop=typeof(integrator.t)(1), length=callback.interp_points)
interp_index = 0
# Check if the event occured
previous_condition = @views(integrator.previous_condition[1:callback.len])
previous_condition = @views(integrator.callback_cache.previous_condition[1:callback.len])

if typeof(callback.idxs) <: Nothing
callback.condition(previous_condition,integrator.uprev,integrator.tprev,integrator)
Expand All @@ -261,8 +261,8 @@ end

integrator.sol.destats.ncondition += 1
ivec = integrator.vector_event_last_time
prev_sign = @view(integrator.callback_prev_sign[1:callback.len])
next_sign = @view(integrator.callback_next_sign[1:callback.len])
prev_sign = @view(integrator.callback_cache.prev_sign[1:callback.len])
next_sign = @view(integrator.callback_cache.next_sign[1:callback.len])


if integrator.event_last_time == counter && minimum(ODE_DEFAULT_NORM(previous_condition[ivec],integrator.t)) < 100ODE_DEFAULT_NORM(integrator.last_event_error,integrator.t)
Expand Down Expand Up @@ -628,4 +628,19 @@ function max_vector_callback_length(cs::CallbackSet)
end
end
maxlen
end

mutable struct CallbackCache{conditionType,signType}
tmp_condition::conditionType
previous_condition::conditionType
next_sign::signType
prev_sign::signType
end

function CallbackCache(max_len,conditionType::Type,signType::Type)
tmp_condition = zeros(conditionType, max_len)
previous_condition = zeros(conditionType, max_len)
next_sign = zeros(signType, max_len)
prev_sign = zeros(signType, max_len)
CallbackCache{Array{conditionType},Array{signType}}(tmp_condition,previous_condition,next_sign,prev_sign)
end

0 comments on commit 0b2de6d

Please sign in to comment.