In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.special import factorial, hermitenorm
from numba import njit

In [21]:
@njit
def ode(u, t, a=1):
    return -0.5 * u**2 + (a+1) * u

@njit
def ode_book(u,t, a=1):
    return -a*u

@njit
def fact(x):
    '''
    calculates factorial

    Arguments:
        x -- integer

    Returns:
        x!
    '''
    factors = np.arange(1,x+1)
    y = 1
    for i in range(x):
        y = y * factors[i]
    return y

@njit
def e_ijk(i,j,k):
    s = np.ceil((i+j+k)/2)
    return (fact(i)*fact(j)*fact(k)) / (fact(s-i)*fact(s-j)*fact(s-k))

@njit
def get_A(N, mu, sigma):
    A = np.zeros((N,N))
    a = np.array([mu, sigma])

    for j in range(N):
        for k in range(N):
            Ajk = 0
            for i, ai in enumerate(a):
                Ajk += ai * e_ijk(i,j,k)
            A[j,k] = -1/fact(k) * Ajk
    
    return A

@njit
def get_A_book(N, mu, sigma):
    A = np.zeros((N,N))
    a = np.array([mu, sigma])

    for j in range(N):
        for k in range(N):
            Ajk = 0
            for i, ai in enumerate(a):
                Ajk += ai * e_ijk(i,j,k)
            A[j,k] = -1/fact(k) * Ajk
    return A

@njit
def get_b(N, betha):
    b = np.zeros(N)
    b[0] = betha
    return b

def approximation_book(N, mu, sigma, betha):
    A = get_A(N,mu,sigma)
    b = get_b(N,betha)
    vhat = np.linalg.solve(A.T,b)
    
    approx = 0
    for i in range(N):
        approx += vhat[i] * hermitenorm(i)
    return approx


In [22]:
u0 = 0.1 # initial condition

In [23]:
N = 10000
dudt2s = np.zeros(N)
for i in range(N):
    a = np.random.normal(0,1)
    dudt2s[i] = solve_ivp(ode_book, (0,2), [u0], args = (a,)).y[0,-1]

In [24]:
print(np.mean(dudt2s))
print(np.var(dudt2s))

0.06430263188235995
3.9927942656197675


In [25]:
A = get_A_book(3,0,1)

In [26]:
b = get_b(3,0.1)

In [27]:
np.linalg.solve(A.T,b)

array([-0.2, -0. ,  0.1])

In [28]:
approx = approximation_book(10,0,1,0.5)

ValueError: non-broadcastable output operand with shape (1,) doesn't match the broadcast shape (2,)

In [20]:
approx

-1.5