Skip to content

Commit

Permalink
all resize with multiscale
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Feb 14, 2017
1 parent a6c5fc3 commit feee402
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 70 deletions.
130 changes: 66 additions & 64 deletions src/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,33 @@ function alg_cache{T,algType<:StochasticCompositeAlgorithm}(alg::algType,u,ΔW,
end

immutable EMConstantCache <: StochasticDiffEqConstantCache end
immutable EMCache{uType} <: StochasticDiffEqMutableCache
immutable EMCache{uType,rateType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
tmp::uType
utmp2::uType
rtmp1::rateType
rtmp2::rateType
end

u_cache(c::EMCache) = ()
du_cache(c::EMCache) = (c.utmp2,)
du_cache(c::EMCache) = (c.rtmp1,c.rtmp2)

alg_cache(alg::EM,u,ΔW,ΔZ,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev,f,t,::Type{Val{false}}) = EMConstantCache()

function alg_cache(alg::EM,u,ΔW,ΔZ,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev,f,t,::Type{Val{true}})
tmp = similar(u); utmp2 = similar(u)
EMCache(u,uprev,tmp,utmp2)
tmp = similar(u); rtmp1 = zeros(rate_prototype); rtmp2 = zeros(rate_prototype)
EMCache(u,uprev,tmp,rtmp1,rtmp2)
end

immutable RKMilConstantCache <: StochasticDiffEqConstantCache end
immutable RKMilCache{uType} <: StochasticDiffEqMutableCache
immutable RKMilCache{uType,rateType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
du1::uType
du2::uType
K::uType
du1::rateType
du2::rateType
K::rateType
tmp::uType
L::uType
L::rateType
end

u_cache(c::RKMilCache) = ()
Expand All @@ -48,36 +49,37 @@ du_cache(c::RKMilCache) = (c.du1,c.du2,c.K,c.L)
alg_cache(alg::RKMil,u,ΔW,ΔZ,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev,f,t,::Type{Val{false}}) = RKMilConstantCache()

function alg_cache(alg::RKMil,u,ΔW,ΔZ,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev,f,t,::Type{Val{true}})
du1 = similar(u); du2 = similar(u)
K = similar(u); tmp = similar(u); L = similar(u)
du1 = zeros(rate_prototype); du2 = zeros(rate_prototype)
K = zeros(rate_prototype); tmp = similar(u); L = zeros(rate_prototype)
RKMilCache(u,uprev,du1,du2,K,tmp,L)
end

immutable SRA1ConstantCache <: StochasticDiffEqConstantCache end
alg_cache(alg::SRA1,u,ΔW,ΔZ,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev,f,t,::Type{Val{false}}) = SRA1ConstantCache()

immutable SRA1Cache{randType,uType} <: StochasticDiffEqMutableCache
immutable SRA1Cache{randType,rateType,uType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
chi2::randType
tmp1::uType
E₁::uType
E₂::uType
gt::uType
k₁::uType
k₂::uType
gpdt::uType
E₁::rateType
E₂::rateType
gt::rateType
k₁::rateType
k₂::rateType
gpdt::rateType
tmp::uType
end

u_cache(c::SRA1Cache) = ()
du_cache(c::SRA1Cache) = (c.chi2,c.tmp1,c.E₁,c.E₂,c.gt,c.k₁,c.k₂,c.gpdt)
du_cache(c::SRA1Cache) = (c.chi2,c.E₁,c.E₂,c.gt,c.k₁,c.k₂,c.gpdt)
user_cache(c::SRA1Cache) = (c.u,c.uprev,c.tmp,c.tmp1)

function alg_cache(alg::SRA1,u,ΔW,ΔZ,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev,f,t,::Type{Val{true}})
chi2 = similar(ΔW)
tmp1 = zeros(u)
E₁ = zeros(u); gt = zeros(u); gpdt = zeros(u)
E₂ = zeros(u); k₁ = zeros(u); k₂ = zeros(u)
E₁ = zeros(rate_prototype); gt = zeros(rate_prototype); gpdt = zeros(rate_prototype)
E₂ = zeros(rate_prototype); k₁ = zeros(rate_prototype); k₂ = zeros(rate_prototype)
tmp = zeros(u)
SRA1Cache(u,uprev,chi2,tmp1,E₁,E₂,gt,k₁,k₂,gpdt,tmp)
end
Expand Down Expand Up @@ -105,38 +107,39 @@ function alg_cache(alg::SRA,u,ΔW,ΔZ,rate_prototype,uEltypeNoUnits,tTypeNoUnits
SRAConstantCache(alg.tableau,rate_prototype)
end

immutable SRACache{uType,tabType} <: StochasticDiffEqMutableCache
immutable SRACache{uType,rateType,tabType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
H0::Vector{uType}
A0temp::uType
B0temp::uType
ftmp::uType
gtmp::uType
chi2::uType
atemp::uType
btemp::uType
E₁::uType
E₁temp::uType
E₂::uType
A0temp::rateType
B0temp::rateType
ftmp::rateType
gtmp::rateType
chi2::rateType
atemp::rateType
btemp::rateType
E₁::rateType
E₁temp::rateType
E₂::rateType
tmp::uType
tab::tabType
end

u_cache(c::SRACache) = ()
du_cache(c::SRACache) = (c.A0temp,c.B0temp,c.ftmp,c.gtmp,c.chi2,c.chi2,c.atemp,
c.btemp,c.E₁,c.E₁temp,c.E₂,c.H0...)
c.btemp,c.E₁,c.E₁temp,c.E₂)
user_cache(c::SRACache) = (c.u,c.uprev,c.tmp,c.H0...)

function alg_cache(alg::SRA,u,ΔW,ΔZ,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev,f,t,::Type{Val{true}})
H0 = Vector{typeof(u)}(0)
tab = SRAConstantCache(alg.tableau,rate_prototype)
for i = 1:tab.stages
push!(H0,zeros(u))
end
A0temp = zeros(u); B0temp = zeros(u)
ftmp = zeros(u); gtmp = zeros(u); chi2 = zeros(u)
atemp = zeros(u); btemp = zeros(u); E₂ = zeros(u); E₁temp = zeros(u)
E₁ = zeros(u)
A0temp = zeros(rate_prototype); B0temp = zeros(rate_prototype)
ftmp = zeros(rate_prototype); gtmp = zeros(rate_prototype); chi2 = zeros(rate_prototype)
atemp = zeros(rate_prototype); btemp = zeros(rate_prototype); E₂ = zeros(rate_prototype); E₁temp = zeros(rate_prototype)
E₁ = zeros(rate_prototype)
tmp = zeros(u)
SRACache(u,uprev,H0,A0temp,B0temp,ftmp,gtmp,chi2,atemp,btemp,E₁,E₁temp,E₂,tmp,tab)
end
Expand Down Expand Up @@ -171,26 +174,26 @@ function alg_cache(alg::SRI,u,ΔW,ΔZ,rate_prototype,uEltypeNoUnits,tTypeNoUnits
SRIConstantCache(alg.tableau,rate_prototype,alg.error_terms)
end

immutable SRICache{randType,uType,tabType} <: StochasticDiffEqMutableCache
immutable SRICache{randType,uType,rateType,tabType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
H0::Vector{uType}
H1::Vector{uType}
A0temp::uType
A1temp::uType
B0temp::uType
B1temp::uType
A0temp2::uType
A1temp2::uType
B0temp2::uType
B1temp2::uType
atemp::uType
btemp::uType
E₁::uType
E₂::uType
E₁temp::uType
ftemp::uType
gtemp::uType
A0temp::rateType
A1temp::rateType
B0temp::rateType
B1temp::rateType
A0temp2::rateType
A1temp2::rateType
B0temp2::rateType
B1temp2::rateType
atemp::rateType
btemp::rateType
E₁::rateType
E₂::rateType
E₁temp::rateType
ftemp::rateType
gtemp::rateType
chi1::randType
chi2::randType
chi3::randType
Expand All @@ -201,9 +204,8 @@ end
u_cache(c::SRICache) = ()
du_cache(c::SRICache) = (c.A0temp,c.A1temp,c.B0temp,c.B1temp,c.A0temp2,c.A1temp2,
c.B0temp2,c.B1temp2,c.atemp,c.btemp,c.E₁,c.E₂,c.E₁temp,
c.ftemp,c.gtemp,c.chi1,c.chi2,c.chi3,
c.H0...,c.H1...)

c.ftemp,c.gtemp,c.chi1,c.chi2,c.chi3)
user_cache(c::SRICache) = (c.u,c.uprev,c.tmp,c.H0...,c.H1...)

function alg_cache(alg::SRI,u,ΔW,ΔZ,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev,f,t,::Type{Val{true}})
H0 = Vector{typeof(u)}(0)
Expand All @@ -214,13 +216,13 @@ function alg_cache(alg::SRI,u,ΔW,ΔZ,rate_prototype,uEltypeNoUnits,tTypeNoUnits
push!(H1,zeros(u))
end
#TODO Reduce memory
A0temp = zeros(u); A1temp = zeros(u)
B0temp = zeros(u); B1temp = zeros(u)
A0temp2 = zeros(u); A1temp2 = zeros(u)
B0temp2 = zeros(u); B1temp2 = zeros(u)
atemp = zeros(u); btemp = zeros(u)
E₁ = zeros(u); E₂ = zeros(u); E₁temp = zeros(u)
ftemp = zeros(u); gtemp = zeros(u)
A0temp = zeros(rate_prototype); A1temp = zeros(rate_prototype)
B0temp = zeros(rate_prototype); B1temp = zeros(rate_prototype)
A0temp2 = zeros(rate_prototype); A1temp2 = zeros(rate_prototype)
B0temp2 = zeros(rate_prototype); B1temp2 = zeros(rate_prototype)
atemp = zeros(rate_prototype); btemp = zeros(rate_prototype)
E₁ = zeros(rate_prototype); E₂ = zeros(rate_prototype); E₁temp = zeros(rate_prototype)
ftemp = zeros(rate_prototype); gtemp = zeros(rate_prototype)
chi1 = similar(ΔW); chi2 = similar(ΔW); chi3 = similar(ΔW)
tmp = zeros(u)
SRICache(u,uprev,H0,H1,A0temp,A1temp,B0temp,B1temp,A0temp2,A1temp2,B0temp2,B1temp2,
Expand Down
8 changes: 4 additions & 4 deletions src/integrators/low_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
end

@inline function perform_step!(integrator,cache::EMCache,f=integrator.f)
@unpack tmp,utmp2 = cache
@unpack rtmp1,rtmp2 = cache
@unpack t,dt,uprev,u,ΔW = integrator
integrator.f(t,uprev,tmp)
integrator.g(t,uprev,utmp2)
integrator.f(t,uprev,rtmp1)
integrator.g(t,uprev,rtmp2)
for i in eachindex(u)
u[i] = @muladd uprev[i] + dt*tmp[i] + utmp2[i]*ΔW[i]
u[i] = @muladd uprev[i] + dt*rtmp1[i] + rtmp2[i]*ΔW[i]
end
@pack integrator = t,dt,u
end
Expand Down
6 changes: 4 additions & 2 deletions src/integrators/sri.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
B1temp[k] = @muladd B1temp[k] + B₁[j,i]*gtemp[k]
end
end
H0[i] = uprev + A0temp*dt + B0temp.*chi2
H1[i] = uprev + A1temp*dt + B1temp*integrator.sqdt
for k in eachindex(u)
H0[i][k] = uprev[k] + A0temp[k]*dt + B0temp[k]*chi2[k]
H1[i][k] = uprev[k] + A1temp[k]*dt + B1temp[k]*integrator.sqdt
end
end
fill!(atemp,zero(eltype(integrator.u)))
fill!(btemp,zero(eltype(integrator.u)))
Expand Down

0 comments on commit feee402

Please sign in to comment.