Skip to content

Commit

Permalink
fix stochastic resizing bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Dec 23, 2017
1 parent 965f778 commit 9b012b4
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
6 changes: 1 addition & 5 deletions src/callbacks.jl
Expand Up @@ -109,13 +109,9 @@ function find_callback_time(integrator,callback)
else
tmp = sde_interpolant(Θ,integrator,callback.idxs,Val{0})
end
@show tmp
callback.condition(integrator.tprev+Θ*integrator.dt,tmp,integrator)
end
@show (bottom_θ,top_Θ)
@show integrator.tprev,integrator.t
@show integrator.uprev,integrator.u
Θ = prevfloat(prevfloat(find_zero(zero_func,(bottom_θ,top_Θ),FalsePosition(),abstol = callback.abstol/10,verbose = true)))
Θ = prevfloat(prevfloat(find_zero(zero_func,(bottom_θ,top_Θ),FalsePosition(),abstol = callback.abstol/10)))
# 2 prevfloat guerentees that the new time is either 1 or 2 floating point
# numbers just before the event, but not after. If there's a barrier
# which is never supposed to be crossed, then this will ensure that
Expand Down
21 changes: 17 additions & 4 deletions src/integrators/integrator_interface.jl
Expand Up @@ -70,7 +70,7 @@ function resize_noise!(integrator,cache,bot_idx,i)
if alg_needs_extra_process(integrator.alg)
resize!(c[3],i)
end
if i > bot_idx # fill in rands
if i >= bot_idx # fill in rands
fill_new_noise_caches!(integrator,c,c[1],bot_idx:i)
end
end
Expand All @@ -79,23 +79,31 @@ function resize_noise!(integrator,cache,bot_idx,i)
if alg_needs_extra_process(integrator.alg)
resize!(c[3],i)
end
if i > bot_idx # fill in rands
if i >= bot_idx # fill in rands
fill_new_noise_caches!(integrator,c,c[1],bot_idx:i)
end
end
resize!(integrator.W.dW,i)
integrator.W.dW[end] = zero(eltype(integrator.u))
resize!(integrator.W.dWtilde,i)
integrator.W.dWtilde[end] = zero(eltype(integrator.u))
resize!(integrator.W.dWtmp,i)
integrator.W.dWtmp[end] = zero(eltype(integrator.u))
resize!(integrator.W.curW,i)
integrator.W.curW[end] = zero(eltype(integrator.u))
DiffEqNoiseProcess.resize_stack!(integrator.W,i)

if alg_needs_extra_process(integrator.alg)
resize!(integrator.W.dZ,i)
integrator.W.dZ[end] = zero(eltype(integrator.u))
resize!(integrator.W.dZtilde,i)
integrator.W.dZtilde[end] = zero(eltype(integrator.u))
resize!(integrator.W.dZtmp,i)
integrator.W.dZtmp[end] = zero(eltype(integrator.u))
resize!(integrator.W.curZ,i)
integrator.W.curZ[end] = zero(eltype(integrator.u))
end
if i > bot_idx # fill in rands
if i >= bot_idx # fill in rands
fill!(@view(integrator.W.curW[bot_idx:i]),zero(eltype(integrator.u)))
if alg_needs_extra_process(integrator.alg)
fill!(@view(integrator.W.curZ[bot_idx:i]),zero(eltype(integrator.u)))
Expand All @@ -115,10 +123,11 @@ end
c[3][idxs] .= integrator.noise(length(idxs),integrator,scaling_factor)
end
end

end

function resize_non_user_cache!(integrator::SDEIntegrator,cache,i)
bot_idx = length(integrator.u)
bot_idx = length(integrator.u) + 1
resize_noise!(integrator,cache,bot_idx,i)
for c in default_non_user_cache(integrator)
resize!(c,i)
Expand Down Expand Up @@ -203,10 +212,14 @@ function addat_noise!(integrator,cache,idxs)
end

addat!(integrator.W.dW,idxs)
integrator.W.dW[idxs] .= zero(eltype(integrator.u))
addat!(integrator.W.curW,idxs)
integrator.W.curW[idxs] .= zero(eltype(integrator.u))
if alg_needs_extra_process(integrator.alg)
addat!(integrator.W.dZ,idxs)
integrator.W.dZ[idxs] .= zero(eltype(integrator.u))
addat!(integrator.W.curZ,idxs)
integrator.W.curZ[idxs] .= zero(eltype(integrator.u))
end

i = length(integrator.u)
Expand Down
5 changes: 3 additions & 2 deletions test/cache_test.jl
Expand Up @@ -2,13 +2,13 @@ using StochasticDiffEq, DiffEqBase

function f(t,u,du)
for i in 1:length(u)
du[i] = (0.3/length(u))*u[i]
du[i] = (0.5/length(u))*u[i]
end
end

function g(t,u,du)
for i in 1:length(u)
du[i] = (0.3/length(u))*u[i]
du[i] = (0.5/length(u))*u[i]
end
end

Expand Down Expand Up @@ -43,6 +43,7 @@ p2 = plot(sol.t,map((x)->length(x),sol[:]),lw=3,
plot(p1,p2,layout=(2,1),size=(600,1000))
=#

srand(3)
sol = solve(prob,EM(),callback=callback,dt=1/4)
sol = solve(prob,RKMil(),callback=callback,dt=1/4)
sol = solve(prob,SRI(),callback=callback)
Expand Down

0 comments on commit 9b012b4

Please sign in to comment.