Skip to content

Commit

Permalink
LambaEM initial attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Mar 31, 2018
1 parent 7ac5e8b commit 147755f
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/StochasticDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,12 @@ module StochasticDiffEq
export StochasticDiffEqAlgorithm, StochasticDiffEqAdaptiveAlgorithm,
StochasticCompositeAlgorithm

export EM, PCEuler, RKMil, SRA, SRI, SRIW1, SRA1, SRA2, SRA3, SOSRA, SOSRA2, RKMilCommute,
export EM, LambaEM, PCEuler, RKMil, SRA, SRI, SRIW1,
SRA1, SRA2, SRA3,
SOSRA, SOSRA2, RKMilCommute,
SRIW2, SOSRI, SOSRI2, SKenCarp

export EulerHeun
export EulerHeun, LambaEulerHeun

export SplitEM, IIF1M, IIF2M, IIF1Mil

Expand Down
3 changes: 3 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
struct EM <: StochasticDiffEqAlgorithm end
struct SplitEM <: StochasticDiffEqAlgorithm end
struct EulerHeun <: StochasticDiffEqAlgorithm end

struct LambaEM <: StochasticDiffEqAdaptiveAlgorithm end
struct LambaEulerHeun <: StochasticDiffEqAdaptiveAlgorithm end
struct RKMil{interpretation} <: StochasticDiffEqAdaptiveAlgorithm end
Base.@pure RKMil(;interpretation=:Ito) = RKMil{interpretation}()

Expand Down
50 changes: 50 additions & 0 deletions src/caches/basic_method_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,56 @@ function alg_cache(alg::RandomEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prot
RandomEMCache(u,uprev,tmp,rtmp)
end

struct LambaEMConstantCache <: StochasticDiffEqConstantCache end
struct LambaEMCache{uType,rateType,rateNoiseType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
du1::rateType
du2::rateType
K::rateType
tmp::uType
L::rateType
gtmp::rateNoiseType
end

u_cache(c::LambaEMCache) = ()
du_cache(c::LambaEMCache) = (c.du1,c.du2,c.K,c.L)

alg_cache(alg::LambaEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,uEltypeNoUnits,uBottomEltype,tTypeNoUnits,uprev,f,t,::Type{Val{false}}) = LambaEMConstantCache()

function alg_cache(alg::LambaEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,uEltypeNoUnits,uBottomEltype,tTypeNoUnits,uprev,f,t,::Type{Val{true}})
du1 = zeros(rate_prototype); du2 = zeros(rate_prototype)
K = zeros(rate_prototype); tmp = similar(u);
L = zeros(noise_rate_prototype)
gtmp = zeros(noise_rate_prototype)
LambaEMCache(u,uprev,du1,du2,K,tmp,L,gtmp)
end

struct LambaEulerHeunConstantCache <: StochasticDiffEqConstantCache end
struct LambaEulerHeunCache{uType,rateType,rateNoiseType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
du1::rateType
du2::rateType
K::rateType
tmp::uType
L::rateType
gtmp::rateNoiseType
end

u_cache(c::LambaEulerHeunCache) = ()
du_cache(c::LambaEulerHeunCache) = (c.du1,c.du2,c.K,c.L)

alg_cache(alg::LambaEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,uEltypeNoUnits,uBottomEltype,tTypeNoUnits,uprev,f,t,::Type{Val{false}}) = LambaEulerHeunConstantCache()

function alg_cache(alg::LambaEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,uEltypeNoUnits,uBottomEltype,tTypeNoUnits,uprev,f,t,::Type{Val{true}})
du1 = zeros(rate_prototype); du2 = zeros(rate_prototype)
K = zeros(rate_prototype); tmp = similar(u);
L = zeros(noise_rate_prototype)
gtmp = zeros(noise_rate_prototype)
LambaEulerHeunCache(u,uprev,du1,du2,K,tmp,L,gtmp)
end

struct RKMilConstantCache <: StochasticDiffEqConstantCache end
struct RKMilCache{uType,rateType} <: StochasticDiffEqMutableCache
u::uType
Expand Down
98 changes: 97 additions & 1 deletion src/perform_step/low_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,102 @@ end
integrator.u = u
end

@muladd function perform_step!(integrator,cache::Union{LambaEMConstantCache,LambaEulerHeunConstantCache},f=integrator.f)
@unpack t,dt,uprev,u,W,p = integrator
du1 = integrator.f(uprev,p,t)
K = @muladd uprev + dt*du1
L = integrator.g(uprev,p,t)
mil_correction = zero(u)

u = K+L*W.dW

if integrator.opts.adaptive
du2 = integrator.f(K,p,t+dt)
Ed = dt*(du2 - du1)/2

if typeof(cache) <: LambaEMConstantCache
utilde = K + L*integrator.sqdt
ggprime = (integrator.g(utilde,p,t).-L)./(integrator.sqdt)
En = ggprime.*(W.dW.^2 .- dt)./2
elseif typeof(cache) <: LambaEulerHeunConstantCache
utilde = uprev + L*integrator.sqdt
ggprime = (integrator.g(utilde,p,t).-L)./(integrator.sqdt)
En = ggprime.*(W.dW.^2)./2
end

integrator.EEst = integrator.opts.internalnorm((Ed + En)/((integrator.opts.abstol + max.(abs(uprev),abs(u))*integrator.opts.reltol)))
end

integrator.u = u
end

@muladd function perform_step!(integrator,cache::Union{LambaEMCache,LambaEulerHeunCache},f=integrator.f)
@unpack du1,du2,K,tmp,L,gtmp = cache
@unpack t,dt,uprev,u,W,p = integrator
integrator.f(du1,uprev,p,t)
integrator.g(L,uprev,p,t)
@. K = @muladd uprev + dt*du1

if is_diagonal_noise(integrator.sol.prob)
@. tmp=L*W.dW
else
A_mul_B!(tmp,L,W.dW)
end

@. u = K+tmp

if integrator.opts.adaptive

if !is_diagonal_noise(integrator.sol.prob)
g_sized = norm(L,2)
else
g_sized = L
end

if typeof(cache) <: LambaEMCache
@. tmp = @muladd K + L*integrator.sqdt

if !is_diagonal_noise(integrator.sol.prob)
integrator.g(gtmp,z,p,t)
g_sized2 = norm(gtmp,2)
@. tmp = dW.^2 - dt
diff_tmp = integrator.opts.internalnorm(tmp)
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
@. tmp = En
else
integrator.g(gtmp,tmp,p,t)
@. tmp = (gtmp-L)/(2integrator.sqdt)*(W.dW.^2 - dt)
end

elseif typeof(cache) <: LambaEulerHeunCache
@. tmp = @muladd uprev + L*integrator.sqdt

if !is_diagonal_noise(integrator.sol.prob)
integrator.g(gtmp,z,p,t)
g_sized2 = norm(gtmp,2)
@. tmp = dW.^2
diff_tmp = integrator.opts.internalnorm(tmp)
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
@. tmp = En
else
integrator.g(gtmp,tmp,p,t)
@. tmp = (gtmp-L)/(2integrator.sqdt)*(W.dW.^2)
end

end

# Ed
integrator.f(du2,K,p,t+dt)
@. tmp += integrator.opts.internalnorm(dt*(du2 - du1)/2)


@tight_loop_macros for (i,atol,rtol) in zip(eachindex(u),Iterators.cycle(integrator.opts.abstol),Iterators.cycle(integrator.opts.reltol))
@inbounds tmp[i] = (tmp[i])/(atol + max(abs(uprev[i]),abs(u[i]))*rtol)
end
integrator.EEst = integrator.opts.internalnorm(tmp)
end
end

@muladd function perform_step!(integrator,cache::RKMilConstantCache,f=integrator.f)
@unpack t,dt,uprev,u,W,p = integrator
du1 = integrator.f(uprev,p,t)
Expand Down Expand Up @@ -229,6 +325,6 @@ end
@inbounds tmp[i] = (tmp[i])/(atol + max(abs(uprev[i]),abs(u[i]))*rtol)
end
integrator.EEst = integrator.opts.internalnorm(tmp)

end
end

0 comments on commit 147755f

Please sign in to comment.