Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mass matrix hotfix #93

Merged
merged 6 commits into from
Jul 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@ abstract type StochasticDiffEqNewtonAlgorithm{CS,AD,Controller} <: StochasticDif
# Basics

struct EM{split} <: StochasticDiffEqAlgorithm end
Base.@pure EM(split=true) = EM{split}()
EM(split=true) = EM{split}()

struct SplitEM <: StochasticDiffEqAlgorithm end
struct EulerHeun <: StochasticDiffEqAlgorithm end

struct LambaEM{split} <: StochasticDiffEqAdaptiveAlgorithm end
Base.@pure LambaEM(split=true) = LambaEM{split}()
LambaEM(split=true) = LambaEM{split}()

struct LambaEulerHeun <: StochasticDiffEqAdaptiveAlgorithm end

struct RKMil{interpretation} <: StochasticDiffEqAdaptiveAlgorithm end
Base.@pure RKMil(;interpretation=:Ito) = RKMil{interpretation}()
RKMil(;interpretation=:Ito) = RKMil{interpretation}()

struct RKMilCommute{interpretation} <: StochasticDiffEqAdaptiveAlgorithm end
Base.@pure RKMilCommute(;interpretation=:Ito) = RKMilCommute{interpretation}()
RKMilCommute(;interpretation=:Ito) = RKMilCommute{interpretation}()

###############################################################################

Expand Down Expand Up @@ -95,17 +95,17 @@ struct SOSRA2 <: StochasticDiffEqAdaptiveAlgorithm end
struct IIF1M{F} <: StochasticDiffEqAlgorithm
nlsolve::F
end
Base.@pure IIF1M(;nlsolve=NLSOLVEJL_SETUP()) = IIF1M{typeof(nlsolve)}(nlsolve)
IIF1M(;nlsolve=NLSOLVEJL_SETUP()) = IIF1M{typeof(nlsolve)}(nlsolve)

struct IIF2M{F} <: StochasticDiffEqAlgorithm
nlsolve::F
end
Base.@pure IIF2M(;nlsolve=NLSOLVEJL_SETUP()) = IIF2M{typeof(nlsolve)}(nlsolve)
IIF2M(;nlsolve=NLSOLVEJL_SETUP()) = IIF2M{typeof(nlsolve)}(nlsolve)

struct IIF1Mil{F} <: StochasticDiffEqAlgorithm
nlsolve::F
end
Base.@pure IIF1Mil(;nlsolve=NLSOLVEJL_SETUP()) = IIF1Mil{typeof(nlsolve)}(nlsolve)
IIF1Mil(;nlsolve=NLSOLVEJL_SETUP()) = IIF1Mil{typeof(nlsolve)}(nlsolve)

################################################################################

Expand All @@ -123,7 +123,7 @@ struct ImplicitEM{CS,AD,F,S,K,T,T2,Controller} <: StochasticDiffEqNewtonAdaptive
new_jac_conv_bound::T2
symplectic::Bool
end
Base.@pure ImplicitEM(;chunk_size=0,autodiff=true,diff_type=Val{:central},
ImplicitEM(;chunk_size=0,autodiff=true,diff_type=Val{:central},
linsolve=DEFAULT_LINSOLVE,κ=nothing,tol=nothing,
extrapolant=:constant,min_newton_iter=1,
theta = 1/2,symplectic=false,
Expand Down Expand Up @@ -151,7 +151,7 @@ struct ImplicitEulerHeun{CS,AD,F,S,K,T,T2,Controller} <: StochasticDiffEqNewtonA
new_jac_conv_bound::T2
symplectic::Bool
end
Base.@pure ImplicitEulerHeun(;chunk_size=0,autodiff=true,diff_type=Val{:central},
ImplicitEulerHeun(;chunk_size=0,autodiff=true,diff_type=Val{:central},
linsolve=DEFAULT_LINSOLVE,κ=nothing,tol=nothing,
extrapolant=:constant,min_newton_iter=1,
theta = 1/2,symplectic = false,
Expand All @@ -178,7 +178,7 @@ struct ImplicitRKMil{CS,AD,F,S,K,T,T2,Controller,interpretation} <: StochasticDi
new_jac_conv_bound::T2
symplectic::Bool
end
Base.@pure ImplicitRKMil(;chunk_size=0,autodiff=true,diff_type=Val{:central},
ImplicitRKMil(;chunk_size=0,autodiff=true,diff_type=Val{:central},
linsolve=DEFAULT_LINSOLVE,κ=nothing,tol=nothing,
extrapolant=:constant,min_newton_iter=1,
theta = 1/2,symplectic = false,
Expand All @@ -205,7 +205,7 @@ struct ISSEM{CS,AD,F,S,K,T,T2,Controller} <: StochasticDiffEqNewtonAdaptiveAlgor
new_jac_conv_bound::T2
symplectic::Bool
end
Base.@pure ISSEM(;chunk_size=0,autodiff=true,diff_type=Val{:central},
ISSEM(;chunk_size=0,autodiff=true,diff_type=Val{:central},
linsolve=DEFAULT_LINSOLVE,κ=nothing,tol=nothing,
extrapolant=:constant,min_newton_iter=1,
theta = 1,symplectic=false,
Expand Down Expand Up @@ -233,7 +233,7 @@ struct ISSEulerHeun{CS,AD,F,S,K,T,T2,Controller} <: StochasticDiffEqNewtonAdapti
new_jac_conv_bound::T2
symplectic::Bool
end
Base.@pure ISSEulerHeun(;chunk_size=0,autodiff=true,diff_type=Val{:central},
ISSEulerHeun(;chunk_size=0,autodiff=true,diff_type=Val{:central},
linsolve=DEFAULT_LINSOLVE,κ=nothing,tol=nothing,
extrapolant=:constant,min_newton_iter=1,
theta = 1,symplectic=false,
Expand All @@ -260,7 +260,7 @@ struct SKenCarp{CS,AD,F,FDT,K,T,T2,Controller} <: StochasticDiffEqNewtonAdaptive
new_jac_conv_bound::T2
end

Base.@pure SKenCarp(;chunk_size=0,autodiff=true,diff_type=Val{:central},
SKenCarp(;chunk_size=0,autodiff=true,diff_type=Val{:central},
linsolve=DEFAULT_LINSOLVE,κ=nothing,tol=nothing,
smooth_est=true,extrapolant=:min_correct,min_newton_iter=1,
max_newton_iter=7,new_jac_conv_bound = 1e-3,
Expand Down
4 changes: 2 additions & 2 deletions src/caches/basic_method_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ function alg_cache(alg::RKMilCommute,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_
du1 = zero(rate_prototype); du2 = zero(rate_prototype)
K = zero(rate_prototype); gtmp = zero(noise_rate_prototype);
L = zero(noise_rate_prototype); tmp = zero(rate_prototype)
I = zero(length(ΔW),length(ΔW));
Dg = zero(length(ΔW),length(ΔW)); mil_correction = zero(rate_prototype)
I = zeros(length(ΔW),length(ΔW));
Dg = zeros(length(ΔW),length(ΔW)); mil_correction = zero(rate_prototype)
Kj = zero(u); Dgj = zero(noise_rate_prototype)
RKMilCommuteCache(u,uprev,du1,du2,K,gtmp,L,I,Dg,mil_correction,Kj,Dgj,tmp)
end
2 changes: 1 addition & 1 deletion src/caches/kencarp_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function alg_cache(alg::SKenCarp,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prot
fsalfirst = zero(rate_prototype)
k = zero(rate_prototype)
tmp = zero(u); b = similar(u,axes(u));
atmp = zero(u,uEltypeNoUnits,axes(u))
atmp = fill!(similar(u,uEltypeNoUnits,axes(u)),0)

if typeof(f) <: SplitFunction
k1 = zero(u); k2 = zero(u)
Expand Down
2 changes: 1 addition & 1 deletion src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function calc_W!(integrator, cache::StochasticDiffEqMutableCache, γdt, repeat_s
@unpack J,W,jac_config = cache
is_compos = is_composite(alg)
alg = unwrap_alg(integrator, true)
mass_matrix = integrator.sol.prob.mass_matrix
mass_matrix = integrator.f.mass_matrix

new_W = true
if has_invW(f)
Expand Down
26 changes: 13 additions & 13 deletions src/perform_step/iif.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mutable struct RHS_IIF1M_Scalar{F,CType,tType,P} <: Function
end

function (f::RHS_IIF1M_Scalar)(resid,u)
resid[1] .= u[1] - f.tmp - f.dt*f.f[2](u[1],f.p,f.t+f.dt)[1]
resid[1] = u[1] - f.tmp - f.dt*f.f.f2(u[1],f.p,f.t+f.dt)[1]
end

mutable struct RHS_IIF2M_Scalar{F,CType,tType,P} <: Function
Expand All @@ -19,7 +19,7 @@ mutable struct RHS_IIF2M_Scalar{F,CType,tType,P} <: Function
end

function (f::RHS_IIF2M_Scalar)(resid,u)
resid[1] = u[1] - f.tmp - 0.5f.dt*f.f[2](u[1],f.p,f.t+f.dt)[1]
resid[1] = u[1] - f.tmp - 0.5f.dt*f.f.f2(u[1],f.p,f.t+f.dt)[1]
end

@muladd function initialize!(integrator,cache::Union{IIF1MConstantCache,IIF2MConstantCache,IIF1MilConstantCache},f=integrator.f)
Expand All @@ -30,13 +30,13 @@ end
@unpack t,dt,uprev,u,W,p = integrator
@unpack uhold,rhs,nl_rhs = cache
alg = unwrap_alg(integrator, true)
A = integrator.f[1](u,p,t)
A = integrator.f.f1(u,p,t)
if typeof(cache) <: IIF1MilConstantCache
error("Milstein correction does not work.")
elseif typeof(cache) <: IIF1MConstantCache
tmp = expm(A*dt)*(uprev + integrator.g(uprev,p,t)*W.dW)
tmp = exp(A*dt)*(uprev + integrator.g(uprev,p,t)*W.dW)
elseif typeof(cache) <: IIF2MConstantCache
tmp = expm(A*dt)*(uprev + 0.5dt*integrator.f[2](uprev,p,t) + integrator.g(uprev,p,t)*W.dW)
tmp = exp(A*dt)*(uprev + 0.5dt*integrator.f.f2(uprev,p,t) + integrator.g(uprev,p,t)*W.dW)
end

if integrator.iter > 1 && !integrator.u_modified
Expand Down Expand Up @@ -64,7 +64,7 @@ end
function (f::RHS_IIF1)(resid,u)
_du = get_du(f.dual_cache, eltype(u))
du = reinterpret(eltype(u),_du)
f.f[2](du,reshape(u,f.sizeu),f.p,f.t+f.dt)
f.f.f2(du,reshape(u,f.sizeu),f.p,f.t+f.dt)
@. resid = u - f.tmp - f.dt*du
end

Expand All @@ -80,7 +80,7 @@ end
function (f::RHS_IIF2)(resid,u)
_du = get_du(f.dual_cache, eltype(u))
du = reinterpret(eltype(u),_du)
f.f[2](du,reshape(u,f.sizeu),f.p,f.t+f.dt)
f.f.f2(du,reshape(u,f.sizeu),f.p,f.t+f.dt)
@. resid = u - f.tmp - 0.5f.dt*du
end

Expand All @@ -102,12 +102,12 @@ end
rtmp3 .+= uprev

if typeof(cache) <: IIF2MCache
integrator.f[2](rtmp1,uprev,p,t)
integrator.f.f2(rtmp1,uprev,p,t)
@. rtmp3 = @muladd 0.5dt*rtmp1 + rtmp3
end

A = integrator.f[1](rtmp1,uprev,p,t)
M = expm(A*dt)
A = integrator.f.f1(rtmp1,uprev,p,t)
M = exp(A*dt)
mul!(tmp,M,rtmp3)

if integrator.iter > 1 && !integrator.u_modified
Expand All @@ -134,8 +134,8 @@ end
dW = W.dW; sqdt = integrator.sqdt
f = integrator.f; g = integrator.g

A = integrator.f[1](t,uprev,rtmp1)
M = expm(A*dt)
A = integrator.f.f1(t,uprev,rtmp1)
M = exp(A*dt)

uidx = eachindex(u)
integrator.g(rtmp2,uprev,p,t)
Expand Down Expand Up @@ -168,7 +168,7 @@ end
end

if typeof(cache) <: IIF2MCache
integrator.f[2](t,uprev,rtmp1)
integrator.f.f2(t,uprev,rtmp1)
@. rtmp1 = @muladd 0.5dt*rtmp1 + uprev + rtmp3
mul!(tmp,M,rtmp1)
elseif !(typeof(cache) <: IIF1MilCache)
Expand Down
2 changes: 1 addition & 1 deletion src/perform_step/implicit_split_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ end
alg = unwrap_alg(integrator, true)
alg.symplectic ? a = dt/2 : a = dt
dW = integrator.W.dW
mass_matrix = integrator.sol.prob.mass_matrix
mass_matrix = integrator.f.mass_matrix
theta = alg.theta

repeat_step = false
Expand Down
2 changes: 1 addition & 1 deletion src/perform_step/sdirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ end
alg = unwrap_alg(integrator, true)
alg.symplectic ? a = dt/2 : a = dt
dW = integrator.W.dW
mass_matrix = integrator.sol.prob.mass_matrix
mass_matrix = integrator.f.mass_matrix
theta = alg.theta

repeat_step = false
Expand Down
8 changes: 4 additions & 4 deletions src/perform_step/split.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@muladd function perform_step!(integrator,cache::SplitEMConstantCache,f=integrator.f)
@unpack t,dt,uprev,u,W,p = integrator
u = dt*(integrator.f[1](uprev,p,t) +
integrator.f[2](uprev,p,t)) +
u = dt*(integrator.f.f1(uprev,p,t) +
integrator.f.f2(uprev,p,t)) +
integrator.g(uprev,p,t).*W.dW + uprev
integrator.u = u
end
Expand All @@ -17,8 +17,8 @@ end
mul!(rtmp3,rtmp2,W.dW)
end

integrator.f[1](t,uprev,rtmp1)
integrator.f.f1(t,uprev,rtmp1)
@. u = @muladd uprev + dt*rtmp1 + rtmp3
integrator.f[2](t,uprev,rtmp1)
integrator.f.f2(t,uprev,rtmp1)
@. u = @muladd u + dt*rtmp1
end
4 changes: 2 additions & 2 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ function __init(
end

if typeof(prob.f)<:Tuple
if min((mm != I for mm in prob.mass_matrix)...)
if min((mm != I for mm in prob.f.mass_matrix)...)
error("This solver is not able to use mass matrices.")
end
elseif prob.mass_matrix != I && !alg_mass_matrix_compatible(alg)
elseif prob.f.mass_matrix != I && !alg_mass_matrix_compatible(alg)
error("This solver is not able to use mass matrices.")
end

Expand Down
17 changes: 11 additions & 6 deletions test/commutative_tests.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,40 @@
using StochasticDiffEq, DiffEqDevTools, Test
using StochasticDiffEq, DiffEqDevTools, Test, LinearAlgebra
srand(100)
dts = 1./2.^(10:-1:2) #14->7 good plot

using SpecialMatrices
const σ_const = 0.87
const μ = 1.01

u0 = rand(2)
A = full(Strang(2))
A = [-2.0 1.0;1.0 -2.0]
B = Diagonal([σ_const for i in 1:2])

function f_commute(du,u,p,t)
mul!(du,A,u)
du .+= 1.01u
end
function (::typeof(f_commute))(::Type{Val{:analytic}},u0,p,t,W)

function f_commute_analytic(u0,p,t,W)
tmp = (A+1.01I-(B^2))*t + B*sum(W)
exp(tmp)*u0
end

function σ(du,u,p,t)
du[1,1] = σ_const*u[1]
du[1,2] = σ_const*u[1]
du[2,1] = σ_const*u[2]
du[2,2] = σ_const*u[2]
end

prob = SDEProblem(f_commute,σ,u0,(0.0,1.0),noise_rate_prototype=rand(2,2))
ff_commute = SDEFunction(f_commute,σ,analytic=f_commute_analytic)

prob = SDEProblem(ff_commute,σ,u0,(0.0,1.0),noise_rate_prototype=rand(2,2))

sol = solve(prob,RKMilCommute(),dt=1/2^(8))
sol = solve(prob,EM(),dt=1/2^(10))

dts = 1./2.^(10:-1:3) #14->7 good plot
dts = (1/2) .^ (10:-1:3) #14->7 good plot
sim2 = test_convergence(dts,prob,EM(),numMonte=Int(1e2))
sim2 = test_convergence(dts,prob,RKMilCommute(),numMonte=Int(1e2))

sim2.𝒪est[:final] - 1 < 0.2