In [None]:
using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D

# Stochastic description of the Van der Pol oscillator

@variables x₁(t), x₂(t)
@parameters ϵ, ω_n, ω_g, A

drift_eqs = [D(x₁) ~ x₂;
             D(x₂) ~ ϵ*(1-x₁^2)*x₂ - ω_n^2*x₁ + A*cos(ω_g*t)]
diff_eqs = [0; A]

vdp_model = SDESystem(drift_eqs, diff_eqs, t, [x₁, x₂], [ϵ, ω_n, ω_g, A], name = :VdP)

ps = [ϵ => 0.1, ω_n => 120*pi, ω_g => 120*pi, A => 2.5] # parameter values
u0map = [x₁ => 0.1, x₂ => 0.1]   # initial conditions
tspan = (0., 0.1) # simulation time limit

In [None]:
using MomentClosure, Latexify

moment_eqs = generate_raw_moment_eqs(vdp_model, 2) 
latexify(moment_eqs)

In [None]:
println(latexify(moment_eqs))

In [None]:
using OrdinaryDiffEqTsit5
using StochasticDiffEq

closed_eqs = moment_closure(moment_eqs, "derivative matching")
oprob = ODEProblem(closed_eqs, u0map, tspan, ps)
sol_MA = solve(oprob, Tsit5(), saveat=0.0001);

In [None]:
using DifferentialEquations.EnsembleAnalysis, Plots

prob_SDE = SDEProblem(complete(vdp_model), u0map, tspan, ps)
@time sol_SDE = solve(EnsembleProblem(prob_SDE), SRIW1(), saveat=0.0001, trajectories=100)
means_SDE = timeseries_steps_mean(sol_SDE)

plot(sol_MA.t, sol_MA[1, :], lw=2, label="MA", ylabel="⟨x₁⟩", xlabel="time")
plot!(sol_MA.t, means_SDE[1, :], lw=2, label="SDE", linecolor=:red,
      linestyle=:dash, background_color_legend=nothing, legend=:topright, grid=false)

In [None]:
@variables x₁(t), x₂(t)
@parameters k, l, m, g

drift_eqs = [D(x₁) ~ x₂;
             D(x₂) ~ -k/m*x₂ - g/l*sin(x₁)]
diff_eqs = [0; 1/m]

pendulum_model = SDESystem(drift_eqs, diff_eqs, t, [x₁, x₂], [k, l, m, g], name = :pendulum)
ps = [k => 10, m => 10, l => 10, g => 10]
u0map = [x₁ => 3, x₂ => 3]
tspan = (0., 15.)

In [None]:
pendulum_model

In [None]:
moment_eqs = generate_central_moment_eqs(pendulum_model, 2, 3) 
latexify(moment_eqs) # the output here is maybe not the most visually pleasing

In [None]:
closed_eqs = moment_closure(moment_eqs, "gamma")

oprob = ODEProblem(closed_eqs, u0map, tspan, ps)
sol_MA = solve(oprob, Tsit5(), saveat=0.01)

prob_SDE = SDEProblem(complete(pendulum_model), u0map, tspan, ps)
sol_SDE = solve(EnsembleProblem(prob_SDE), SRIW1(), saveat=0.01, trajectories=100)
means_SDE = timeseries_steps_mean(sol_SDE)

plot(sol_MA.t, sin.(sol_MA[1, :]), lw=2, label="MA", ylabel="sin(⟨x₁⟩)", xlabel="time")
plot!(sol_MA.t, sin.(means_SDE[1, :]), lw=2, label="SDE", linecolor=:red,
    linestyle=:dash, background_color_legend=nothing, legend=:topright, grid=false)