Skip to content

Commit

Permalink
sde options type
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jan 24, 2017
1 parent 832b47c commit 992f271
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 14 deletions.
2 changes: 2 additions & 0 deletions src/StochasticDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ module StochasticDiffEq
import DiffEqBase: solve

include("algorithms.jl")
include("options_type.jl")
include("constants.jl")
include("alg_utils.jl")
include("solve.jl")
include("initdt.jl")
Expand Down
1 change: 1 addition & 0 deletions src/constants.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
@inline ODE_DEFAULT_ISOUTOFDOMAIN(t,u) = false
33 changes: 33 additions & 0 deletions src/options_type.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
type SDEOptions{uEltype,uEltypeNoUnits,tTypeNoUnits,tType,F2,F3,F4,F5,tstopsType,ECType}
maxiters::Int
timeseries_steps::Int
save_timeseries::Bool
adaptive::Bool
abstol::uEltype
reltol::uEltypeNoUnits
gamma::tTypeNoUnits
qmax::tTypeNoUnits
qmin::tTypeNoUnits
dtmax::tType
dtmin::tType
internalnorm::F2
tstops::tstopsType
saveat::tstopsType
d_discontinuities::tstopsType
userdata::ECType
progress::Bool
progress_steps::Int
progress_name::String
progress_message::F5
timeseries_errors::Bool
dense_errors::Bool
beta1::tTypeNoUnits
beta2::tTypeNoUnits
qoldinit::tTypeNoUnits
dense::Bool
callback::F3
isoutofdomain::F4
calck::Bool
advance_to_tstop::Bool
stop_at_next_tstop::Bool
end
98 changes: 84 additions & 14 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

function solve{uType,tType,isinplace,NoiseClass,F,F2,F3,algType<:AbstractSDEAlgorithm,recompile_flag}(
prob::AbstractSDEProblem{uType,tType,isinplace,NoiseClass,F,F2,F3},
alg::algType,timeseries=[],ts=[],ks=[],recompile::Type{Val{recompile_flag}}=Val{true};
alg::algType,timeseries_init=uType[],ts_init=tType[],ks_init=[],
recompile::Type{Val{recompile_flag}}=Val{true};
dt = tType(0),save_timeseries::Bool = true,
timeseries_steps::Int = 1,
dense = false,
saveat = tType[],tstops = tType[],d_discontinuities= tType[],
calck = (!isempty(setdiff(saveat,tstops)) || dense),
adaptive=isadaptive(alg),γ=9//10,
adaptive=isadaptive(alg),gamma=9//10,
abstol=1e-2,reltol=1e-2,
qmax=qmax_default(alg),qmin=qmin_default(alg),
qoldinit=1//10^4, fullnormalize=true,
Expand All @@ -21,6 +22,7 @@ function solve{uType,tType,isinplace,NoiseClass,F,F2,F3,algType<:AbstractSDEAlgo
dtmin=tType <: AbstractFloat ? tType(10)*eps(tType) : tType(1//10^(10)),
internalnorm=ODE_DEFAULT_NORM,
unstable_check = ODE_DEFAULT_UNSTABLE_CHECK,
isoutofdomain = ODE_DEFAULT_ISOUTOFDOMAIN,
advance_to_tstop = false,stop_at_next_tstop=false,
discard_length=1e-15,adaptivealg::Symbol=:RSwM3,
progress_steps=1000,
Expand All @@ -29,12 +31,15 @@ function solve{uType,tType,isinplace,NoiseClass,F,F2,F3,algType<:AbstractSDEAlgo
userdata=nothing,callback=nothing,
timeseries_errors = true, dense_errors=false,
kwargs...)

gamma
@unpack u0,noise,tspan = prob

tspan = prob.tspan
tdir = sign(tspan[end]-tspan[1])

T = tType(tspan[2])
t = tType(tspan[1])

if tspan[2]-tspan[1]<0 || length(tspan)>2
error("tspan must be two numbers and final time must be greater than starting time. Aborting.")
end
Expand Down Expand Up @@ -79,16 +84,84 @@ function solve{uType,tType,isinplace,NoiseClass,F,F2,F3,algType<:AbstractSDEAlgo
dt = sde_determine_initdt(u0,float(tspan[1]),tdir,dtmax,abstol,reltol,internalnorm,prob,order)
end

T = tType(tspan[2])
t = tType(tspan[1])
if sign(dt)!=tdir && dt!=tType(0)
error("dt has the wrong sign. Exiting")
end

if typeof(u) <: AbstractArray
rate_prototype = similar(u/zero(t),indices(u)) # rate doesn't need type info
else
rate_prototype = u/zero(t)
end
rateType = typeof(rate_prototype) ## Can be different if united

saveat_vec = convert(Vector{tType},collect(saveat))
if !isempty(saveat_vec) && saveat_vec[end] == tspan[2]
pop!(saveat_vec)
end

if tdir>0
saveat_internal = binary_minheap(saveat_vec)
else
saveat_internal = binary_maxheap(saveat_vec)
end

if !isempty(saveat_internal) && top(saveat_internal) == tspan[1]
pop!(saveat_internal)
end

d_discontinuities_vec = convert(Vector{tType},d_discontinuities_col)

if tdir>0
d_discontinuities_internal = binary_minheap(d_discontinuities_vec)
else
d_discontinuities_internal = binary_maxheap(d_discontinuities_vec)
end

callbacks_internal = CallbackSet(callback)

uEltypeNoUnits = typeof(recursive_one(u))
tTypeNoUnits = typeof(recursive_one(t))

### Algorithm-specific defaults ###
ksEltype = Vector{rateType}

timeseries = Vector{uType}(0)
push!(timeseries,u0)
ts = Vector{tType}(0)
push!(ts,t)
# Have to convert incase passed in wrong.
timeseries = convert(Vector{uType},timeseries_init)
ts = convert(Vector{tType},ts_init)
ks = convert(Vector{ksEltype},ks_init)
alg_choice = Int[]

copyat_or_push!(ts,1,t)
copyat_or_push!(timeseries,1,u)
copyat_or_push!(ks,1,[rate_prototype])

uEltype = eltype(u)

opts = SDEOptions(Int(maxiters),timeseries_steps,save_timeseries,adaptive,uEltype(uEltype(1)*abstol),
uEltypeNoUnits(reltol),tTypeNoUnits(gamma),tTypeNoUnits(qmax),tTypeNoUnits(qmin),
dtmax,dtmin,internalnorm,
tstops_internal,saveat_internal,d_discontinuities_internal,
userdata,
progress,progress_steps,
progress_name,progress_message,
timeseries_errors,dense_errors,
tTypeNoUnits(beta1),tTypeNoUnits(beta2),tTypeNoUnits(qoldinit),dense,
callbacks_internal,isoutofdomain,calck,advance_to_tstop,stop_at_next_tstop)

progress ? (prog = Juno.ProgressBar(name=progress_name)) : prog = nothing

notsaveat_idxs = Int[1]

k = ksEltype[]

if uType <: Array
uprev = copy(u)
else
uprev = deepcopy(u)
end


if !(uType <: AbstractArray)
rands = ChunkedArray(noise.noise_func)
randType = typeof(u/u) # Strip units and type info
Expand All @@ -98,9 +171,6 @@ function solve{uType,tType,isinplace,NoiseClass,F,F2,F3,algType<:AbstractSDEAlgo
randType = typeof(rand_prototype) # Strip units and type info
end

uEltypeNoUnits = typeof(recursive_one(u))
tTypeNoUnits = typeof(recursive_one(t))


Ws = Vector{randType}(0)
if !(uType <: AbstractArray)
Expand All @@ -120,13 +190,13 @@ function solve{uType,tType,isinplace,NoiseClass,F,F2,F3,algType<:AbstractSDEAlgo

rateType = typeof(u/t) ## Can be different if united

#@code_warntype sde_solve(SDEIntegrator{typeof(alg),typeof(u),eltype(u),ndims(u),ndims(u)+1,typeof(dt),typeof(tableau)}(f,g,u,t,dt,T,Int(maxiters),timeseries,Ws,ts,timeseries_steps,save_timeseries,adaptive,adaptivealg,δ,γ,abstol,reltol,qmax,dtmax,dtmin,internalnorm,discard_length,progress,atomloaded,progress_steps,rands,sqdt,W,Z,tableau))
#@code_warntype sde_solve(SDEIntegrator{typeof(alg),typeof(u),eltype(u),ndims(u),ndims(u)+1,typeof(dt),typeof(tableau)}(f,g,u,t,dt,T,Int(maxiters),timeseries,Ws,ts,timeseries_steps,save_timeseries,adaptive,adaptivealg,δ,gamma,abstol,reltol,qmax,dtmax,dtmin,internalnorm,discard_length,progress,atomloaded,progress_steps,rands,sqdt,W,Z,tableau))

u,t,W,timeseries,ts,Ws,maxstacksize,maxstacksize2 = sde_solve(
SDEIntegrator{typeof(alg),uType,uEltype,ndims(u),ndims(u)+1,tType,tTypeNoUnits,
uEltypeNoUnits,randType,rateType,typeof(internalnorm),typeof(progress_message),
typeof(unstable_check),F,F2}(f,g,u,t,dt,T,alg,Int(maxiters),timeseries,Ws,
ts,timeseries_steps,save_timeseries,adaptive,adaptivealg,δ,tTypeNoUnits(γ),
ts,timeseries_steps,save_timeseries,adaptive,adaptivealg,δ,tTypeNoUnits(gamma),
abstol,reltol,tTypeNoUnits(qmax),dtmax,dtmin,internalnorm,discard_length,
progress,progress_name,progress_steps,progress_message,
unstable_check,rands,sqdt,W,Z,
Expand Down

0 comments on commit 992f271

Please sign in to comment.