diff --git a/src/callbacks.jl b/src/callbacks.jl index 3c42bf8991..dc88da2387 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -43,8 +43,43 @@ end previous_condition = callback.condition(@view(integrator.uprev[callback.idxs]),integrator.tprev,integrator) end - if integrator.event_last_time - prev_sign = 0.0 + if integrator.event_last_time && abs(previous_condition) < callback.abstol + + # abs(previous_condition) < callback.abstol is for multiple events: only + # condition this on the correct event + + # If there was a previous event, utilize the derivative at the start to + # chose the previous sign. If the derivative is positive at tprev, then + # we treat the value as positive, and derivative is negative then we + # treat the value as negative, reguardless of the postiivity/negativity + # of the true value due to it being =0 sans floating point issues. + + if callback.interp_points==0 + ode_addsteps!(integrator) + end + + if typeof(integrator.cache) <: OrdinaryDiffEqMutableCache + if typeof(callback.idxs) <: Void + tmp = integrator.cache.tmp + else !(typeof(callback.idxs) <: Number) + tmp = @view integrator.cache.tmp[callback.idxs] + end + end + + if typeof(integrator.cache) <: OrdinaryDiffEqMutableCache && !(typeof(callback.idxs) <: Number) + ode_interpolant!(tmp,100eps(typeof(integrator.tprev)), + integrator,callback.idxs,Val{0}) + else + + tmp = ode_interpolant(100eps(typeof(integrator.tprev)), + integrator,callback.idxs,Val{0}) + end + + tmp_condition = callback.condition(tmp,integrator.tprev + + 100eps(typeof(integrator.tprev)), + integrator) + + prev_sign = sign((tmp_condition-previous_condition)/integrator.dt) else prev_sign = sign(previous_condition) end @@ -76,14 +111,13 @@ end tmp = ode_interpolant(Θs[i],integrator,callback.idxs,Val{0}) end new_sign = callback.condition(tmp,integrator.tprev+integrator.dt*Θs[i],integrator) - if prev_sign == 0 - prev_sign = sign(new_sign) - prev_sign_index = i - end + if ((prev_sign<0 && !(typeof(callback.affect!)<:Void)) || (prev_sign>0 && !(typeof(callback.affect_neg!)<:Void))) && prev_sign*new_sign<0 event_occurred = true interp_index = i break + else + prev_sign_index = i end end end