Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/problems/sdeproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,20 @@ end
t = tspan !== nothing ? tspan[1] : tspan, check_length, eval_expression,
eval_module, check_compatibility, sparse, expression, kwargs...)

noise, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
# Only calculate noise and noise_rate_prototype if not provided by user
if !haskey(kwargs, :noise) && !haskey(kwargs, :noise_rate_prototype)
noise, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
elseif !haskey(kwargs, :noise)
noise, _ = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
noise_rate_prototype = kwargs[:noise_rate_prototype]
elseif !haskey(kwargs, :noise_rate_prototype)
_, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
noise = kwargs[:noise]
else
noise = kwargs[:noise]
noise_rate_prototype = kwargs[:noise_rate_prototype]
end

kwargs = process_kwargs(sys; expression, callback, eval_expression, eval_module,
op, kwargs...)

Expand Down
45 changes: 45 additions & 0 deletions test/sdesystem.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using ModelingToolkit, StaticArrays, LinearAlgebra
using StochasticDiffEq, OrdinaryDiffEq, SparseArrays
using DiffEqNoiseProcess: NoiseWrapper
using Random, Test
using Setfield
using Statistics
Expand Down Expand Up @@ -953,3 +954,47 @@ end
@test ModelingToolkit.isbrownian(p)
@test ModelingToolkit.isbrownian(q)
end

@testset "noise kwarg propagation (issue #3664)" begin
@parameters σ ρ β
@variables x(tt) y(tt) z(tt)

u0 = [1.0, 0.0, 0.0]
T = (0.0, 5.0)

eqs = [D(x) ~ σ * (y - x),
D(y) ~ x * (ρ - z) - y,
D(z) ~ x * y - β * z]
noiseeqs = [3.0,
3.0,
3.0]
@mtkbuild sde_lorentz = SDESystem(eqs, noiseeqs, tt, [x, y, z], [σ, ρ, β])
parammap = [σ, ρ, β] .=> [10, 28.0, 8 / 3]

# Test that user-provided noise is respected
Random.seed!(1)
noise1 = StochasticDiffEq.RealWienerProcess(0.0, 0.0, 0.0; save_everystep = true)
u0_dict = Dict(unknowns(sde_lorentz) .=> u0)
prob1 = SDEProblem(sde_lorentz, merge(u0_dict, Dict(parammap)), T; noise = noise1)
sol1 = solve(prob1, SRIW1())

# Verify noise was actually used (curW should be modified)
@test noise1.curW != 0.0

# Test that using the same noise via NoiseWrapper gives deterministic results
noise2 = NoiseWrapper(noise1)
prob2 = SDEProblem(sde_lorentz, merge(u0_dict, Dict(parammap)), T; noise = noise2)
sol2 = solve(prob2, SRIW1())

# Same noise should give same results
@test sol1.u[end] ≈ sol2.u[end]

# Test that without providing noise, different results are obtained
Random.seed!(1)
prob3 = SDEProblem(sde_lorentz, merge(u0_dict, Dict(parammap)), T)
Random.seed!(2)
prob4 = SDEProblem(sde_lorentz, merge(u0_dict, Dict(parammap)), T)
sol3 = solve(prob3, SRIW1(), seed = 1)
sol4 = solve(prob4, SRIW1(), seed = 2)
@test !(sol3.u[end] ≈ sol4.u[end])
end
Loading