In [1]:
using BenchmarkTools: @btime
import LinearAlgebra as la

In [2]:
struct HO_Funcs
    ω::Float64
    hermites::Vector{Float64}
    hos::Vector{Float64}
    ho_der::Vector{Float64}
    ho_dder::Vector{Float64}
    
    function HO_Funcs(l, ω)
        hermites = zeros(l)
        hermites[1] = 1.0
        
        hos = zeros(l)
        ho_der = zeros(l)
        ho_dder = zeros(l)
        return new(ω, hermites, hos, ho_der, ho_dder)
    end
end

In [67]:
function fast_ho_all(x, ho)
    (; ω, hermites, hos, ho_der, ho_dder) = ho
    
    ξ = √ω * x

    #hermites[1] = 1.0
    ho_fac = (ω / π)^0.25 * exp(-ξ^2 / 2)
    
    hos[1] = ho_fac * hermites[1]
    ho_der[1] = -ω * x * hos[1]
    ho_dder[1] = ω * (ω * x^2 - 1) * hos[1]
    
    hermites[2] = 2ξ
    ho_fac *= 1 / √2

    hos[2] = ho_fac * hermites[2]
    ho_der[2] =  ho_fac * (√ω * 2 * hermites[1] - ω * x * hermites[2])
    ho_dder[2] = ho_fac * ω * ((ω * x^2 - 1) * hermites[2] - √ω * x * 4 * hermites[1])

    @inbounds for n in 2:length(hos)-1
        hermites[n+1] = 2ξ * hermites[n] - 2(n-1) * hermites[n-1]
        ho_fac *= 1 / sqrt( 2n )

        hos[n+1] = ho_fac * hermites[n+1]
        ho_der[n+1] =  ho_fac * (√ω * 2 * n * hermites[n] - ω * x * hermites[n+1])
        ho_dder[n+1] = ho_fac * ω * ((ω * x^2 - 1) * hermites[n+1] - √ω * x * 4n * hermites[n] + 4 * (n-1)*n * hermites[n - 1])
    end
    
    return hos, ho_der, ho_dder
end

fast_ho_all (generic function with 1 method)

In [13]:
l = 50
ω = 0.25
ho = HO_Funcs(l, ω);
x = 8.1

8.1

In [54]:
@time fast_ho_all(x, ho);

-0.0015867778662118476
  0.054710 seconds (131.56 k allocations: 6.595 MiB, 99.45% compilation time)


In [55]:
function ho_test_2(x)
    n = 1
    x = √ω * x
    hermite = 2x
    return (ω/π)^0.25 * 1 / sqrt(2^n * factorial(n)) * hermite * exp(-x^2/2)
end

ho_test_2 (generic function with 1 method)

In [56]:
function ho_test(x)
    n = 5
    x = √ω * x
    hermite = 32x^5 - 160x^3 + 120x
    return (ω/π)^0.25 * 1 / sqrt(2^n * factorial(n)) * hermite * exp(-x^2/2)
end

ho_test (generic function with 1 method)

In [57]:
import ForwardDiff as fd

In [68]:
hos, ho_der, ho_dder = fast_ho_all(x, ho);

In [60]:
println(ho_test_2(x))
println(hos[2])

0.0008344684769560762
0.0008344684769560763


In [69]:
println( fd.derivative(ho_test_2, 8.1))
println(ho_der[2])

-0.0015867778662118472
-0.0015867778662118476


In [19]:
println( fd.derivative(ho_test, 8.1))
println(ho_der[6])

-0.0762265461506806
-0.07622654615068064


In [20]:
println( fd.derivative(x -> fd.derivative(ho_test, x), 8.1) )
println(ho_dder[6])

0.07851376582953867
0.07851376582953876
