# 四阶龙格—库塔(Runge—Kutta)方法

In [1]:
# 引入必要的库
import math
import numpy as np
from IPython.display import Latex, display
# 类型
from typing import Callable
ufunc = Callable[[float],float]
bfunc = Callable[[float,float],float]
from numpy import ndarray

In [2]:
def Runge_Kutta(
    f:bfunc,
    a:float,
    b:float,
    alpha:float,
    N:int
) -> ndarray:
    if a>=b:
        raise ValueError
    x_vector = np.linspace(a,b,N+1)
    ylist = [alpha]
    y = alpha
    h = (b-a)/N
    hf = lambda x,y: h*f(x,y)
    for x in [x for x in x_vector if x !=b]:
        K1=hf(x,y)
        K2=hf(x+h/2,y+K1/2)
        K3=hf(x+h/2,y+K2/2)
        K4=hf(x+h,y+K3)
        y = y+1.0/6*(K1+2*K2+2*K3+K4)
        ylist.append(y)
    y_vector = np.array(ylist).reshape(N+1,-1)
    return np.hstack([x_vector.reshape(N+1,-1),y_vector])

In [3]:
def get_Standard_Deviation(
    g:ufunc,
    res:ndarray
) -> float:
    x = res[:,0]
    y = res[:,-1]
    true_value = np.array([g(xi) for xi in x])
    return np.sqrt(np.sum((y-true_value)**2)/len(y))


In [4]:
def solve_question(
    f:bfunc,
    a:float,
    b:float,
    alpha:float,
    Nlist:list[int],
    g:ufunc
) -> list[tuple[int,ndarray,float]]:
    if a>=b or len(Nlist)==0:
        raise ValueError
    reslist = list()
    for N in Nlist:
        xyarray = Runge_Kutta(f,a,b,alpha,N)
        S = get_Standard_Deviation(g,xyarray)
        reslist.append((N,xyarray,S))
    return reslist

In [5]:
def trans_to_latex(a: ndarray) -> str:
    if(len(a.shape)) != 2:
        raise ValueError
    # print('\\begin{bmatrix}')
    s = '\\begin{bmatrix}\n'
    for row in a:
        for i in range(len(row)):
            s += f'{row[i]:.8f} '
            s += '\\\\\n' if i == len(row)-1 else '& '
    s += '\\end{bmatrix}\n'
    return s

## 问题求解
### 问题 1
#### (1)
$$
\frac{\mathrm{d}y}{\mathrm{d}x}=x+y,~~x\in[0,1],~~N=5,10,20,~~y(0)=-1
$$
精确解: $y=-x-1$

In [6]:
res = solve_question(
    f = lambda x,y:x+y,
    g = lambda x:-x-1,
    a=0,
    b=1,
    alpha=-1,
    Nlist=[5,10,20]
)
for N,R,S in res:
    print(f'N = {N:2d}, S = {S:.6e}')
    print('xylist:')
    # print(R)
    s = '$$\n[x,y]='+trans_to_latex(R)+'$$'
    display(Latex(s))

N =  5, S = 1.570092e-16
xylist:


<IPython.core.display.Latex object>

N = 10, S = 5.142448e-16
xylist:


<IPython.core.display.Latex object>

N = 20, S = 4.988655e-16
xylist:


<IPython.core.display.Latex object>

#### (2)
$$
\frac{\mathrm{d}y}{\mathrm{d}x}=-y^2,~~x\in[0,1],~~N=5,10,20,~~y(0)=1
$$
精确解: $y=\frac{1}{x+1}$

In [7]:
res = solve_question(
    f = lambda x,y:-y**2,
    g = lambda x:1/(x+1),
    a=0,
    b=1,
    alpha=1,
    Nlist=[5,10,20]
)
for N,R,S in res:
    print(f'N = {N:2d}, S = {S:.6e}')
    print('xylist:')
    # print(R)
    s = '$$\n[x,y]='+trans_to_latex(R)+'$$'
    display(Latex(s))

N =  5, S = 5.069083e-06
xylist:


<IPython.core.display.Latex object>

N = 10, S = 3.581699e-07
xylist:


<IPython.core.display.Latex object>

N = 20, S = 2.316655e-08
xylist:


<IPython.core.display.Latex object>

### 问题 2

#### (1)

$$
\frac{\mathrm{d}y}{\mathrm{d}x}=\frac{2y}{x}+x^2e^x,~~x\in[1,3],~~N=5,10,20,~~y(1)=0
$$
精确解: $y=x^2(e^x-e)$

In [8]:
res = solve_question(
    f = lambda x,y:2*y/x+x**2*np.exp(x),
    g = lambda x:x**2*(np.exp(x)-math.e),
    a=1,
    b=3,
    alpha=0,
    Nlist=[5,10,20]
)
for N,R,S in res:
    print(f'N = {N:2d}, S = {S:.6e}')
    print('xylist:')
    # print(R)
    s = '$$\n[x,y]='+trans_to_latex(R)+'$$'
    display(Latex(s))

N =  5, S = 4.227449e-02
xylist:


<IPython.core.display.Latex object>

N = 10, S = 3.559349e-03
xylist:


<IPython.core.display.Latex object>

N = 20, S = 2.593603e-04
xylist:


<IPython.core.display.Latex object>

#### (2)

$$
\frac{\mathrm{d}y}{\mathrm{d}x}=\frac{1}{x}(y^2+y),~~x\in[1,3],~~N=5,10,20,~~y(1)=-2
$$
精确解: $y=\frac{2x}{1-2x}$

In [9]:
res = solve_question(
    f = lambda x,y:(y+y**2)/x,
    g = lambda x:2*x/(1-2*x),
    a=1,
    b=3,
    alpha=-2,
    Nlist=[5,10,20]
)
for N,R,S in res:
    print(f'N = {N:2d}, S = {S:.6e}')
    print('xylist:')
    # print(R)
    s = '$$\n[x,y]='+trans_to_latex(R)+'$$'
    display(Latex(s))

N =  5, S = 8.636723e-04
xylist:


<IPython.core.display.Latex object>

N = 10, S = 2.098017e-05
xylist:


<IPython.core.display.Latex object>

N = 20, S = 2.991036e-07
xylist:


<IPython.core.display.Latex object>

### 问题 3
#### (1)

$$
\frac{\mathrm{d}y}{\mathrm{d}x}=-20(y-x^2)+2x,~~x\in[0,1],~~N=5,10,20,~~y(0)=\frac{1}{3}
$$
精确解: $y=x^2+\frac{1}{3}e^{-20x}$

In [10]:
res = solve_question(
    f = lambda x,y:-20*(y-x**2)+2*x,
    g = lambda x:x**2+1/3*np.exp(-20*x),
    a=0,
    b=1,
    alpha=1/3,
    Nlist=[5,10,20]
)
for N,R,S in res:
    print(f'N = {N:2d}, S = {S:.6e}')
    print('xylist:')
    # print(R)
    s = '$$\n[x,y]='+trans_to_latex(R)+'$$'
    display(Latex(s))

N =  5, S = 4.513822e+02
xylist:


<IPython.core.display.Latex object>

N = 10, S = 2.328359e-02
xylist:


<IPython.core.display.Latex object>

N = 20, S = 7.217970e-04
xylist:


<IPython.core.display.Latex object>

#### (2)

$$
\frac{\mathrm{d}y}{\mathrm{d}x}=-20y+20\sin x +\cos x,~~x\in[0,1],~~N=5,10,20,~~y(0)=1
$$
精确解: $y=e^{-20x}+\sin x$

In [11]:
res = solve_question(
    f = lambda x,y:-20*y+20*np.sin(x)+np.cos(x),
    g = lambda x:np.exp(-20*x)+np.sin(x),
    a=0,
    b=1,
    alpha=1,
    Nlist=[5,10,20]
)
for N,R,S in res:
    print(f'N = {N:2d}, S = {S:.6e}')
    print('xylist:')
    # print(R)
    s = '$$\n[x,y]='+trans_to_latex(R)+'$$'
    display(Latex(s))

N =  5, S = 1.301231e+03
xylist:


<IPython.core.display.Latex object>

N = 10, S = 6.681782e-02
xylist:


<IPython.core.display.Latex object>

N = 20, S = 2.070013e-03
xylist:


<IPython.core.display.Latex object>

#### (3)

$$
\frac{\mathrm{d}y}{\mathrm{d}x}=-20(y-e^x\sin x)+e^x(\sin x + \cos x),~~x\in[0,1],~~N=5,10,20,~~y(0)=0
$$
精确解: $y=e^x\sin x$

In [12]:
res = solve_question(
    f = lambda x,y:-20*(y-np.exp(x)*np.sin(x))+np.exp(x)*(np.sin(x)+np.cos(x)),
    g = lambda x:np.exp(x)*np.sin(x),
    a=0,
    b=1,
    alpha=0,
    Nlist=[5,10,20]
)
for N,R,S in res:
    print(f'N = {N:2d}, S = {S:.6e}')
    print('xylist:')
    # print(R)
    s = '$$\n[x,y]='+trans_to_latex(R)+'$$'
    display(Latex(s))

N =  5, S = 1.902084e+01
xylist:


<IPython.core.display.Latex object>

N = 10, S = 3.148600e-03
xylist:


<IPython.core.display.Latex object>

N = 20, S = 1.101164e-04
xylist:


<IPython.core.display.Latex object>