exact solution을 steady state가 아닌 일반해로 잡고 error을 보려고 했으나 코드 확인 차 간단하게 돌려봤을 때만 해도 계산 시간이 너무 오래 걸려서 포기했다.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress

# dt의 L2norm error

error_list=[]
dt_list=[]

def S(x, y):
    return 2*(2-x**2-y**2)

def exact_pi(x, y, t, M=20):
    phi_ho = (1 - x**2) * (1 - y**2)
    phi_non = 0.0

    for m in range(1, M+1):
        for n in range(1, M+1):
            Cmn = fourier_coeff_mn(m, n)
            phi_non += Cmn * np.sin((m*np.pi/2)*(x + 1)) * np.sin((n*np.pi/2)*(y+1)) * np.exp(-(np.pi**2*(n**2+m**2)/4)*t)

    return phi_ho + phi_non

def fourier_coeff_mn(m, n):
    from scipy.integrate import quad

    def inner_y(y):
        return (1 - y**2) * np.sin(n * np.pi / 2 * (y + 1))
    def integrand_x(x):
        inner = quad(inner_y, -1, 1)[0]
        return (1 - x**2) * np.sin(m * np.pi / 2 * (x + 1)) * inner

    coeff = -4 * quad(integrand_x, -1, 1)[0]
    return coeff

n=101
x_hani=np.linspace(-1, 1, n)
x_list=x_hani[1:-1]
y_hani=np.linspace(-1, 1, n)
y_list=y_hani[1:-1]
X,Y=np.meshgrid(x_hani, y_hani)
h=x_hani[1]-x_hani[0]
pi=np.zeros((n, n))
alpha=1


for k in [20, 80, 140, 200]:
    dt=1/k
    t=0
    dt_list.append(dt)
    pi_list=[pi,]
    beta=alpha*dt/(2*h**2)

    I_bLx_main = (1-2*beta)*np.eye(n-2)
    I_bLx_upper = (beta)*np.eye(n-2, k=1)
    I_bLx_lower = (beta)*np.eye(n-2, k=-1)
    I_bLx = I_bLx_lower + I_bLx_main + I_bLx_upper

    I__bLx_main = (1+2*beta)*np.eye(n-2)
    I__bLx_upper = (-1*beta)*np.eye(n-2, k=1)
    I__bLx_lower = (-1*beta)*np.eye(n-2, k=-1)
    I__bLx = I__bLx_lower + I__bLx_main + I__bLx_upper
    for j in range(10*k):
        t+=dt
        phi_exact = exact_pi(X, Y, t)
        R=I_bLx @ (I_bLx @ pi_list[j][1:-1, 1:-1].T).T +S(X[1:-1,1:-1], Y[1:-1,1:-1])*dt
        psi=np.linalg.solve(I__bLx, R)
        pi_new = psi @ np.linalg.inv(I__bLx.T)
        pi_all=np.zeros((n,n))
        pi_all[1:-1, 1:-1]=pi_new
        pi_list.append(pi_all.copy())

    error=np.linalg.norm(exact_pi(X, Y, t)-pi_list[-1], 2)*h
    error_list.append(error.copy())

print(dt_list)
print(error_list)

plt.plot(np.log10(dt_list), np.log10(error_list), marker='o', color='m', label='CN error')
plt.plot(np.log10(dt_list), 2*(np.log10(dt_list)-np.log10(dt_list[0]))+np.log10(error_list[0]), linestyle='--', color='b', label='기울기 2')
plt.xlabel('log10(dt)')
plt.ylabel('log10(L2norm error)')
plt.title('L2norm error by changing dt')
plt.grid()
plt.legend()
plt.show()

print(f'order of accuracy: {linregress(np.log10(dt_list), np.log10(error_list)).slope}')
