In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
import time
import sys
import os
import torch.sparse as sparse
from scipy.sparse import csr_matrix
import scipy.sparse
import scipy.sparse.linalg
import math
import numpy as np
import scipy 
import scipy.io
import importlib
import Burger2D_SSPRK2 as Burger2D_SSP

In [2]:
importlib.reload(Burger2D_SSP)

<module 'Burger2D_SSPRK2' from 'D:\\Anaconda\\Code\\EnSF\\2DBurger\\Burgers2D_ETD\\BurgerCode\\Burger2D_SSPRK2.py'>

# Run this block for the reference solution

In [4]:
xa = -1
xb = 1
ya = -1
yb = 1
LF = 2; # 1--LF, 2--EO

## Final time T
T = 0.2
# T = 0.45 
limiter = 1 
Nx = 80
Ny = Nx
Nt = 800

hx = (xb - xa) / Nx
hy = (yb - ya) / Ny
dt = T / Nt

x = torch.linspace(xa, xb, Nx + 1)
xmid = (x[1:] + x[:-1]) / 2
y = torch.linspace(ya, yb, Ny + 1)
ymid = (y[1:] + y[:-1]) / 2

# Parameters
alpha = 1 / 4
beta = 1 / 2

# Define initial function
def u0(x, y, alpha_input, beta_input):
    return alpha_input + beta_input * torch.sin(torch.pi * (x + y))

Xmid, Ymid = torch.meshgrid(xmid, ymid, indexing='ij')
X = Xmid.reshape(-1, 1)
Y = Ymid.reshape(-1, 1)

# Compute perturbed points
X1 = X - hx / (2 * torch.sqrt(torch.tensor(3.0)))
X2 = X + hx / (2 * torch.sqrt(torch.tensor(3.0)))
Y1 = Y - hy / (2 * torch.sqrt(torch.tensor(3.0)))
Y2 = Y + hy / (2 * torch.sqrt(torch.tensor(3.0)))

# Evaluate function at these points
u11 = u0(X1, Y1, alpha, beta)
u12 = u0(X1, Y2, alpha, beta)
u21 = u0(X2, Y1, alpha, beta)
u22 = u0(X2, Y2, alpha, beta)

U0 = torch.zeros(3 * Nx * Ny, 1, dtype=torch.float64)

# Discrete initial condition
U0[0::3, :] = (u11 + u12 + u21 + u22) / 4  # lk = 00
U0[1::3, :] = 3 * (-u11 - u12 + u21 + u22) / (4 * torch.sqrt(torch.tensor(3.0)))  # lk = 10
U0[2::3, :] = 3 * (-u11 + u12 - u21 + u22) / (4 * torch.sqrt(torch.tensor(3.0)))  # lk = 01

device = 'cpu'
U0 = Burger2D_SSP.slope_limiter(U0, Nx, Ny, hx, hy, limiter)
ndim = 3*Nx*Ny
state_target = U0.clone()

## Simulation for the reference solution
state_ref = torch.zeros(Nt+1, ndim, device = device, dtype=torch.float64)
state_ref[[0], :] += torch.transpose(state_target, 0, 1)
for i in range(Nt):
    print(i)
    stateT = torch.transpose(state_ref[[i], :], 0, 1)
    sln_batch = Burger2D_SSP.GTS_RK2_onestep(stateT, dt, Nx, Ny, Nt, hx, hy, LF, limiter)
    state_ref[[i+1], :] += torch.transpose(sln_batch, 0, 1)

KeyboardInterrupt: 

# Save the reference solution for future use

In [None]:
state_ref_save = state_ref.clone()
state_ref_save = state_ref_save.numpy()

# scipy.io.savemat('RefSol_SSPRK2_2DBurger_T045_v2.mat', {'ExactState':state_ref_save, 'hx':hx, 'hy':hy})
scipy.io.savemat('RefSol_SSPRK2_2DBurger_T02_v2.mat', {'ExactState':state_ref_save, 'hx':hx, 'hy':hy})

# EnSF code

In [None]:
def cond_alpha(t):
    # conditional information
    # alpha_t(0) = 1
    # alpha_t(1) = esp_alpha \approx 0
    return 1 - (1-eps_alpha)*t


def cond_sigma_sq(t):
    # conditional sigma^2
    # sigma2_t(0) = 0
    # sigma2_t(1) = 1
    # sigma(t) = t
    return t

def cond_sigma_sq(t):
    # conditional sigma^2
    # sigma2_t(0) = 0
    # sigma2_t(1) = 1
    # sigma(t) = t
    return t

def f(t):
    # f=d_(log_alpha)/dt
    alpha_t = cond_alpha(t)
    f_t = -(1-eps_alpha) / alpha_t
    return f_t

def g_sq(t):
    # g = d(sigma_t^2)/dt -2f sigma_t^2
    d_sigma_sq_dt = 1
    g2 = d_sigma_sq_dt - 2*f(t)*cond_sigma_sq(t)
    return g2

def g(t):
    return np.sqrt(g_sq(t))

def reverse_SDE(idA_sub, size, x0, time_steps, C, score_likelihood=None, drift_fun=f,\
                diffuse_fun=g, alpha_fun=cond_alpha, sigma2_fun=cond_sigma_sq, save_path=False):
    # x_T: sample from standard Gaussian
    # x_0: target distribution to sample from

    # reverse SDE sampling process
    # N1 = x_T.shape[0]
    # N2 = x0.shape[0]
    # d = x_T.shape[1]

    # Generate the time mesh
    dt = 1.0/time_steps

    # Initialization
    xt = torch.randn(x0.shape[0], x0.shape[1], device=device)
    t = 1.0

    # define storage
    if save_path:
        path_all = [xt]
        t_vec = [t]

    # forward Euler sampling
    for i in range(time_steps):
        # prior score evaluation
        alpha_t = alpha_fun(t)
        sigma2_t = sigma2_fun(t)


        # Evaluate the diffusion term
        diffuse = diffuse_fun(t)

        # Evaluate the drift term
        # drift = drift_fun(t)*xt - diffuse**2 * score_eval

        # Update
        if score_likelihood is not None:
#             zt = score_likelihood(xt, t)
#             print(zt.size())
            xt += -dt*( drift_fun(t)*xt+diffuse**2*((xt-alpha_t*x0)/sigma2_t)-\
                       diffuse**2*score_likelihood(idA_sub, xt, t, C)) +np.sqrt(dt)*diffuse*torch.randn_like(xt)
    
        else:
            xt += -dt*(drift_fun(t)*xt+diffuse**2*((xt-alpha_t*x0)/sigma2_t))+np.sqrt(dt)*diffuse*torch.randn_like(xt)
#         xt = torch.clamp(xt, min = -3, max = 3)
#         mass_state = hx * torch.sum(xt[:, ::2], dim=1)
#         xt -= (mass_state[:, None]-mass0) / (hx*Nx)
        # Store the state in the path
        if save_path:
            path_all.append(xt)
            t_vec.append(t)

        # update time
        t = t - dt
    xt = xt.to(torch.float64)
    if save_path:
        return path_all, t_vec
    else:
        return xt

# Setting for the EnSF algorithm

In [None]:
data = scipy.io.loadmat('RefSol_SSPRK2_2DBurger_T02_v2.mat')
state_ref = torch.from_numpy(data['ExactState'])
# data = scipy.io.loadmat('RefSol_SSPRK2_2DBurger_T045_v2.mat')
# state_ref = torch.from_numpy(data['ExactState'])

t0 = 0
nttrue = Nt
ntEnSF = 80
filtering_steps = ntEnSF
timeTrue = torch.linspace(0, 1, nttrue+1)
tEnSF = torch.linspace(0, 1, filtering_steps+1)
indices = torch.searchsorted(timeTrue, tEnSF, right=False)
dtEnSF = (T - t0) / ntEnSF

state_clone = state_ref[indices, :].clone()
# state_EnSF = state_ref.clone()

length = state_clone.shape[1]
num_indices = int(1*length) #100%

# num_indices = int(0.6*length) #60%


spa_indices = torch.randperm(length)[:num_indices]  # Random permutation, then take the first num_indices
spa_indices, _ = spa_indices.sort() 

num_arctan = int(spa_indices.shape[0])

if num_arctan == 0:
    idA_sub = torch.tensor([], dtype=torch.long, device=spa_indices.device)
else:
    idA_sub = torch.randperm(spa_indices.shape[0])[:num_arctan]
    idA_sub, _ = idA_sub.sort()
    

state_EnSF = state_clone[:, spa_indices].clone()

state_EnSF = state_EnSF.to(torch.float64)

n_dim = 3*Nx*Ny

In [None]:
obs_sigma = 0.1
eps_alpha = 0.05
SDE_Sigma_00 = 0.01 #v1
SDE_Sigma_10 = 0.001 #v1
SDE_Sigma_01 = 0.001 #v1
# ensemble size
ensemble_size = 80
ensemble_true = 1
# forward Euler step
# euler_steps = 300
euler_steps = 400

def g_tau(t):
    return 1-t

# saving file name
exp_name = 'EnSF_2DBurger_T02_arctan'
# exp_name = 'EnSF_2DBurger_T045_arctan'

x_state = 2*torch.randn(ensemble_size, n_dim, device=device, dtype=torch.float64) # initial condition for EnSF
mem_state = state_target.element_size() * state_target.nelement()/1e+6
mem_ensemble = mem_state * ensemble_size
print(f'single state memory: {mem_state:.2f} MB')
print(f'state ensemble memory: {mem_ensemble:.2f} MB')

rmse_all = []
obs_save = []
est_save = torch.zeros(filtering_steps+1, n_dim, device = device, dtype=torch.float64)
est_save[[0], :] += torch.mean(x_state,dim=0)

# Filtering Process by EnSF

In [None]:
x0filter = x_state
limiterfilter = 1
limiterSDE = 1
for i in range(filtering_steps):
    print(f'step={i}:')
    t1 = time.time()    

    state_scale = state_EnSF[[i+1], :].clone()
    
    indob_scale1 = torch.nonzero(((-1e-1<=state_scale) & (state_scale<-1e-2)) | \
                                 ((1e-2<=state_scale) & (state_scale<1e-1)), as_tuple=True)[1]
    indob_scale2 = torch.nonzero(((-1e-2<=state_scale) & (state_scale<-1e-3)) | \
                                 ((1e-3<=state_scale) & (state_scale<1e-2)), as_tuple=True)[1]
    indob_scale3 = torch.nonzero(((-1e-3<=state_scale) & (state_scale<-1e-4)) | \
                                 ((1e-4<=state_scale) & (state_scale<1e-3)), as_tuple=True)[1]
    indob_scale4 = torch.nonzero(((-1e-4<=state_scale) & (state_scale<-1e-5)) | \
                                 ((1e-5<=state_scale) & (state_scale<1e-4)), as_tuple=True)[1]
    indob_scale5 = torch.nonzero(((-1e-5<=state_scale) & (state_scale<-1e-6)) | \
                                 ((1e-6<=state_scale) & (state_scale<1e-5)), as_tuple=True)[1]
    indob_scale6 = torch.nonzero(((-1e-6<=state_scale) & (state_scale<-1e-7)) | \
                                 ((1e-7<=state_scale) & (state_scale<1e-6)), as_tuple=True)[1]
    indob_scale7 = torch.nonzero(((-1e-7<=state_scale) & (state_scale<-1e-8)) | \
                                 ((1e-8<=state_scale) & (state_scale<1e-7)), as_tuple=True)[1]
    indob_scale8 = torch.nonzero(((-1e-8<=state_scale) & (state_scale<-1e-9)) | \
                                 ((1e-9<=state_scale) & (state_scale<1e-8)), as_tuple=True)[1]
    
    indob_scale9 = torch.nonzero(((-1e-9<=state_scale) & (state_scale<-1e-10)) | \
                                 ((1e-10<=state_scale) & (state_scale<1e-9)), as_tuple=True)[1]
    
    indob_scale10 = torch.nonzero(((-1e-10<=state_scale) & (state_scale<-1e-11)) | \
                                 ((1e-11<=state_scale) & (state_scale<1e-10)), as_tuple=True)[1]
    
    indob_scale11 = torch.nonzero(((-1e-11<=state_scale) & (state_scale<0)) | \
                                 ((0<=state_scale) & (state_scale<1e-11)), as_tuple=True)[1]
    
    state_scale[:, indob_scale1] *= 1e1
    state_scale[:, indob_scale2] *= 1e2
    state_scale[:, indob_scale3] *= 1e3
    state_scale[:, indob_scale4] *= 1e4
    state_scale[:, indob_scale5] *= 1e5
    state_scale[:, indob_scale6] *= 1e6
    state_scale[:, indob_scale7] *= 1e7
    state_scale[:, indob_scale8] *= 1e8
    state_scale[:, indob_scale9] *= 1e9
    state_scale[:, indob_scale10] *= 1e10
    state_scale[:, indob_scale11] *= 1e11
    
    obs = torch.atan(state_scale.clone())
    obs += torch.randn_like(state_EnSF[[i+1], :])*obs_sigma
    
    def score_likelihood(idA_sub, xt, t, C):
        # obs: (d)
        # xt: (ensemble, d)
        score_x = -(torch.atan(xt) - obs)/obs_sigma**2 * (1./(1+xt**2))
        tau = g_tau(t)
        return tau*score_x/C
       
    x0filterT = torch.transpose(x0filter, 0, 1)
    sln_state = Burger2D_SSP.GTS_RK2_onestep(x0filterT, dtEnSF, Nx, Ny, ntEnSF, hx, hy, LF, limiterfilter)
    x_state = torch.transpose(sln_state, 0, 1)   
    
    x_state = Burger2D_SSP.ReOrderSol(x_state, Nx, Ny) 
    
    noise00 = torch.sqrt(torch.tensor(dtEnSF))*SDE_Sigma_00*torch.randn_like(x_state[:, :Nx*Ny])
    noise10 = torch.sqrt(torch.tensor(dtEnSF))*SDE_Sigma_10*torch.randn_like(x_state[:, Nx*Ny+torch.arange(0, Nx*Ny)])
    noise01 = torch.sqrt(torch.tensor(dtEnSF))*SDE_Sigma_01*torch.randn_like(x_state[:, 2*Nx*Ny+torch.arange(0, Nx*Ny)])
    noise = torch.cat((noise00, noise10, noise01), dim=1)
    
    x_state += noise
    
    x_state = Burger2D_SSP.ReverseOrderSol(x_state, Nx, Ny)
    
    x0_EnSF = x_state[:, spa_indices].clone()
    
    for l in range(8):      
        indx_scale1 = torch.nonzero(((-1e-1<=x0_EnSF) & (x0_EnSF<-1e-2)) | \
                                     ((1e-2<=x0_EnSF) & (x0_EnSF<1e-1)), as_tuple=False)
        indx_scale2 = torch.nonzero(((-1e-2<=x0_EnSF) & (x0_EnSF<-1e-3)) | \
                                     ((1e-3<=x0_EnSF) & (x0_EnSF<1e-2)), as_tuple=False)
        indx_scale3 = torch.nonzero(((-1e-3<=x0_EnSF) & (x0_EnSF<-1e-4)) | \
                                     ((1e-4<=x0_EnSF) & (x0_EnSF<1e-3)), as_tuple=False)
        indx_scale4 = torch.nonzero(((-1e-4<=x0_EnSF) & (x0_EnSF<-1e-5)) | \
                                     ((1e-5<=x0_EnSF) & (x0_EnSF<1e-4)), as_tuple=False)
        
        indx_scale5 = torch.nonzero(((-1e-5<=x0_EnSF) & (x0_EnSF<-1e-6)) | \
                                     ((1e-6<=x0_EnSF) & (x0_EnSF<1e-5)), as_tuple=False)
        indx_scale6 = torch.nonzero(((-1e-6<=x0_EnSF) & (x0_EnSF<-1e-7)) | \
                                     ((1e-7<=x0_EnSF) & (x0_EnSF<1e-6)), as_tuple=False)
        indx_scale7 = torch.nonzero(((-1e-7<=x0_EnSF) & (x0_EnSF<-1e-8)) | \
                                     ((1e-8<=x0_EnSF) & (x0_EnSF<1e-7)), as_tuple=False)
        indx_scale8 = torch.nonzero(((-1e-8<=x0_EnSF) & (x0_EnSF<-1e-09)) | \
                                     ((1e-9<=x0_EnSF) & (x0_EnSF<1e-8)), as_tuple=False)
        
        indx_scale9 = torch.nonzero(((-1e-9<=x0_EnSF) & (x0_EnSF<-1e-10)) | \
                                     ((1e-10<=x0_EnSF) & (x0_EnSF<1e-9)), as_tuple=False)
        
        indx_scale10 = torch.nonzero(((-1e-10<=x0_EnSF) & (x0_EnSF<-1e-11)) | \
                                     ((1e-11<=x0_EnSF) & (x0_EnSF<1e-10)), as_tuple=False)
        
        indx_scale11 = torch.nonzero(((-1e-11<=x0_EnSF) & (x0_EnSF<0)) | \
                                     ((0<=x0_EnSF) & (x0_EnSF<1e-11)), as_tuple=False)
        
        x0_EnSF[indx_scale1[:, 0], indx_scale1[:, 1]] *= 1e1
        x0_EnSF[indx_scale2[:, 0], indx_scale2[:, 1]] *= 1e2
        x0_EnSF[indx_scale3[:, 0], indx_scale3[:, 1]] *= 1e3
        x0_EnSF[indx_scale4[:, 0], indx_scale4[:, 1]] *= 1e4 
        x0_EnSF[indx_scale5[:, 0], indx_scale5[:, 1]] *= 1e5 
        x0_EnSF[indx_scale6[:, 0], indx_scale6[:, 1]] *= 1e6 
        x0_EnSF[indx_scale7[:, 0], indx_scale7[:, 1]] *= 1e7 
        x0_EnSF[indx_scale8[:, 0], indx_scale8[:, 1]] *= 1e8 
        x0_EnSF[indx_scale9[:, 0], indx_scale9[:, 1]] *= 1e9
        x0_EnSF[indx_scale10[:, 0], indx_scale10[:, 1]] *= 1e10
        x0_EnSF[indx_scale11[:, 0], indx_scale11[:, 1]] *= 1e11
        
        sln_bar = reverse_SDE(idA_sub, num_indices, x0=x0_EnSF.clone(), time_steps=euler_steps, \
                              C=1, score_likelihood=score_likelihood)
        ## v1a
        sln_bar[:, indob_scale1] /= 1e1
        sln_bar[:, indob_scale2] /= 1e2
        sln_bar[:, indob_scale3] /= 1e3
        sln_bar[:, indob_scale4] /= 1e4
        sln_bar[:, indob_scale5] /= 1e5
        sln_bar[:, indob_scale6] /= 1e6
        sln_bar[:, indob_scale7] /= 1e7
        sln_bar[:, indob_scale8] /= 1e8
        sln_bar[:, indob_scale9] /= 1e9
        sln_bar[:, indob_scale10] /= 1e10
        sln_bar[:, indob_scale11] /= 1e11
     
        x_state[:, spa_indices] = sln_bar.clone()
#         x_state[:, 0::3] = torch.clamp(x_state[:, 0::3], min = -0.3, max = 0.82)
        x_state =  torch.transpose(Burger2D_SSP.slope_limiter(torch.transpose(x_state, 0, 1), Nx, Ny, hx, hy,\
                                                              limiterSDE), 0, 1)
    
        x_state[:, 0::3] = torch.clamp(x_state[:, 0::3], min = -0.35, max = 0.78)
        
        x0_EnSF = x_state[:, spa_indices].clone()
    
    x0filter = torch.zeros_like(x_state)
    x0filter += x_state
    est_save[[i+1], :] += torch.mean(x_state,dim=0)
    # get rmse
    rmse_temp = torch.sqrt(torch.mean((est_save[[i+1], :]  - state_clone[[i+1], :])**2)).item()

    # get time
    if x_state.device.type == 'cuda':
        torch.cuda.current_stream().synchronize()
    t2 = time.time()
    print(f'\t RMSE = {rmse_temp:.4f}')
    print(f'\t time = {t2-t1:.4f} ')

    # save information
    rmse_all.append(rmse_temp)
    # break
    if rmse_temp > 1000:
        print('diverge!')
        break

state_savepy  = state_ref.clone()
est_savepy = est_save.clone()

state_savepy = state_savepy.cpu().numpy()
est_savepy = est_savepy.cpu().numpy()
rmse_all = np.array(rmse_all)

# Save the results in .mat file

In [None]:
scipy.io.savemat('EnSF_2DBurger_MixedObs_SSPRK_T02_100Obs_allarctan_v1.mat',\
                     {'Est_State':est_savepy, 'rmse':rmse_all, 'Nx':Nx, 'nttrue':nttrue, 'ntEnSF': ntEnSF})

# scipy.io.savemat('EnSF_2DBurger_MixedObs_SSPRK_T02_60Obs_allarctan_v1.mat',\
#                      {'Est_State':est_savepy, 'rmse':rmse_all, 'Nx':Nx, 'nttrue':nttrue, 'ntEnSF': ntEnSF})

# scipy.io.savemat('EnSF_2DBurger_MixedObs_SSPRK_T045_100Obs_allarctan_v1.mat',\
#                      {'Est_State':est_savepy, 'rmse':rmse_all, 'Nx':Nx, 'nttrue':nttrue, 'ntEnSF': ntEnSF})

# scipy.io.savemat('EnSF_2DBurger_MixedObs_SSPRK_T045_60Obs_allarctan_v1.mat',\
#                      {'Est_State':est_savepy, 'rmse':rmse_all, 'Nx':Nx, 'nttrue':nttrue, 'ntEnSF': ntEnSF})