Skip to content

Commit

Permalink
fix non-diagonal adaptivity
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Apr 3, 2018
1 parent 384ceea commit 13e8f07
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 28 deletions.
18 changes: 14 additions & 4 deletions src/caches/implicit_split_step_caches.jl
@@ -1,4 +1,6 @@
mutable struct ISSEMCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateType,F} <: StochasticDiffEqMutableCache
mutable struct ISSEMCache{uType,rateType,J,JC,UF,
uEltypeNoUnits,noiseRateType,F,dWType} <:
StochasticDiffEqMutableCache
u::uType
uprev::uType
du1::rateType
Expand All @@ -18,6 +20,7 @@ mutable struct ISSEMCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateType,F}
κ::uEltypeNoUnits
tol::uEltypeNoUnits
newton_iters::Int
dW_cache::dWType
end

u_cache(c::ISSEMCache) = (c.uprev2,c.z,c.dz)
Expand Down Expand Up @@ -52,12 +55,14 @@ function alg_cache(alg::ISSEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototy

if is_diagonal_noise(prob)
gtmp2 = gtmp
dW_cache = nothing
else
gtmp2 = similar(rate_prototype)
dW_cache = similar(ΔW)
end

ISSEMCache(u,uprev,du1,fsalfirst,k,z,dz,tmp,gtmp,gtmp2,J,W,jac_config,linsolve,uf,
ηold,κ,tol,10000)
ηold,κ,tol,10000,dW_cache)
end

mutable struct ISSEMConstantCache{F,uEltypeNoUnits} <: StochasticDiffEqConstantCache
Expand Down Expand Up @@ -88,7 +93,9 @@ function alg_cache(alg::ISSEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototy
ISSEMConstantCache(uf,ηold,κ,tol,100000)
end

mutable struct ISSEulerHeunCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateType,F} <: StochasticDiffEqMutableCache
mutable struct ISSEulerHeunCache{uType,rateType,J,JC,UF,uEltypeNoUnits,
noiseRateType,F,dWType} <:
StochasticDiffEqMutableCache
u::uType
uprev::uType
du1::rateType
Expand All @@ -109,6 +116,7 @@ mutable struct ISSEulerHeunCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRate
κ::uEltypeNoUnits
tol::uEltypeNoUnits
newton_iters::Int
dW_cache::dWType
end

u_cache(c::ISSEulerHeunCache) = (c.uprev2,c.z,c.dz)
Expand Down Expand Up @@ -145,12 +153,14 @@ function alg_cache(alg::ISSEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_

if is_diagonal_noise(prob)
gtmp3 = gtmp2
dW_cache = nothing
else
gtmp3 = similar(noise_rate_prototype)
dW_cache = similar(ΔW)
end

ISSEulerHeunCache(u,uprev,du1,fsalfirst,k,z,dz,tmp,gtmp,gtmp2,gtmp3,
J,W,jac_config,linsolve,uf,ηold,κ,tol,10000)
J,W,jac_config,linsolve,uf,ηold,κ,tol,10000,dW_cache)
end

mutable struct ISSEulerHeunConstantCache{F,uEltypeNoUnits} <: StochasticDiffEqConstantCache
Expand Down
20 changes: 16 additions & 4 deletions src/caches/lamba_caches.jl
@@ -1,5 +1,5 @@
struct LambaEMConstantCache <: StochasticDiffEqConstantCache end
struct LambaEMCache{uType,rateType,rateNoiseType} <: StochasticDiffEqMutableCache
struct LambaEMCache{uType,rateType,rateNoiseType,dWType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
du1::rateType
Expand All @@ -8,6 +8,7 @@ struct LambaEMCache{uType,rateType,rateNoiseType} <: StochasticDiffEqMutableCach
tmp::uType
L::rateType
gtmp::rateNoiseType
dW_cache::dWType
end

u_cache(c::LambaEMCache) = ()
Expand All @@ -20,11 +21,16 @@ function alg_cache(alg::LambaEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_proto
K = zeros(rate_prototype); tmp = similar(u);
L = zeros(noise_rate_prototype)
gtmp = zeros(noise_rate_prototype)
LambaEMCache(u,uprev,du1,du2,K,tmp,L,gtmp)
if is_diagonal_noise(prob)
dW_cache = nothing
else
dW_cache = similar(ΔW)
end
LambaEMCache(u,uprev,du1,du2,K,tmp,L,gtmp,dW_cache)
end

struct LambaEulerHeunConstantCache <: StochasticDiffEqConstantCache end
struct LambaEulerHeunCache{uType,rateType,rateNoiseType} <: StochasticDiffEqMutableCache
struct LambaEulerHeunCache{uType,rateType,rateNoiseType,dWType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
du1::rateType
Expand All @@ -33,6 +39,7 @@ struct LambaEulerHeunCache{uType,rateType,rateNoiseType} <: StochasticDiffEqMuta
tmp::uType
L::rateType
gtmp::rateNoiseType
dW_cache::dWType
end

u_cache(c::LambaEulerHeunCache) = ()
Expand All @@ -45,5 +52,10 @@ function alg_cache(alg::LambaEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rat
K = zeros(rate_prototype); tmp = similar(u);
L = zeros(noise_rate_prototype)
gtmp = zeros(noise_rate_prototype)
LambaEulerHeunCache(u,uprev,du1,du2,K,tmp,L,gtmp)
if is_diagonal_noise(prob)
dW_cache = nothing
else
dW_cache = similar(ΔW)
end
LambaEulerHeunCache(u,uprev,du1,du2,K,tmp,L,gtmp,dW_cache)
end
14 changes: 10 additions & 4 deletions src/caches/sdirk_caches.jl
@@ -1,4 +1,4 @@
mutable struct ImplicitEMCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateType,F} <: StochasticDiffEqMutableCache
mutable struct ImplicitEMCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateType,F,dWType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
du1::rateType
Expand All @@ -18,6 +18,7 @@ mutable struct ImplicitEMCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateTy
κ::uEltypeNoUnits
tol::uEltypeNoUnits
newton_iters::Int
dW_cache::dWType
end

u_cache(c::ImplicitEMCache) = (c.uprev2,c.z,c.dz)
Expand Down Expand Up @@ -52,12 +53,14 @@ function alg_cache(alg::ImplicitEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_pr

if is_diagonal_noise(prob)
gtmp2 = gtmp
dW_cache = nothing
else
gtmp2 = similar(rate_prototype)
dW_cache = similar(ΔW)
end

ImplicitEMCache(u,uprev,du1,fsalfirst,k,z,dz,tmp,gtmp,gtmp2,J,W,jac_config,linsolve,uf,
ηold,κ,tol,10000)
ηold,κ,tol,10000,dW_cache)
end

mutable struct ImplicitEMConstantCache{F,uEltypeNoUnits} <: StochasticDiffEqConstantCache
Expand Down Expand Up @@ -88,7 +91,7 @@ function alg_cache(alg::ImplicitEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_pr
ImplicitEMConstantCache(uf,ηold,κ,tol,100000)
end

mutable struct ImplicitEulerHeunCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateType,F} <: StochasticDiffEqMutableCache
mutable struct ImplicitEulerHeunCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateType,F,dWType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
du1::rateType
Expand All @@ -109,6 +112,7 @@ mutable struct ImplicitEulerHeunCache{uType,rateType,J,JC,UF,uEltypeNoUnits,nois
κ::uEltypeNoUnits
tol::uEltypeNoUnits
newton_iters::Int
dW_cache::dWType
end

u_cache(c::ImplicitEulerHeunCache) = (c.uprev2,c.z,c.dz)
Expand Down Expand Up @@ -145,12 +149,14 @@ function alg_cache(alg::ImplicitEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_

if is_diagonal_noise(prob)
gtmp3 = gtmp2
dW_cache = nothing
else
gtmp3 = similar(noise_rate_prototype)
dW_cache = similar(ΔW)
end

ImplicitEulerHeunCache(u,uprev,du1,fsalfirst,k,z,dz,tmp,gtmp,gtmp2,gtmp3,
J,W,jac_config,linsolve,uf,ηold,κ,tol,10000)
J,W,jac_config,linsolve,uf,ηold,κ,tol,10000,dW_cache)
end

mutable struct ImplicitEulerHeunConstantCache{F,uEltypeNoUnits} <: StochasticDiffEqConstantCache
Expand Down
10 changes: 5 additions & 5 deletions src/perform_step/implicit_split_step.jl
Expand Up @@ -124,7 +124,7 @@ end
ISSEulerHeunCache},
f=integrator.f)
@unpack t,dt,uprev,u,p = integrator
@unpack uf,du1,dz,z,k,J,W,jac_config,gtmp,gtmp2,tmp = cache
@unpack uf,du1,dz,z,k,J,W,jac_config,gtmp,gtmp2,tmp,dW_cache = cache
integrator.alg.symplectic ? a = dt/2 : a = dt
dW = integrator.W.dW
mass_matrix = integrator.sol.prob.mass_matrix
Expand Down Expand Up @@ -308,8 +308,8 @@ end
if !is_diagonal_noise(integrator.sol.prob)
integrator.g(gtmp,z,p,t)
g_sized2 = norm(gtmp,2)
@. dz = dW.^2 - dt
diff_tmp = integrator.opts.internalnorm(dz)
@. dW_cache = dW.^2 - dt
diff_tmp = integrator.opts.internalnorm(dW_cache)
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
@. dz = En
else
Expand All @@ -324,8 +324,8 @@ end
if !is_diagonal_noise(integrator.sol.prob)
integrator.g(gtmp,z,p,t)
g_sized2 = norm(gtmp,2)
@. dz = dW.^2
diff_tmp = integrator.opts.internalnorm(dz)
@. dW_cache = dW.^2
diff_tmp = integrator.opts.internalnorm(dW_cache)
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
@. dz = En
else
Expand Down
12 changes: 6 additions & 6 deletions src/perform_step/lamba.jl
Expand Up @@ -28,7 +28,7 @@
end

@muladd function perform_step!(integrator,cache::LambaEMCache,f=integrator.f)
@unpack du1,du2,K,tmp,L,gtmp = cache
@unpack du1,du2,K,tmp,L,gtmp,dW_cache = cache
@unpack t,dt,uprev,u,W,p = integrator

integrator.f(du1,uprev,p,t)
Expand Down Expand Up @@ -61,8 +61,8 @@ end
if !is_diagonal_noise(integrator.sol.prob)
integrator.g(gtmp,tmp,p,t)
g_sized2 = norm(gtmp,2)
@. tmp = dW.^2 - dt
diff_tmp = integrator.opts.internalnorm(tmp)
@. dW_cache = dW.^2 - dt
diff_tmp = integrator.opts.internalnorm(dW_cache)
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
@. tmp = En
else
Expand Down Expand Up @@ -118,7 +118,7 @@ end
end

@muladd function perform_step!(integrator,cache::LambaEulerHeunCache,f=integrator.f)
@unpack du1,du2,K,tmp,L,gtmp = cache
@unpack du1,du2,K,tmp,L,gtmp,dW_cache = cache
@unpack t,dt,uprev,u,W,p = integrator
integrator.f(du1,uprev,p,t)
integrator.g(L,uprev,p,t)
Expand Down Expand Up @@ -168,8 +168,8 @@ end
if !is_diagonal_noise(integrator.sol.prob)
integrator.g(gtmp,tmp,p,t)
g_sized2 = norm(gtmp,2)
@. tmp = dW.^2
diff_tmp = integrator.opts.internalnorm(tmp)
@. dW_cache = dW.^2
diff_tmp = integrator.opts.internalnorm(dW_cache)
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
@. tmp = En
else
Expand Down
10 changes: 5 additions & 5 deletions src/perform_step/sdirk.jl
Expand Up @@ -318,7 +318,7 @@ end
# k is Ed
# dz is En
if typeof(cache) <: Union{ImplicitEMCache,ImplicitEulerHeunCache}

dW_cache = cache.dW_cache
if !is_diagonal_noise(integrator.sol.prob)
g_sized = norm(gtmp,2)
else
Expand All @@ -331,8 +331,8 @@ end
if !is_diagonal_noise(integrator.sol.prob)
integrator.g(gtmp,z,p,t)
g_sized2 = norm(gtmp,2)
@. dz = dW.^2 - dt
diff_tmp = integrator.opts.internalnorm(dz)
@. dW_cache = dW.^2 - dt
diff_tmp = integrator.opts.internalnorm(dW_cache)
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
@. dz = En
else
Expand All @@ -347,8 +347,8 @@ end
if !is_diagonal_noise(integrator.sol.prob)
integrator.g(gtmp,z,p,t)
g_sized2 = norm(gtmp,2)
@. dz = dW.^2
diff_tmp = integrator.opts.internalnorm(dz)
@. dW_cache = dW.^2
diff_tmp = integrator.opts.internalnorm(dW_cache)
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
@. dz = En
else
Expand Down

0 comments on commit 13e8f07

Please sign in to comment.