Burgers equation in time with state-parameter - Implementing the EnKF model

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
import prevision as prv

In [None]:
folder = 'Burgers_time_cx'

Nx = 2**6
Nt = 2**18
dx = 1/Nx
dz = 4*dx
dt = 1/Nt
xx = np.arange(0,1+dx,dx)
zz = np.arange(0,1+dz,dz)
tt = np.arange(0,0.5+dt,dt)
nu = 0.025
u0 = np.load('../Data/Burgers_time/u0.npy')
resolution = Nt/64

def compute_ex_sol(xx, tt, u0, nu, resolution):
    u=[]
    uh = np.zeros((xx.shape[0],tt.shape[0]+1))
    uh[:,0] = u0
    for j in range(0, tt.shape[0]):
        cx = 1.2*np.ones_like(u0)
        for i in range(1, xx.shape[0]-1):
            uh[i,j+1] = uh[i,j] + nu*dt*(uh[i+1,j] - 2*uh[i,j] + uh[i-1,j])/(dx**2) - 0.5*cx[i-1]*dt*(uh[i,j]**2-uh[i-1,j]**2)/dx
        if j==0:
            u.append(np.concatenate((u0,cx), axis=None))
        elif np.mod(j,resolution)==0:
            u.append(np.concatenate((uh[:,j],cx), axis=None))
    return np.array(u)

# Exact solution
u_ex = compute_ex_sol(xx,tt,u0,nu,resolution)
x0 = np.concatenate((u0, 0.95*1.2*np.ones_like(u0)), axis=None)

T = 0.5
Nt = 32
dt = T/Nt
dim = 2*(Nx+1)

# Define the measurament function
def hx(x):
   return x

# Define the transition function
FNO=keras.models.load_model('../data/' + folder +'/Burgers_time_cx_FNO.h5', compile=False)
def fxx(u, dt):
    d = int(dim/2)
    factor = np.amax(np.abs(u[:d]))
    return np.concatenate((factor*(FNO(np.array([[u[:d], u[d:]],]))[0]),u[d:]), axis=None)

    # Define the covariance matrix
P = np.cov(u_ex, rowvar=False)
# Define the measurament noise
R = 0.1*np.eye(dim)
# Define the process noise
Q = 0.1*np.eye(dim)

# Define the data acquisition function
def get_sensor_reading(t):
    i = np.int32(t/dt)
    return u_ex[i,:]

# Create the model from library
f = prv.EnKF(dim_x=dim, dim_z=dim, f=fxx, h=hx, get_data=get_sensor_reading, dt=dt, t0=0)
f.create_model(x0=x0, P=P, R=R, Q=Q, N=10000)

# Predict/Update loop
u_hat = f.loop(T, verbose=True)

d = int(dim/2)
plt.figure()
for t_index in range(0,u_ex.shape[0]):    
    plt.title('Solution at time: '+str(t_index*dt))
    plt.grid(True)
    plt.plot(u_hat[t_index,:d], label='estimated solution')
    plt.plot(u_ex[t_index,:d], label='exact solution', linestyle='--')
    plt.plot(u_hat[t_index,d:], label='estimated parameter')
    plt.plot(u_ex[t_index,d:], label='exact parameter', linestyle='--')
    plt.xlabel('x [-]')
    plt.ylabel('u [-]')
    plt.ylim([-1,1.5])
    plt.legend(loc='lower left')
    plt.savefig('../Burgers_time_cx_EnKF_' + str(t_index) + '.png', dpi=300)
    plt.clf()