Skip to content

Commit

Permalink
Merge pull request #151 from Jonas-a-Zimmermann/OU_Patch
Browse files Browse the repository at this point in the history
Ou patch
  • Loading branch information
ChrisRackauckas committed Jun 5, 2023
2 parents 6f441f0 + d16a044 commit 74da127
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/ornstein_uhlenbeck.jl
Expand Up @@ -10,9 +10,9 @@ function (X::OrnsteinUhlenbeck)(dW, W, dt, u, p, t, rng) #dist
else
rand_val = wiener_randn(rng, typeof(dW))
end
drift = X.μ .+ (W[end] .- X.μ) .* exp.(-X.Θ * dt)
drift = X.μ .+ (W.curW .- X.μ) .* exp.(-X.Θ * dt)
diffusion = X.σ .* sqrt.((1 .- exp.(-2X.Θ * dt)) ./ (2X.Θ))
drift .+ rand_val .* diffusion .- W[end]
drift .+ rand_val .* diffusion .- W.curW
end

#=
Expand Down Expand Up @@ -59,8 +59,8 @@ end

function (X::OrnsteinUhlenbeck!)(rand_vec, W, dt, u, p, t, rng) #dist!
wiener_randn!(rng, rand_vec)
@.. rand_vec = X.μ + (W[end] - X.μ) * exp(-X.Θ * dt) +
rand_vec * X.σ * sqrt((1 - exp.(-2 * X.Θ .* dt)) / (2 * X.Θ)) - W[end]
@.. rand_vec = X.μ + (W.curW - X.μ) * exp(-X.Θ * dt) +
rand_vec * X.σ * sqrt((1 - exp.(-2 * X.Θ .* dt)) / (2 * X.Θ)) - W.curW
end

@doc doc"""
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Expand Up @@ -28,4 +28,5 @@ using Test
include("reinit_test.jl")
include("BWT_test.jl")
include("pcn_test.jl")
include("savestep_test.jl")
end
27 changes: 27 additions & 0 deletions test/savestep_test.jl
@@ -0,0 +1,27 @@
@testset "save_everystep Keyword" begin
#Test whether the result of the process is dependent on 'save_everystep'.
using DiffEqNoiseProcess, DiffEqBase, Test, Statistics, Random
processes = [ OrnsteinUhlenbeckProcess(1., 1., 0.3, 0., 0., nothing ),
WienerProcess(0.,0.,nothing ),
CorrelatedWienerProcess([1. 0.; 0. 1.],0.,[0.; 0.],nothing ),
GeometricBrownianMotionProcess(1., 1., 0., 0., nothing )
]



@testset "Noise_process = $(proc.dist)" for proc in processes

cproc = deepcopy(proc)
cproc.save_everystep = true
prob = NoiseProblem(cproc, (0.0, 1.0); seed=1234)
sol_save = solve(prob; dt = 0.1)


cproc = deepcopy(proc)
cproc.save_everystep = false
prob = NoiseProblem(cproc, (0.0, 1.0); seed=1234)
sol_nosave = solve(prob; dt = 0.1)

@test sol_save.curW == sol_nosave.curW
end
end

0 comments on commit 74da127

Please sign in to comment.