diff --git a/src/integrators/symplectic_integrators.jl b/src/integrators/symplectic_integrators.jl index 9cc9a6a14c..fe6981808a 100644 --- a/src/integrators/symplectic_integrators.jl +++ b/src/integrators/symplectic_integrators.jl @@ -44,3 +44,34 @@ end end f[1](integrator.t,uprev,du,ku) end + +@inline function initialize!(integrator,cache::VelocityVerletCache,f=integrator.f) + integrator.kshortsize = 2 + @unpack k,fsalfirst = cache + integrator.fsalfirst = fsalfirst + integrator.fsallast = k + integrator.k = eltype(integrator.sol.k)(integrator.kshortsize) + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + uprev,duprev = integrator.uprev.x + f[2](integrator.t,uprev,duprev,integrator.k[1].x[2]) +end + +@inline function perform_step!(integrator,cache::VelocityVerletCache,f=integrator.f) + @unpack t,dt = integrator + uprev,duprev = integrator.uprev.x + u,du = integrator.u.x + kduprev = integrator.k[1].x[2] + kdu = integrator.k[2].x[2] + # x(t+Δt) = x(t) + v(t)*Δt + 1/2*a(t)*Δt^2 + f[2](integrator.t,uprev,duprev,kduprev) + @tight_loop_macros for i in eachindex(u) + @inbounds u[i] = @muladd uprev[i]+duprev[i]*dt+(1//2*kduprev[i])*dt^2 + end + f[2](integrator.t,u,duprev,kdu) + # v(t+Δt) = v(t) + 1/2*(a(t)+a(t+Δt))*Δt + @tight_loop_macros for i in eachindex(du) + du[i] = muladd(dt,(1//2*kduprev[i]+kdu[i]),duprev[i]) + end + f[2](integrator.t,uprev,duprev,kduprev) +end