In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
import time
import sys
import os
import torch.sparse as sparse
from scipy.sparse import csr_matrix
import scipy.sparse
import scipy.sparse.linalg
import math
import numpy as np
import scipy 
import scipy.io
import importlib
import Burger2D_SSPRK2 as Burger2D_SSP

# Localization

In [None]:
def local_indices_pytorch(p, Nx, Ny, radius, device=None):
    """
    Compute flattened 0-based indices of a square neighborhood of radius 'radius'
    around grid point p on an Nx-by-Ny mesh.
    """
    if device is None:
        device = torch.device('cpu')
    # convert p -> (i,j)
    i = p // Nx
    j = p % Nx
    # define clipped row/col ranges
    row_min = max(0, i - radius)
    row_max = min(Nx - 1, i + radius)
    col_min = max(0, j - radius)
    col_max = min(Ny - 1, j + radius)

    rows = torch.arange(row_min, row_max + 1, device=device)
    cols = torch.arange(col_min, col_max + 1, device=device)
    # form grid
    Ii, Jj = torch.meshgrid(rows, cols, indexing='ij')
    flat_inds = (Ii * Ny + Jj).reshape(-1)
    return flat_inds.long()

In [None]:
def local_indices_block_pytorch(p, Nx, Ny, radius, device=None):
    """
    Compute flattened indices of radius 'radius' neighborhood around global index p
    in a 3-component state [X1; X2; X3], each of size (Nx*Ny).
    """
    m = Nx * Ny
    # determine component: 0,1,2
    comp = p // m
    # pixel idx within component
    q = p % m
    # local patch in component
    base_patch = local_indices_pytorch(q, Nx, Ny, radius, device)
    # shift by component block
    return base_patch + comp * m

# LETKF

In [None]:
def LETKF_local_ptorch(Xb, y, idx_obs, rho, R, Nx, Ny, gird_xy, obs_xy, radius):
    """
    A localized LETKF in PyTorch (no observation localization).

    Inputs:
      Xb     : (m_g x k) forecast ensemble tensor
      y      : (l_g,)     observation tensor
      R      : (l_g x l_g) observation error covariance tensor
      Nx, Ny : grid dimensions (integers)
      radius : localization radius (integer)

    Returns:
      Xa     : (m_g x k) analysis ensemble tensor
    """
#     device = Xb.device
    m_g, k = Xb.shape
    # Steps 1-2: global means & perturbations
    Xb_mean = Xb.mean(dim=1, keepdim=True)         # (m_g x 1)
    Xb_pert = Xb - Xb_mean                         # (m_g x k)

    # Step 1: project to obs space
    Yb_raw  = torch.atan(Xb[idx_obs, :])          # (l_g, k)    

    Yb_mean = Yb_raw.mean(dim=1, keepdim=True)     # (l_g x 1)
    Yb_pert = Yb_raw - Yb_mean                     # (l_g x k)

    # Prepare output
    Xa = torch.zeros_like(Xb)
#     Rinv = torch.linalg.inv(R)
#     print(grid_xy.shape)
    # Loop over each state component p
    for p in range(m_g):
        # Step 3: select local model-space indices
        loc_inds = local_indices_block_pytorch(p, Nx, Ny, radius, device)
        xb_m_loc = Xb_mean[loc_inds, 0]              # (m_loc,)
        Xb_loc   = Xb_pert[loc_inds, :]              # (m_loc x k)
        
        # Steps 4-7: use all observations
        loc_obs = nearest_k_obs_torch(p, grid_xy, obs_xy, 15)
        yb_m_loc = Yb_mean[loc_obs, 0]              # (l_g,)
        Yb_loc   = Yb_pert[loc_obs, :]                           # (l_g x k)
#         yo_loc   = y[loc_inds, 0]                                  # (l_g,)
                
        yo_loc = y[loc_obs, 0] # observation
        R_loc = R[loc_obs]           # now shape (n_obs, l_g)
        R_loc = R_loc[:, loc_obs] 
        # Step 5: C = 1/(k-1) * Yb_loc.T @ Rinv
        W = torch.linalg.solve(R_loc, Yb_loc)
        C = W.T        # (k x l_g)

        # Step 6: Pa = inv(I + C @ Yb_loc)
        M = (k-1)/rho * torch.eye(k, device=device) + C @ Yb_loc
        w, Q = torch.linalg.eigh(M)
        w_inv     = 1.0 / w          # for P^a = Q diag(w_inv) Q^T
        w_inv_sqrt = 1.0 / torch.sqrt(w)  # for (P^a)^{1/2}
        
        Pa = (Q * w_inv.unsqueeze(0)) @ Q.T                 # (k x k)

        # Step 7a: analysis-space perturbations
        Wa = math.sqrt(k-1) * (Q * w_inv_sqrt.unsqueeze(0)) @ Q.T
        
#         print(Pa.shape)
#         print(C.shape)
#         print(yo_loc.shape)
#         print(yb_m_loc.shape)
        # Step 7b: mean weight increment
        wabar = Pa @ (C @ (yo_loc - yb_m_loc))           # (k,)
        Wana  = Wa + wabar.unsqueeze(1)                  # (k x k)

        # Step 8: map back to model space
        xa_loc = xb_m_loc.unsqueeze(1) + Xb_loc @ Wana    # (m_loc x k)
        # extract the center component
        center_idx = (loc_inds == p).nonzero(as_tuple=False).item()
        Xa[p, :] = xa_loc[center_idx, :]

    return Xa

In [None]:
xa = -1
xb = 1
ya = -1
yb = 1
LF = 2; # 1--LF, 2--EO
T = 0.2
# T = 0.45
limiter = 1
Nx = 80
Ny = Nx
Nt = 800

hx = (xb - xa) / Nx
hy = (yb - ya) / Ny
dt = T / Nt

x = torch.linspace(xa, xb, Nx + 1)
xmid = (x[1:] + x[:-1]) / 2
y = torch.linspace(ya, yb, Ny + 1)
ymid = (y[1:] + y[:-1]) / 2

# Parameters
alpha = 1 / 4
beta = 1 / 2

# Define initial function
def u0(x, y, alpha_input, beta_input):
    return alpha_input + beta_input * torch.sin(torch.pi * (x + y))

Xmid, Ymid = torch.meshgrid(xmid, ymid, indexing='ij')
X = Xmid.reshape(-1, 1)
Y = Ymid.reshape(-1, 1)

# Compute perturbed points
X1 = X - hx / (2 * torch.sqrt(torch.tensor(3.0)))
X2 = X + hx / (2 * torch.sqrt(torch.tensor(3.0)))
Y1 = Y - hy / (2 * torch.sqrt(torch.tensor(3.0)))
Y2 = Y + hy / (2 * torch.sqrt(torch.tensor(3.0)))

# Evaluate function at these points
u11 = u0(X1, Y1, alpha, beta)
u12 = u0(X1, Y2, alpha, beta)
u21 = u0(X2, Y1, alpha, beta)
u22 = u0(X2, Y2, alpha, beta)

U0 = torch.zeros(3 * Nx * Ny, 1, dtype=torch.float64)

# Assign values
U0[0::3, :] = (u11 + u12 + u21 + u22) / 4  # lk = 00
U0[1::3, :] = 3 * (-u11 - u12 + u21 + u22) / (4 * torch.sqrt(torch.tensor(3.0)))  # lk = 10
U0[2::3, :] = 3 * (-u11 + u12 - u21 + u22) / (4 * torch.sqrt(torch.tensor(3.0)))  # lk = 01

device = 'cpu'
U0 = Burger2D_SSP.slope_limiter(U0, Nx, Ny, hx, hy, limiter)
ndim = 3*Nx*Ny
state_target = U0.clone()


state_ref = torch.zeros(Nt+1, ndim, device = device, dtype=torch.float64)
state_ref[[0], :] += torch.transpose(state_target, 0, 1)
for i in range(Nt):
    print(i)
    stateT = torch.transpose(state_ref[[i], :], 0, 1)
    sln_batch = Burger2D_SSP.GTS_RK2_onestep(stateT, dt, Nx, Ny, Nt, hx, hy, LF, limiter)
    state_ref[[i+1], :] += torch.transpose(sln_batch, 0, 1)

In [None]:
def nearest_k_obs_torch(i, grid_xy, obs_xy, K):
    m = Nx * Ny
    q = i % m
    xy = grid_xy[q]
    d2 = ((obs_xy - xy)**2).sum(dim=1)
    return torch.topk(d2, K, largest=False).indices

In [None]:
data = scipy.io.loadmat('RefSol_SSPRK2_2DBurger_T02_v2.mat')
state_ref = torch.from_numpy(data['ExactState'])
# data = scipy.io.loadmat('RefSol_SSPRK2_2DBurger_T045_v2.mat')
# state_ref = torch.from_numpy(data['ExactState'])

t0 = 0
nttrue = Nt
ntEnSF = 80
filtering_steps = ntEnSF
timeTrue = torch.linspace(0, 1, nttrue+1)
tEnSF = torch.linspace(0, 1, filtering_steps+1)
indices = torch.searchsorted(timeTrue, tEnSF, right=False)
dtEnSF = (T - t0) / ntEnSF

state_clone = state_ref[indices, :].clone()
# state_EnSF = state_ref.clone()

length = state_clone.shape[1]
num_indices = int(0.1*length) #10%
# num_indices = int(1*length) #100%
# num_indices = int(0.6*length) #60%

spa_indices = torch.randperm(length)[:num_indices]  # Random permutation, then take the first num_indices
spa_indices, _ = spa_indices.sort() 

num_arctan = int(1*spa_indices.shape[0])
idA_sub = torch.randperm(spa_indices.shape[0])[:num_arctan]
idA_sub, _ = idA_sub.sort()

state_EnSF = state_clone[:, spa_indices].clone()

state_EnSF = state_EnSF.to(torch.float64)

n_dim = 3*Nx*Ny

grid_xy = torch.stack([Xmid.reshape(-1), Ymid.reshape(-1)], dim=1)  # (Nx*Ny)×2

grid_xy_total = grid_xy.repeat(3, 1)  # (3*Nx*Ny)×2

total_pts = grid_xy_total.size(0)             # = 3*Nx*Ny

obs_xy    = grid_xy_total[spa_indices]    

In [None]:
obs_sigma = 0.1
eps_alpha = 0.05
SDE_Sigma_00 = 0.01 #v1
SDE_Sigma_10 = 0.001 #v1
SDE_Sigma_01 = 0.001 #v1
# ensemble size

ensemble_size = 100

def g_tau(t):
    return 1-t

# saving file name
exp_name = 'EnSF_2DBurger_T02_LETKF'
# exp_name = 'EnSF_2DBurger_T045_LETKF'

x_state = 2*torch.randn(ensemble_size, n_dim, device=device, dtype=torch.float64)
mem_state = state_target.element_size() * state_target.nelement()/1e+6
mem_ensemble = mem_state * ensemble_size
print(f'single state memory: {mem_state:.2f} MB')
print(f'state ensemble memory: {mem_ensemble:.2f} MB')

rmse_all = []
obs_save = []
est_save = torch.zeros(filtering_steps+1, n_dim, device = device, dtype=torch.float64)
est_save[[0], :] += torch.mean(x_state,dim=0)

In [None]:
x0filter = x_state
limiterfilter = 1
limiterSDE = 1
for i in range(filtering_steps):
    print(f'step={i}:')
    t1 = time.time()    
    
#     obs = state_EnSF[[i+1], :].clone()
    
    state_scale = state_EnSF[[i+1], :].clone()
    indob_scale1 = torch.nonzero(((-1e-1<=state_scale) & (state_scale<-1e-2)) | \
                                 ((1e-2<=state_scale) & (state_scale<1e-1)), as_tuple=True)[1]
    indob_scale2 = torch.nonzero(((-1e-2<=state_scale) & (state_scale<-1e-3)) | \
                                 ((1e-3<=state_scale) & (state_scale<1e-2)), as_tuple=True)[1]
    indob_scale3 = torch.nonzero(((-1e-3<=state_scale) & (state_scale<-1e-4)) | \
                                 ((1e-4<=state_scale) & (state_scale<1e-3)), as_tuple=True)[1]
    indob_scale4 = torch.nonzero(((-1e-4<=state_scale) & (state_scale<-1e-5)) | \
                                 ((1e-5<=state_scale) & (state_scale<1e-4)), as_tuple=True)[1]
    indob_scale5 = torch.nonzero(((-1e-5<=state_scale) & (state_scale<-1e-6)) | \
                                 ((1e-6<=state_scale) & (state_scale<1e-5)), as_tuple=True)[1]
    indob_scale6 = torch.nonzero(((-1e-6<=state_scale) & (state_scale<-1e-7)) | \
                                 ((1e-7<=state_scale) & (state_scale<1e-6)), as_tuple=True)[1]
    indob_scale7 = torch.nonzero(((-1e-7<=state_scale) & (state_scale<-1e-8)) | \
                                 ((1e-8<=state_scale) & (state_scale<1e-7)), as_tuple=True)[1]
    indob_scale8 = torch.nonzero(((-1e-8<=state_scale) & (state_scale<0)) | \
                                 ((0<=state_scale) & (state_scale<1e-8)), as_tuple=True)[1]
    
    state_scale[:, indob_scale1] *= 1e1
    state_scale[:, indob_scale2] *= 1e2
    state_scale[:, indob_scale3] *= 1e3
    state_scale[:, indob_scale4] *= 1e4
    state_scale[:, indob_scale5] *= 1e5
    state_scale[:, indob_scale6] *= 1e6
    state_scale[:, indob_scale7] *= 1e7
    state_scale[:, indob_scale8] *= 1e8
#     state_scale = torch.cat((state00, state10, state01), dim=1)
    
    obs = state_scale.clone()
    obs[:, idA_sub] = torch.atan(state_scale[:, idA_sub].clone())
    obs += torch.randn_like(state_EnSF[[i+1], :])*obs_sigma
    
    R = obs_sigma**2 * torch.eye(idA_sub.shape[0], dtype=torch.float64)
    
    x0filterT = torch.transpose(x0filter, 0, 1)
    sln_state = Burger2D_SSP.GTS_RK2_onestep(x0filterT, dtEnSF, Nx, Ny, ntEnSF, hx, hy, LF, limiterfilter)
    x_state = torch.transpose(sln_state, 0, 1)   
    
    noise00 = torch.sqrt(torch.tensor(dtEnSF))*SDE_Sigma_00*torch.randn_like(x_state[:, :Nx*Ny])
    noise10 = torch.sqrt(torch.tensor(dtEnSF))*SDE_Sigma_10*torch.randn_like(x_state[:, Nx*Ny+torch.arange(0, Nx*Ny)])
    noise01 = torch.sqrt(torch.tensor(dtEnSF))*SDE_Sigma_01*torch.randn_like(x_state[:, 2*Nx*Ny+torch.arange(0, Nx*Ny)])
    noise = torch.cat((noise00, noise10, noise01), dim=1)
    
    x_state += noise
    
    x0_EnSF = x_state[:, spa_indices].clone()
    
    indx_scale1 = torch.nonzero(((-1e-1<=x0_EnSF) & (x0_EnSF<-1e-2)) | \
                                 ((1e-2<=x0_EnSF) & (x0_EnSF<1e-1)), as_tuple=False)
    indx_scale2 = torch.nonzero(((-1e-2<=x0_EnSF) & (x0_EnSF<-1e-3)) | \
                                 ((1e-3<=x0_EnSF) & (x0_EnSF<1e-2)), as_tuple=False)
    indx_scale3 = torch.nonzero(((-1e-3<=x0_EnSF) & (x0_EnSF<-1e-4)) | \
                                 ((1e-4<=x0_EnSF) & (x0_EnSF<1e-3)), as_tuple=False)
    indx_scale4 = torch.nonzero(((-1e-4<=x0_EnSF) & (x0_EnSF<-1e-5)) | \
                                 ((1e-5<=x0_EnSF) & (x0_EnSF<1e-4)), as_tuple=False)
    indx_scale5 = torch.nonzero(((-1e-5<=x0_EnSF) & (x0_EnSF<-1e-6)) | \
                                 ((1e-6<=x0_EnSF) & (x0_EnSF<1e-5)), as_tuple=False)
    indx_scale6 = torch.nonzero(((-1e-6<=x0_EnSF) & (x0_EnSF<-1e-7)) | \
                                 ((1e-7<=x0_EnSF) & (x0_EnSF<1e-6)), as_tuple=False)
    indx_scale7 = torch.nonzero(((-1e-7<=x0_EnSF) & (x0_EnSF<-1e-8)) | \
                                 ((1e-8<=x0_EnSF) & (x0_EnSF<1e-7)), as_tuple=False)
    indx_scale8 = torch.nonzero(((-1e-8<=x0_EnSF) & (x0_EnSF<0)) | \
                                 ((0<=x0_EnSF) & (x0_EnSF<1e-8)), as_tuple=False)

    x0_EnSF[indx_scale1[:, 0], indx_scale1[:, 1]] *= 1e1
    x0_EnSF[indx_scale2[:, 0], indx_scale2[:, 1]] *= 1e2
    x0_EnSF[indx_scale3[:, 0], indx_scale3[:, 1]] *= 1e3
    x0_EnSF[indx_scale4[:, 0], indx_scale4[:, 1]] *= 1e4 
    x0_EnSF[indx_scale5[:, 0], indx_scale5[:, 1]] *= 1e5 
    x0_EnSF[indx_scale6[:, 0], indx_scale6[:, 1]] *= 1e6 
    x0_EnSF[indx_scale7[:, 0], indx_scale7[:, 1]] *= 1e7 
    x0_EnSF[indx_scale8[:, 0], indx_scale8[:, 1]] *= 1e8
    
    x_state[:, spa_indices] = x0_EnSF.clone()
#     Xb_LETKF = x_state[:, spa_indices].clone()
    Xb_LETKF = x_state.clone()
    Xb = torch.transpose(Xb_LETKF, 0, 1)
    y  = torch.transpose(obs.clone(), 0, 1)
#     rho = 2
    rho = 4
    radius = 2
    Xa_LETKF = LETKF_local_ptorch(Xb, y, spa_indices, rho, R, Nx, Ny, grid_xy_total, obs_xy, radius)
    
    x_scale = torch.transpose(Xa_LETKF.clone(), 0, 1)
    x0_EnSF = x_scale[:, spa_indices].clone()
    x0_EnSF[:, indob_scale1] /= 1e1
    x0_EnSF[:, indob_scale2] /= 1e2
    x0_EnSF[:, indob_scale3] /= 1e3
    x0_EnSF[:, indob_scale4] /= 1e4
    x0_EnSF[:, indob_scale5] /= 1e5
    x0_EnSF[:, indob_scale6] /= 1e6
    x0_EnSF[:, indob_scale7] /= 1e7
    x0_EnSF[:, indob_scale8] /= 1e8
    
#     x_state = x_scale.clone()
    x_state[:, spa_indices] = x0_EnSF.clone()
    x_state =  torch.transpose(Burger2D_SSP.slope_limiter(torch.transpose(x_state, 0, 1), Nx, Ny, hx, hy,\
                                                              limiterSDE), 0, 1)
    
#         x_state[:, 0::3] = torch.clamp(x_state[:, 0::3], min = -0.3, max = 0.81) ## T =0.45

    x_state[:, 0::3] = torch.clamp(x_state[:, 0::3], min = -0.3, max = 0.81)
    x0filter = torch.zeros_like(x_state)
    x0filter += x_state
    est_save[[i+1], :] += torch.mean(x_state,dim=0)
    # get rmse
    rmse_temp = torch.sqrt(torch.mean((est_save[[i+1], :]  - state_clone[[i+1], :])**2)).item()
    
    # get time
    if x_state.device.type == 'cuda':
        torch.cuda.current_stream().synchronize()
    t2 = time.time()
    print(f'\t RMSE = {rmse_temp:.4f}')
    print(f'\t time = {t2-t1:.4f} ')

    # save information
    rmse_all.append(rmse_temp)
    # break
    if rmse_temp > 1000:
        print('diverge!')
        break

state_savepy  = state_ref.clone()
est_savepy = est_save.clone()

state_savepy = state_savepy.cpu().numpy()
est_savepy = est_savepy.cpu().numpy()
rmse_all = np.array(rmse_all)

In [None]:
scipy.io.savemat('LETKF_2DBurger_MixedObs_SSPRK_T02_10Obs_v1.mat',\
                     {'Est_State':est_savepy,'rmse':rmse_all, 'Nx':Nx, 'nttrue':nttrue, 'ntEnSF': ntEnSF})

# scipy.io.savemat('LETKF_2DBurger_MixedObs_SSPRK_T045_10Obs_v1.mat',\
#                      {'Est_State':est_savepy,'rmse':rmse_all, 'Nx':Nx, 'nttrue':nttrue, 'ntEnSF': ntEnSF})