In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

from model import lorenz63_fdm
from nmc import nmc

In [2]:
from scipy.linalg import sqrtm

In [3]:
x0 = np.array([[8, 0, 30]]).T
end_time = 10
dt = 0.01
ts = np.arange(0, end_time, dt)

nature = lorenz63_fdm(x0, ts)
nature.shape

(3, 1000)

In [4]:
Pb = nmc(lorenz63_fdm, nature, dt, 1, 0.04)
Pb

array([[ 2.2841561 ,  2.15577354, -0.57402692],
       [ 2.15577354,  2.7865047 , -0.16715354],
       [-0.57402692, -0.16715354,  6.00974005]])

In [5]:
obs_intv = 8
obs = nature + np.sqrt(2) * np.random.randn(*nature.shape)
obs = obs[:,::obs_intv]
obs.shape

(3, 125)

In [6]:
R = np.eye(3) * 2
R

array([[2., 0., 0.],
       [0., 2., 0.],
       [0., 0., 2.]])

In [7]:
# a very bad initial condition
X_ini = x0 + np.array([[10, -10, 15]]).T
X_ini

array([[ 18],
       [-10],
       [ 45]])

In [8]:
N_ens = 30
X_ens_ini = np.random.multivariate_normal(X_ini.ravel(), Pb, size=N_ens).T  # (ndim, N_ens)
X_ens_ini.shape

(3, 30)

In [9]:
def da_rmse(nature, analysis, obs_intv):
    return np.sqrt(np.mean((analysis[:,::obs_intv] - nature[:,::obs_intv]) ** 2, axis=0))

def plot_assimilation_result(nature, obs, analysis, obs_intv):
    fig, axs = plt.subplots(nrows=4, figsize=(8, 8), sharex=True)
    for i in range(3):
        axs[i].plot(ts, nature[i,:], color='#024BC7', label='nature')
        axs[i].plot(ts[::obs_intv], obs[i,:], '.', color='#024BC7', label='obs')
        axs[i].plot(ts, analysis[i,:], color='#FFA500', label='analysis')
    axs[0].legend()
    axs[0].set_title('X')
    axs[1].set_title('Y')
    axs[2].set_title('Z')
    
    rmse = da_rmse(nature, analysis, obs_intv)
    axs[3].plot(ts[::obs_intv], rmse, '.-')
    axs[3].set_title('RMSE')
    
    plt.tight_layout()

In [25]:
class DiagWarning(UserWarning):
    """Used in serially assimilation when R is not diagonal"""
    pass


class DAbase:
    def __init__(self, model, dt, store_history=False):
        self._isstore = store_history
        self._params = {'alpha': 0, 'inflat': 1}
        self.model = model
        self.dt = dt
        self.X_ini = None
        
    def set_params(self, param_list, **kwargs):
        for key, value in kwargs.items():
            if key in param_list:
                self._params[key] = kwargs.get(key)
            else:
                raise ValueError(f'Invalid parameter: {key}')
        
    def _check_params(self, param_list):
        missing_params = []
        for var in param_list:
            if self._params.get(var) is None:
                missing_params.append(var)
        return missing_params
    
    
class EnsembleBase(DAbase):    
    def __init__(self, model, dt, store_history=False):
        super().__init__(model, dt, store_history)
        self._param_list = [
            'X_ens_ini', 
            'obs', 
            'obs_interv', 
            'R', 
            'H_func', 
            'alpha', 
            'inflat',
            'local',
        ]
    
    def list_params(self):
        return self._param_list
    
    def set_params(self, **kwargs):
        local = kwargs.get('local')
        if local is not None and not isinstance(local, (tuple, list)):
            kwargs['local'] = tuple(local)
        super().set_params(self._param_list, **kwargs)
    
    def _check_params(self):
        if self._params.get('H_func') is None:
            H_func = lambda arr: arr
            self._params['H_func']
        
        missing_params = super()._check_params(self._param_list)
        if missing_params:
            raise ValueError(f"Missing parameters: {missing_params}")
            
    def _analysis(self):
        pass
            
    def cycle(self, **kwargs):
        self._check_params()
        
        model = self.model
        dt = self.dt
        cycle_len = self._params['obs_interv']
        cycle_num = self._params['obs'].shape[1]
        
        xb = self._params['X_ens_ini'].copy()
        obs = self._params['obs']
        R = self._params['R']
        H_func = self._params['H_func']
        alpha = self._params['alpha']
        inflat = self._params['inflat']
        local = self._params['local']
        
        ndim, N_ens = xb.shape
        background = np.zeros((N_ens, ndim, cycle_len*cycle_num))
        analysis = np.zeros_like(background)
        
        t_start = 0
        ts = np.linspace(t_start, (cycle_len-1)*dt, cycle_len)
        
        for nc in range(cycle_num):
            # analysis
            xa = self._analysis(xb, obs[:,[nc]], R, H_func, *local, **kwargs)
            
            # inflat
            xa_perturb = xa - xa.mean(axis=1)[:,np.newaxis]
            xa_perturb *= inflat
            xa = xa.mean(axis=1)[:,np.newaxis] + xa_perturb
            
            # ensemble forecast
            for iens in range(N_ens):
                x_forecast = model(xa[:,iens], ts)   # (ndim, ts.size)
                
                idx1 = nc*cycle_len
                idx2 = (nc+1)*cycle_len
                analysis[iens,:,idx1:idx2] = x_forecast
                background[iens,:,[idx1]] = xb[:,iens]
                background[iens,:,(idx1+1):idx2] = x_forecast[:,1:]
                
                # xb for next cycle
                xb[:,iens] = x_forecast[:,-1]
                
            # for next cycle
            t_start = int(ts[-1] + dt)
            ts = np.linspace(t_start, t_start+(cycle_len-1)*dt, cycle_len)
            
        self.background = background
        self.analysis = analysis


class EnKF(EnsembleBase):
    def _check_params(self):          
        if self._params.get('local') is None:
            ndim_model = self._params.get('X_ens_ini').shape[0]
            ndim_obs = self._params.get('obs').shape[0]
            loc_mo = np.ones((ndim_model, ndim_obs))
            loc_oo = np.ones((ndim_obs, ndim_obs))
            self._params['local'] = (loc_mo, loc_oo)            
        super()._check_params()
 
    def _analysis(self, xb, yo, R, H_func, loc_mo, loc_oo):
        """xb.shape = (n_dim, n_ens)"""
        N_ens = xb.shape[1]
        xb_mean = xb.mean(axis=1)[:,np.newaxis]   # (ndim_xb, 1)
        Xb_perturb = xb - xb_mean   # (ndim_xb, N_ens)
        Hxb_mean = H_func(xb).mean(axis=1)[:,np.newaxis]   # (ndim_yo, 1)
        HXb_perturb = H_func(xb) - Hxb_mean   # (ndim_yo, N_ens)
        
        PfH_T = Xb_perturb @ HXb_perturb.T / (N_ens-1)
        HPfH_T = HXb_perturb @ HXb_perturb.T / (N_ens-1)
        K = loc_mo * PfH_T @ np.linalg.inv(loc_oo * HPfH_T + R)
        
        yo_ens = np.random.multivariate_normal(yo.ravel(), R, size=N_ens).T   # (ndim_yo, N_ens)
        xa_ens = np.zeros((xb.shape[0], N_ens))
        for iens in range(N_ens):            
            xa_ens[:,[iens]] = xb[:,[iens]] + K @ (yo_ens[:,[iens]] - H_func(xb[:,[iens]]))
            
        return xa_ens


class EnSRF(EnsembleBase):         
    def _check_params(self):
        # default setting
        if self._params.get('local') is None:
            ndim_model = self._params.get('X_ens_ini').shape[0]
            ndim_obs = self._params.get('obs').shape[0]
            loc_mo = np.ones((ndim_model, ndim_obs))
            self._params['local'] = (loc_mo,)            
        
        # check parameters
        super()._check_params()
        
        # check if R is diagonal matrix
        R = self._params['R']
        Rnew = np.zeros_like(R)
        np.fill_diagonal(Rnew, R.diagonal())
        if not np.all(R == Rnew):
            messg = 'EnSRF assimilates observations serially. It suggests that R should be diagonal matrix.'
            warnings.warn(messg, DiagWarning)
        
    def _analysis(self, xb, yo, R, H_func=None, loc_mo=None):
        """xb.shape = (n_dim, n_ens)"""
        if H_func is None:
            H_func = lambda arr: arr
        
        N_ens = xb.shape[1]
        xb_mean = xb.mean(axis=1)[:,np.newaxis]   # (ndim_xb, 1)
        xb_pertb = xb - xb_mean   # (ndim_xb, N_ens)
        
        # update `xb_mean`
        Hxb_mean = H_func(xb).mean(axis=1)[:,np.newaxis]   # (ndim_yo, 1)
        Hxb_pertb = H_func(xb) - Hxb_mean   # (ndim_yo, N_ens)
        PfH_T = xb_pertb @ Hxb_pertb.T / (N_ens-1)
        HPfH_T = Hxb_pertb @ Hxb_pertb.T / (N_ens-1)
        K = loc_mo * PfH_T @ np.linalg.inv(HPfH_T + R)
        xa_mean = xb_mean + K @ (yo - H_func(xb_mean))
        
        # update `xb_pertb`
        xa_pertb = xb_pertb.copy()
        for j_ens in range(N_ens):
            # assimilate one observation at a time
            for io, y in enumerate(yo):
                iR = R[io,io]
                iHxb_pertb = Hxb_pertb[[io],:]   # (1, N_ens)
                
                PfH_T = xb_pertb @ iHxb_pertb.T / (N_ens-1)
                HPfH_T = iHxb_pertb @ iHxb_pertb.T / (N_ens-1)
                gamma = 1 / (1 + np.sqrt(iR / (HPfH_T+iR)))
                K = loc_mo[:,[io]] * PfH_T / (HPfH_T + iR)
                
                xa_pertb_j = xa_pertb[:,[j_ens]]
                xa_pertb[:,[j_ens]] = xa_pertb_j - gamma * K * iHxb_pertb[0,j_ens]
                
        xa_ens = xa_mean + xa_pertb
        return xa_ens
        

class ETKF(EnsembleBase):
    """
    Ensemble Transform Kalman Filter
    
    It should note that localization is only used for updating ensemble mean 
    of K method (e.g etkf.cycle(mean_method='K')). There is no localization
    for w method (e.g etkf.cycle(mean_method='w')).
    
    And localization is for ensemble mean only, there is no localization for 
    updating ensemble perturbation.
    
    *Reference
    Update ensemble mean of w method:
        Harlim and Hunt: Local Ensemble Transform Kalman Filter: An efficient
        scheme for assimilating atmospheric data.
        https://www.atmos.umd.edu/~ekalnay/pubs/harlim_hunt05.pdf
    Update ensemble perturbation:
        Tippett, M. K., J. L. Anderson, C. H. Bishop, T. M. Hamill, and J. S. 
        Whitaker, 2003: Ensemble square root filters.
        https://journals.ametsoc.org/doi/pdf/10.1175/1520-0493%282003%29131%3C1485%3AESRF%3E2.0.CO%3B2      
    """
    def _check_params(self):
        # default setting
        if self._params.get('local') is None:
            ndim_model = self._params.get('X_ens_ini').shape[0]
            ndim_obs = self._params.get('obs').shape[0]
            loc_mo = np.ones((ndim_model, ndim_obs))
            loc_oo = np.ones((ndim_obs, ndim_obs))
            self._params['local'] = (loc_mo, loc_oo)            
        
        # check parameters
        super()._check_params()
    
    def _analysis_mean_w(self, xb_mean, xb_pertb, Hxb_mean, Hxb_pertb, N_ens, yo, R):
        """
        Using the w vector in Harlim and Hunt* to update background ensemble
        mean to analysis mean.
        *Reference: 
        https://www.atmos.umd.edu/~ekalnay/pubs/harlim_hunt05.pdf
        """
        P_tilt = np.linalg.inv(Hxb_pertb.T @ np.linalg.inv(R) @ Hxb_pertb + (N_ens-1) * np.eye(N_ens))
        w = P_tilt @ Hxb_pertb.T @ np.linalg.inv(R) @ (yo - Hxb_mean)
        xa_mean = xb_mean + xb_pertb @ w
        return xa_mean
    
    def _analysis_mean_K(self, xb_mean, xb_pertb, Hxb_pertb, N_ens, yo, R, H_func, loc_mo, loc_oo):
        """
        Using the K matrix (Kalman gain matrix) in traditional Kalman filter to
        upate background ensemble mean to analysis ensemble mean.
        """
        PfH_T = xb_pertb @ Hxb_pertb.T / (N_ens-1)
        HPfH_T = Hxb_pertb @ Hxb_pertb.T / (N_ens-1)
        K = loc_mo * PfH_T @ np.linalg.inv(loc_oo * HPfH_T + R)
        xa_mean = xb_mean + K @ (yo - H_func(xb_mean))
        return xa_mean
    
    def _analysis_perturb(self, xb_pertb, Hxb_pertb, N_ens, R):
        """
        Update background ensemble perturbation tp analysis ensemble perturbation.
        *Reference:
        https://journals.ametsoc.org/doi/pdf/10.1175/1520-0493%282003%29131%3C1485%3AESRF%3E2.0.CO%3B2
        """
        Z = xb_pertb / np.sqrt(N_ens-1)
        HZ = Hxb_pertb / np.sqrt(N_ens-1)
        eigval, C = np.linalg.eig(HZ.T @ np.linalg.inv(R) @ HZ)
        S = np.diag(eigval)
        T = C @ np.linalg.inv(sqrtm(S+np.eye(N_ens)))
        T = T.real   # imag part is likely due to numerical error
        xa_pertb = xb_pertb @ T
        return xa_pertb
        
    def _analysis(self, xb, yo, R, H_func, loc_mo, loc_oo, mean_method='w'):       
        N_ens = xb.shape[1]
        xb_mean = xb.mean(axis=1)[:,np.newaxis]   # (ndim_xb, 1)
        xb_pertb = xb - xb_mean   # (ndim_xb, N_ens)
        Hxb_mean = H_func(xb).mean(axis=1)[:,np.newaxis]   # (ndim_yo, 1)
        Hxb_pertb = H_func(xb) - Hxb_mean   # (ndim_yo, N_ens)
        
        if mean_method == 'w':
            xa_mean = self._analysis_mean_w(xb_mean, xb_pertb, Hxb_mean, Hxb_pertb, N_ens, yo, R)
        elif mean_method == 'K':
            xa_mean = self._analysis_mean_K(xb_mean, xb_pertb, Hxb_pertb, N_ens, yo, R, H_func, loc_mo, loc_oo)
        else:
            raise TypeError('`mean_method` should be "w" or "K"')
            
        xa_pertb = self._analysis_perturb(xb_pertb, Hxb_pertb, N_ens, R)
        xa = xa_mean + xa_pertb
        return xa
    
    def cycle(self, mean_method='w'):
        super().cycle(mean_method=mean_method)
        

class EAKF(EnsembleBase): 
    """
    Ensemble Adjustment Kalman Filter
    
    It based on the 2-step procedure of Anderson (2003), and followed the 
    step-by-step introduction of Shen et al. (2018).
    
    *Reference
    Zheqi Shen, Youmin Tang, Xiaojing Li, Yanqiu Gao, and Junde Li, 2018:
    On the localization in strongly coupled ensemble data assimilationusing 
    a two-scale Lorenz model
    https://www.nonlin-processes-geophys-discuss.net/npg-2018-50/
    Anderson, 2003: A local least squares framework for ensemble filtering
    https://doi.org/10.1175/1520-0493(2003)131<0634:ALLSFF>2.0.CO;2
    """
    def _check_params(self):
        # default setting
        if self._params.get('local') is None:
            ndim_model = self._params.get('X_ens_ini').shape[0]
            ndim_obs = self._params.get('obs').shape[0]
            loc_mo = np.ones((ndim_model, ndim_obs))
            self._params['local'] = (loc_mo,)            
        
        # check parameters
        super()._check_params()
    
    def _analysis(self, xb, yo, R, H_func, loc_mo):            
        N_x, N_ens = xb.shape
        
        # serially assimilation
        xa = xb.copy()
        for io, iyo in enumerate(yo):
            ### step 1
            # estimate background field at the observation space
            yp = np.empty(N_ens)
            for iens in range(N_ens):
                yp[iens] = H_func(xa[:,[iens]])[io]
                
            # analysis for the background field at the observation space
            yp_mean = yp.mean()
            yp_var = yp.var()
            r = R[io,io]
            yu_var = 1 / (1/yp_var + 1/r)
            yu_mean = yu_var * (yp_mean / yp_var + iyo / r)
            yu = np.sqrt(yu_var / yp_var) * (yp - yp_mean) + yu_mean   # (N_ens,)
            increment_y = yu - yp   # (N_ens,)
            
            ### step 2 
            for jstate in range(N_x):
                cov_xy = np.cov(xa[jstate,:], yp)[0,1]
                increment_x = cov_xy / yp_var * increment_y
                xa[jstate,:] = xa[jstate,:] + loc_mo[jstate,io] * increment_x
                
        return xa

In [26]:
for Da in [EnKF, EnSRF, ETKF, EAKF]:
    da = Da(lorenz63_fdm, dt)
    params = {
        'X_ens_ini': X_ens_ini,
        'obs': obs,
        'obs_interv': obs_intv,
        'R': R,
        'H_func': lambda arr: arr,
        'alpha': 0.4,
        'inflat': 1.5,
    }
    da.set_params(**params)
    da.cycle()
    analysis = da.analysis.mean(axis=0)
    
    name = da.__str__().split('.')[1].split()[0]
    print(name + ', Mean RMSE: ', da_rmse(nature, analysis, obs_intv).mean())

EnKF, Mean RMSE:  0.8579315684896274
EnSRF, Mean RMSE:  0.9565964031879403
ETKF, Mean RMSE:  0.8916293338355986
EAKF, Mean RMSE:  0.8847446160796265


In [10]:
from model import lorenz96_fdm

x0 = np.random.randint(low=0, high=9, size=40)
x0 = x0[:,np.newaxis]  # (40, 1)
end_time = 10
dt = 0.01
ts = np.arange(0, end_time, dt)
nature = lorenz96_fdm(x0, ts)

Pb = nmc(lorenz96_fdm, nature, dt, 1, 0.1)

obs_intv = 8
obs = nature + np.sqrt(2) * np.random.randn(*nature.shape)
obs = obs[:,::obs_intv]

R = np.eye(40) * 2

X_ini = x0 + np.random.randint(-15, 15, size=x0.shape)
N_ens = 30
X_ens_ini = np.random.multivariate_normal(X_ini.ravel(), Pb, size=N_ens).T  # (ndim, N_ens)

In [11]:
def dis_oo(i, j, L=2):
    if j < i:
        i, j = j, i
    dis_idx = j - i
    if dis_idx > 20:
        dis_idx = 40 - dis_idx
    return np.exp(-dis_idx**2 / (2*L**2))

def dis_mo(i, j, L=2):
    return dis_oo(i, j, L)

# localization for model to observation
loc1 = np.zeros((40, 40))
for i in range(40):
    for j in range(40):
        loc1[i,j] = dis_mo(i, j)
        
# localization for observation to observation
loc2 = np.zeros((40, 40))
for i in range(40):
    for j in range(40):
        loc2[i,j] = dis_oo(i, j)

In [29]:
for Da in [EnKF, EnSRF, ETKF, EAKF]:
    da = Da(lorenz96_fdm, dt)
    params = {
        'X_ens_ini': X_ens_ini,
        'obs': obs,
        'obs_interv': obs_intv,
        'R': R,
        'H_func': lambda arr: arr,
        'alpha': 0.4,
        'inflat': 1.5,
        'local': (loc1, loc2)
    }
    if Da in [EnKF, ETKF]:
        params['local'] = (loc1, loc2)
    else:
        params['local'] = (loc1,)
    da.set_params(**params)
    
    if Da is ETKF:
        da.cycle(mean_method='K')
    else:
        da.cycle()
    analysis = da.analysis.mean(axis=0)
    
    name = da.__str__().split('.')[1].split()[0]
    print(name + ', Mean RMSE: ', da_rmse(nature, analysis, obs_intv).mean())

EnKF, Mean RMSE:  0.9450801334140161


  x[:,idx+1] = xn + dt * ((xn_p1-xn_m2) * xn_m1 - xn + F)
  ret = umr_sum(arr, axis, dtype, out, keepdims)


EnSRF, Mean RMSE:  nan
ETKF, Mean RMSE:  1.1254783389162129
EAKF, Mean RMSE:  0.9449260180624719


## modified EnSRF

In [30]:
# original one

class EnSRF2(EnsembleBase):         
    def _check_params(self):
        # default setting
        if self._params.get('local') is None:
            ndim_model = self._params.get('X_ens_ini').shape[0]
            ndim_obs = self._params.get('obs').shape[0]
            loc_mo = np.ones((ndim_model, ndim_obs))
            self._params['local'] = (loc_mo,)            
        
        # check parameters
        super()._check_params()
        
        # check if R is diagonal matrix
        R = self._params['R']
        Rnew = np.zeros_like(R)
        np.fill_diagonal(Rnew, R.diagonal())
        if not np.all(R == Rnew):
            messg = 'EnSRF assimilates observations serially. It suggests that R should be diagonal matrix.'
            warnings.warn(messg, DiagWarning)
           
    def _analysis(self, xb, yo, R, H_func, loc_mo):
        """xb.shape = (n_dim, n_ens)"""
        xb = xb.copy()
        N_x, N_ens = xb.shape
        
        xb_mean = xb.mean(axis=1)[:,np.newaxis]   # (N_x, 1)
        xb_pertb = xb - xb_mean   # (N_x, N_ens) 
        
        for io, y in enumerate(yo):
            # update mean field
            Hxb_mean = H_func(xb)[io,:].mean()   # scalar
            Hxb_pertb = H_func(xb)[io,:] - Hxb_mean   # (N_ens,)
                
            HPfH_T = np.cov(Hxb_pertb)   # scalar
            PfH_T = np.empty((N_x, 1))
            for ix in range(N_x):
                PfH_T[ix] = np.cov(xb_pertb[ix,:], Hxb_pertb)[0,1]
                
            K = loc_mo[:,[io]] * PfH_T / (HPfH_T + R[io,io])   # (N_x, 1)
            xa_mean = xb_mean + K * (y - H_func(xb)[io,:].mean())
                
            # update perturbation field
            D = R[io,io] + HPfH_T
            gamma = 1 / (1 + np.sqrt(R[io,io] / D))   # scalar
            innovation_H = H_func(xb)[io,:] - H_func(xb)[io,:].mean()   # (N_ens,)
            xa_pertb = xb_pertb - gamma * K * innovation_H   # (N_x, N_ens)
            
            # for next loop
            xb_mean = xa_mean
            xb_pertb = xa_pertb
        
        xa = xa_mean + xa_pertb
        return xa

In [31]:
# equivalent to pervious one but try to speed up

def covariance(m1, v2, n):
    """
    Calculate the covariance between each row of `m1` and `v2`.
    Parameters:
        m1: numpy matrix with shape=(k, n)
        v2: numpy array with shape=(n,)
    Return:
        covariance with shape=(k,) where i'th element is the covariance
        between m1[i,:] and v2
    """
    return ((m1 - m1.mean(axis=1)[:,np.newaxis]) * (v2 - v2.mean())).sum(axis=1) / (n-1)

class EnSRF2(EnsembleBase):         
    def _check_params(self):
        # default setting
        if self._params.get('local') is None:
            ndim_model = self._params.get('X_ens_ini').shape[0]
            ndim_obs = self._params.get('obs').shape[0]
            loc_mo = np.ones((ndim_model, ndim_obs))
            self._params['local'] = (loc_mo,)            
        
        # check parameters
        super()._check_params()
        
        # check if R is diagonal matrix
        R = self._params['R']
        Rnew = np.zeros_like(R)
        np.fill_diagonal(Rnew, R.diagonal())
        if not np.all(R == Rnew):
            messg = 'EnSRF assimilates observations serially. It suggests that R should be diagonal matrix.'
            warnings.warn(messg, DiagWarning)
           
    def _analysis(self, xb, yo, R, H_func, loc_mo):
        """xb.shape = (N_x, N_ens)"""
        xb = xb.copy()
        N_x, N_ens = xb.shape
        
        xb_mean = xb.mean(axis=1)[:,np.newaxis]   # (N_x, 1)
        xb_pertb = xb - xb_mean   # (N_x, N_ens) 
        
        Hxb = H_func(xb)   # (N_y, N_ens)
        Hxb_mean = Hxb.mean(axis=1)[:,np.newaxis]   # (N_y, 1)
        Hxb_pertb = Hxb - Hxb_mean   # (N_y, N_ens)
        
        for io, y in enumerate(yo):
            # update mean field
            iHxb_mean = Hxb_mean[io]   # scalar
            iHxb_pertb = Hxb_pertb[io,:]   # (N_ens,)
                
            HPfH_T = np.cov(iHxb_pertb)   # scalar
            PfH_T = covariance(xb_pertb, iHxb_pertb, N_ens)[:,np.newaxis]   # (N_x, 1)
                
            K = loc_mo[:,[io]] * PfH_T / (HPfH_T + R[io,io])   # (N_x, 1)
            xa_mean = xb_mean + K * (y - iHxb_mean)
                
            # update perturbation field
            D = R[io,io] + HPfH_T
            gamma = 1 / (1 + np.sqrt(R[io,io] / D))   # scalar
            innovation_H = Hxb[io,:] - iHxb_mean   # (N_ens,)
            xa_pertb = xb_pertb - gamma * K * innovation_H   # (N_x, N_ens)
            
            # for next loop
            xb_mean = xa_mean
            xb_pertb = xa_pertb
        
        xa = xa_mean + xa_pertb
        return xa

In [32]:
import time

for Da in [EnKF, EnSRF2, ETKF, EAKF]:
    time1 = time.time()
    
    da = Da(lorenz96_fdm, dt)
    params = {
        'X_ens_ini': X_ens_ini,
        'obs': obs,
        'obs_interv': obs_intv,
        'R': R,
        'H_func': lambda arr: arr,
        'alpha': 0.4,
        'inflat': 1.5,
        'local': (loc1, loc2)
    }
    if Da in [EnKF, ETKF]:
        params['local'] = (loc1, loc2)
    else:
        params['local'] = (loc1,)
    da.set_params(**params)
    
    if Da is ETKF:
        da.cycle(mean_method='K')
    else:
        da.cycle()
    analysis = da.analysis.mean(axis=0)
    
    name = da.__str__().split('.')[1].split()[0]
    print(name + ', Mean RMSE: ', da_rmse(nature, analysis, obs_intv).mean())
    
    time2 = time.time()
    print(time2 - time1)

EnKF, Mean RMSE:  0.9561451261750175
2.703463077545166
EnSRF2, Mean RMSE:  1.0097440720225648
3.4122581481933594
ETKF, Mean RMSE:  1.1254783389162129
2.942962169647217
EAKF, Mean RMSE:  0.9449260180624719
18.19165349006653


In [33]:
# change the position of H_func and np.cov

def covariance(m1, v2, n):
    """
    Calculate the covariance between each row of `m1` and `v2`.
    Parameters:
        m1: numpy matrix with shape=(k, n)
        v2: numpy array with shape=(n,)
    Return:
        covariance with shape=(k,) where i'th element is the covariance
        between m1[i,:] and v2
    """
    return ((m1 - m1.mean(axis=1)[:,np.newaxis]) * (v2 - v2.mean())).sum(axis=1) / (n-1)

class EnSRF2(EnsembleBase):         
    def _check_params(self):
        # default setting
        if self._params.get('local') is None:
            ndim_model = self._params.get('X_ens_ini').shape[0]
            ndim_obs = self._params.get('obs').shape[0]
            loc_mo = np.ones((ndim_model, ndim_obs))
            self._params['local'] = (loc_mo,)            
        
        # check parameters
        super()._check_params()
        
        # check if R is diagonal matrix
        R = self._params['R']
        Rnew = np.zeros_like(R)
        np.fill_diagonal(Rnew, R.diagonal())
        if not np.all(R == Rnew):
            messg = 'EnSRF assimilates observations serially. It suggests that R should be diagonal matrix.'
            warnings.warn(messg, DiagWarning)
           
    def _analysis(self, xb, yo, R, H_func, loc_mo):
        """xb.shape = (N_x, N_ens)"""
        xb = xb.copy()
        N_x, N_ens = xb.shape
        
        xb_mean = xb.mean(axis=1)[:,np.newaxis]   # (N_x, 1)
        xb_pertb = xb - xb_mean   # (N_x, N_ens) 
        
        for io, y in enumerate(yo):
            Hxb = H_func(xb)   # (N_y, N_ens)
            Hxb_mean = Hxb.mean(axis=1)[:,np.newaxis]   # (N_y, 1)
            Hxb_pertb = Hxb - Hxb_mean   # (N_y, N_ens)
            iHxb_mean = Hxb_mean[io]   # scalar
            iHxb_pertb = Hxb_pertb[io,:]   # (N_ens,)
                    
            # update mean field
            HPfH_T = np.sum(iHxb_pertb**2) / (N_ens-1)   # scalar
            PfH_T = covariance(xb_pertb, iHxb_pertb, N_ens)[:,np.newaxis]   # (N_x, 1)
            K = loc_mo[:,[io]] * PfH_T / (HPfH_T + R[io,io])   # (N_x, 1)
            xa_mean = xb_mean + K * (y - iHxb_mean)
                
            # update perturbation field
            D = R[io,io] + HPfH_T
            gamma = 1 / (1 + np.sqrt(R[io,io] / D))   # scalar
            innovation_H = Hxb[io,:] - iHxb_mean   # (N_ens,)
            xa_pertb = xb_pertb - gamma * K * innovation_H   # (N_x, N_ens)
            
            # for next loop
            xb_mean = xa_mean
            xb_pertb = xa_pertb
            xb = xb_mean + xb_pertb
        
        xa = xa_mean + xa_pertb
        return xa

In [34]:
import time

for Da in [EnKF, EnSRF2, ETKF]:
    time1 = time.time()
    
    da = Da(lorenz96_fdm, dt)
    params = {
        'X_ens_ini': X_ens_ini,
        'obs': obs,
        'obs_interv': obs_intv,
        'R': R,
        'H_func': lambda arr: arr,
        'alpha': 0.4,
        'inflat': 1.5,
        'local': (loc1, loc2)
    }
    if Da in [EnKF, ETKF]:
        params['local'] = (loc1, loc2)
    else:
        params['local'] = (loc1,)
    da.set_params(**params)
    
    if Da is ETKF:
        da.cycle(mean_method='K')
    else:
        da.cycle()
    analysis = da.analysis.mean(axis=0)
    
    name = da.__str__().split('.')[1].split()[0]
    print(name + ', Mean RMSE: ', da_rmse(nature, analysis, obs_intv).mean())
    
    time2 = time.time()
    print(time2 - time1)

EnKF, Mean RMSE:  0.9506942170540202
2.677966833114624
EnSRF2, Mean RMSE:  0.9437663575378138
3.061962604522705
ETKF, Mean RMSE:  1.1254783389162129
2.834963798522949


In [56]:
# change covariance

def covariance(m1, v2, n):
    """
    Calculate the covariance between each row of `m1` and `v2`.
    Parameters:
        m1: numpy matrix with shape=(k, n)
        v2: numpy array with shape=(n,)
    Return:
        covariance with shape=(k,) where i'th element is the covariance
        between m1[i,:] and v2
    """
    #return ((m1 - m1.mean(axis=1)[:,np.newaxis]) * (v2 - v2.mean())).sum(axis=1) / (n-1)
    return np.dot(m1 - m1.mean(axis=1)[:,np.newaxis], v2 - v2.mean()) / (n-1)

class EnSRF2(EnsembleBase):         
    def _check_params(self):
        # default setting
        if self._params.get('local') is None:
            ndim_model = self._params.get('X_ens_ini').shape[0]
            ndim_obs = self._params.get('obs').shape[0]
            loc_mo = np.ones((ndim_model, ndim_obs))
            self._params['local'] = (loc_mo,)            
        
        # check parameters
        super()._check_params()
        
        # check if R is diagonal matrix
        R = self._params['R']
        Rnew = np.zeros_like(R)
        np.fill_diagonal(Rnew, R.diagonal())
        if not np.all(R == Rnew):
            messg = 'EnSRF assimilates observations serially. It suggests that R should be diagonal matrix.'
            warnings.warn(messg, DiagWarning)
           
    def _analysis(self, xb, yo, R, H_func, loc_mo):
        """xb.shape = (N_x, N_ens)"""
        xb = xb.copy()
        N_x, N_ens = xb.shape
        
        xb_mean = xb.mean(axis=1)[:,np.newaxis]   # (N_x, 1)
        xb_pertb = xb - xb_mean   # (N_x, N_ens) 
        
        for io, y in enumerate(yo):
            Hxb = H_func(xb)   # (N_y, N_ens)
            Hxb_mean = Hxb.mean(axis=1)[:,np.newaxis]   # (N_y, 1)
            Hxb_pertb = Hxb - Hxb_mean   # (N_y, N_ens)
            iHxb_mean = Hxb_mean[io]   # scalar
            iHxb_pertb = Hxb_pertb[io,:]   # (N_ens,)
                    
            # update mean field
            HPfH_T = np.sum(iHxb_pertb**2) / (N_ens-1)   # scalar
            PfH_T = covariance(xb_pertb, iHxb_pertb, N_ens)[:,np.newaxis]   # (N_x, 1)
            K = loc_mo[:,[io]] * PfH_T / (HPfH_T + R[io,io])   # (N_x, 1)
            xa_mean = xb_mean + K * (y - iHxb_mean)
                
            # update perturbation field
            D = R[io,io] + HPfH_T
            gamma = 1 / (1 + np.sqrt(R[io,io] / D))   # scalar
            innovation_H = Hxb[io,:] - iHxb_mean   # (N_ens,)
            xa_pertb = xb_pertb - gamma * K * innovation_H   # (N_x, N_ens)
            
            # for next loop
            xb_mean = xa_mean
            xb_pertb = xa_pertb
            xb = xb_mean + xb_pertb
        
        xa = xa_mean + xa_pertb
        return xa

In [59]:
import time

for Da in [EnKF, EnSRF2, ETKF]:
    time1 = time.time()
    
    da = Da(lorenz96_fdm, dt)
    params = {
        'X_ens_ini': X_ens_ini,
        'obs': obs,
        'obs_interv': obs_intv,
        'R': R,
        'H_func': lambda arr: arr,
        'alpha': 0.4,
        'inflat': 1.5,
        'local': (loc1, loc2)
    }
    if Da in [EnKF, ETKF]:
        params['local'] = (loc1, loc2)
    else:
        params['local'] = (loc1,)
    da.set_params(**params)
    
    if Da is ETKF:
        da.cycle(mean_method='K')
    else:
        da.cycle()
    analysis = da.analysis.mean(axis=0)
    
    name = da.__str__().split('.')[1].split()[0]
    print(name + ', Mean RMSE: ', da_rmse(nature, analysis, obs_intv).mean())
    
    time2 = time.time()
    print(time2 - time1)

EnKF, Mean RMSE:  0.9428712362712797
2.7121872901916504
EnSRF2, Mean RMSE:  0.9437663575378128
3.092928886413574
ETKF, Mean RMSE:  1.1254783389162129
3.453951120376587


### modified EAKF

In [62]:
class EAKF2(EnsembleBase): 
    """
    Ensemble Adjustment Kalman Filter
    
    It based on the 2-step procedure of Anderson (2003), and followed the 
    step-by-step introduction of Shen et al. (2018) or Liu el al. (2007).
    
    *Reference
    [1]
    Zheqi Shen, Youmin Tang, Xiaojing Li, Yanqiu Gao, and Junde Li, 2018:
    On the localization in strongly coupled ensemble data assimilationusing 
    a two-scale Lorenz model
    https://www.nonlin-processes-geophys-discuss.net/npg-2018-50/
    
    [2]
    Anderson, 2003: A local least squares framework for ensemble filtering
    https://doi.org/10.1175/1520-0493(2003)131<0634:ALLSFF>2.0.CO;2
    
    [3]
    Liu, H., J. Anderson, Y.-H. Kuo, and K. Raeder, 2007: Importance of 
    forecast error multivariate correlations in idealized assimilation of GPS
    radio occultation data with the ensemble adjustment filter. 
    https://journals.ametsoc.org/doi/abs/10.1175/MWR3270.1
    """
    def _check_params(self):
        # default setting
        if self._params.get('local') is None:
            ndim_model = self._params.get('X_ens_ini').shape[0]
            ndim_obs = self._params.get('obs').shape[0]
            loc_mo = np.ones((ndim_model, ndim_obs))
            self._params['local'] = (loc_mo,)            
        
        # check parameters
        super()._check_params()
        
        # check if R is diagonal matrix
        R = self._params['R']
        Rnew = np.zeros_like(R)
        np.fill_diagonal(Rnew, R.diagonal())
        if not np.all(R == Rnew):
            messg = 'EAKF assimilates observations serially. It suggests that R should be diagonal matrix.'
            warnings.warn(messg, DiagWarning)
    
    def _analysis(self, xb, yo, R, H_func, loc_mo):  
        """xb.shape = (N_x, N_ens)"""
        N_x, N_ens = xb.shape
        
        # serially assimilation
        xa = xb.copy()
        for io, iyo in enumerate(yo):
            ### step 1
            # estimate background field at the observation space
            #yp = np.empty(N_ens)
            #for iens in range(N_ens):
            #    yp[iens] = H_func(xa[:,[iens]])[io]
            yp = H_func(xa)[io,:]   # (N_ens,)
                
            # analysis for the background field at the observation space
            yp_mean = yp.mean()
            #yp_var = yp.var()
            yp_var = np.sum((yp - yp_mean)**2) / (N_ens-1)
            r = R[io,io]
            yu_var = 1 / (1/yp_var + 1/r)
            yu_mean = yu_var * (yp_mean / yp_var + iyo / r)
            yu = np.sqrt(yu_var / yp_var) * (yp - yp_mean) + yu_mean   # (N_ens,)
            increment_y = yu - yp   # (N_ens,)
            
            ### step 2 
            #cov_xy_states = covariance(xa, yp, N_ens)   # (N_x,)
            #for jstate in range(N_x):
            #    cov_xy = cov_xy_states[jstate]
            #    increment_x = cov_xy / yp_var * increment_y   # (N_ens,)
            #    xa[jstate,:] += loc_mo[jstate,io] * increment_x
            cov_xy_states = covariance(xa, yp, N_ens)   # (N_x,)
            increment_x_states = cov_xy_states[:,np.newaxis] / yp_var * increment_y   # (N_x, N_ens)
            xa += loc_mo[:,[io]] * increment_x_states
                
        return xa

In [63]:
import time

for Da in [EAKF2]:
    time1 = time.time()
    
    da = Da(lorenz96_fdm, dt)
    params = {
        'X_ens_ini': X_ens_ini,
        'obs': obs,
        'obs_interv': obs_intv,
        'R': R,
        'H_func': lambda arr: arr,
        'alpha': 0.4,
        'inflat': 1.5,
        'local': (loc1, loc2)
    }
    if Da in [EnKF, ETKF]:
        params['local'] = (loc1, loc2)
    else:
        params['local'] = (loc1,)
    da.set_params(**params)
    
    if Da is ETKF:
        da.cycle(mean_method='K')
    else:
        da.cycle()
    analysis = da.analysis.mean(axis=0)
    
    name = da.__str__().split('.')[1].split()[0]
    print(name + ', Mean RMSE: ', da_rmse(nature, analysis, obs_intv).mean())
    
    time2 = time.time()
    print(time2 - time1)

EAKF2, Mean RMSE:  0.9437663575378132
3.1144251823425293


In [64]:
%load_ext line_profiler

In [65]:
def test():        
    da = EAKF2(lorenz96_fdm, dt)
    params = {
        'X_ens_ini': X_ens_ini,
        'obs': obs,
        'obs_interv': obs_intv,
        'R': R,
        'H_func': lambda arr: arr,
        'alpha': 0.4,
        'inflat': 1.5,
        'local': (loc1,)
    }
    da.set_params(**params)
    da.cycle()
    analysis = da.analysis.mean(axis=0)
    
    name = da.__str__().split('.')[1].split()[0]
    print(name + ', Mean RMSE: ', da_rmse(nature, analysis, obs_intv).mean())

In [66]:
%lprun -f da._analysis test()

EAKF2, Mean RMSE:  0.9437663575378132


Timer unit: 2.68282e-07 s

Total time: 0.951675 s
File: <ipython-input-62-fc0aa67c53dc>
Function: _analysis at line 44

Line #      Hits         Time  Per Hit   % Time  Line Contents
    44                                               def _analysis(self, xb, yo, R, H_func, loc_mo):  
    45                                                   """xb.shape = (N_x, N_ens)"""
    46       125       1223.0      9.8      0.0          N_x, N_ens = xb.shape
    47                                                   
    48                                                   # serially assimilation
    49       125       2577.0     20.6      0.1          xa = xb.copy()
    50      5125      37841.0      7.4      1.1          for io, iyo in enumerate(yo):
    51                                                       ### step 1
    52                                                       # estimate background field at the observation space
    53                                                       #yp

### modified ETKF

In [67]:
class ETKF2(EnsembleBase):
    """
    Ensemble Transform Kalman Filter
    
    It should note that localization is only used for updating ensemble mean 
    of K method (e.g etkf.cycle(mean_method='K')). There is no localization
    for w method (e.g etkf.cycle(mean_method='w')).
    
    And localization is for ensemble mean only, there is no localization for 
    updating ensemble perturbation.
    
    *Reference
    Update ensemble mean of w method:
        Harlim and Hunt: Local Ensemble Transform Kalman Filter: An efficient
        scheme for assimilating atmospheric data.
        https://www.atmos.umd.edu/~ekalnay/pubs/harlim_hunt05.pdf
    Update ensemble perturbation:
        Tippett, M. K., J. L. Anderson, C. H. Bishop, T. M. Hamill, and J. S. 
        Whitaker, 2003: Ensemble square root filters.
        https://journals.ametsoc.org/doi/pdf/10.1175/1520-0493%282003%29131%3C1485%3AESRF%3E2.0.CO%3B2      
    """
    def _check_params(self):
        # default setting
        if self._params.get('local') is None:
            ndim_model = self._params.get('X_ens_ini').shape[0]
            ndim_obs = self._params.get('obs').shape[0]
            loc_mo = np.ones((ndim_model, ndim_obs))
            loc_oo = np.ones((ndim_obs, ndim_obs))
            self._params['local'] = (loc_mo, loc_oo)            
        
        # check parameters
        super()._check_params()
    
    def _analysis_mean_w(self, xb_mean, xb_pertb, Hxb_mean, Hxb_pertb, N_ens, yo, R):
        """
        Using the w vector in Harlim and Hunt* to update background ensemble
        mean to analysis mean.
        *Reference: 
        https://www.atmos.umd.edu/~ekalnay/pubs/harlim_hunt05.pdf
        """
        P_tilt = np.linalg.inv(Hxb_pertb.T @ np.linalg.inv(R) @ Hxb_pertb + (N_ens-1) * np.eye(N_ens))
        w = P_tilt @ Hxb_pertb.T @ np.linalg.inv(R) @ (yo - Hxb_mean)
        xa_mean = xb_mean + xb_pertb @ w
        return xa_mean
    
    def _analysis_mean_K(self, xb_mean, xb_pertb, Hxb_pertb, N_ens, yo, R, H_func, loc_mo, loc_oo):
        """
        Using the K matrix (Kalman gain matrix) in traditional Kalman filter to
        upate background ensemble mean to analysis ensemble mean.
        """
        PfH_T = xb_pertb @ Hxb_pertb.T / (N_ens-1)
        HPfH_T = Hxb_pertb @ Hxb_pertb.T / (N_ens-1)
        K = loc_mo * PfH_T @ np.linalg.inv(loc_oo * HPfH_T + R)
        xa_mean = xb_mean + K @ (yo - H_func(xb_mean))
        return xa_mean
    
    def _analysis_perturb(self, xb_pertb, Hxb_pertb, N_ens, R):
        """
        Update background ensemble perturbation tp analysis ensemble perturbation.
        *Reference:
        https://journals.ametsoc.org/doi/pdf/10.1175/1520-0493%282003%29131%3C1485%3AESRF%3E2.0.CO%3B2
        """
        Z = xb_pertb / np.sqrt(N_ens-1)
        HZ = Hxb_pertb / np.sqrt(N_ens-1)
        eigval, C = np.linalg.eig(HZ.T @ np.linalg.inv(R) @ HZ)
        T = C @ np.diag(1 / np.sqrt(1 + eigval))
        T = T.real   # imag part is likely due to numerical error
        xa_pertb = xb_pertb @ T
        return xa_pertb
        
    def _analysis(self, xb, yo, R, H_func, loc_mo, loc_oo, mean_method='w'):       
        N_ens = xb.shape[1]
        xb_mean = xb.mean(axis=1)[:,np.newaxis]   # (ndim_xb, 1)
        xb_pertb = xb - xb_mean   # (ndim_xb, N_ens)
        Hxb_mean = H_func(xb).mean(axis=1)[:,np.newaxis]   # (ndim_yo, 1)
        Hxb_pertb = H_func(xb) - Hxb_mean   # (ndim_yo, N_ens)
        
        if mean_method == 'w':
            xa_mean = self._analysis_mean_w(xb_mean, xb_pertb, Hxb_mean, Hxb_pertb, N_ens, yo, R)
        elif mean_method == 'K':
            xa_mean = self._analysis_mean_K(xb_mean, xb_pertb, Hxb_pertb, N_ens, yo, R, H_func, loc_mo, loc_oo)
        else:
            raise TypeError('`mean_method` should be "w" or "K"')
            
        xa_pertb = self._analysis_perturb(xb_pertb, Hxb_pertb, N_ens, R)
        xa = xa_mean + xa_pertb
        return xa
    
    def cycle(self, mean_method='w'):
        super().cycle(mean_method=mean_method)

In [68]:
import time

time1 = time.time()
da = ETKF2(lorenz96_fdm, dt)
params = {
    'X_ens_ini': X_ens_ini,
    'obs': obs,
    'obs_interv': obs_intv,
    'R': R,
    'H_func': lambda arr: arr,
    'alpha': 0.4,
    'inflat': 1.5,
    'local': (loc1, loc2)
}
da.set_params(**params)
da.cycle(mean_method='K')

analysis = da.analysis.mean(axis=0)
name = da.__str__().split('.')[1].split()[0]
print(name + ', Mean RMSE: ', da_rmse(nature, analysis, obs_intv).mean())
time2 = time.time()
print(time2 - time1)

ETKF2, Mean RMSE:  1.1513501733208094
2.910527467727661


### sort out

In [15]:
def _covariance(m1, v2, n):
    """
    Calculate the covariance between each row of `m1` and `v2`.
    Parameters:
        m1: numpy matrix with shape=(k, n)
        v2: numpy array with shape=(n,)
    Return:
        covariance with shape=(k,) where i'th element is the covariance
        between m1[i,:] and v2
    """
    #return ((m1 - m1.mean(axis=1)[:,np.newaxis]) * (v2 - v2.mean())).sum(axis=1) / (n-1)
    return np.dot(m1 - m1.mean(axis=1)[:,np.newaxis], v2 - v2.mean()) / (n-1)

def _isdiag(matrix):
    """Check if `matrix` is a diagonal matrix. Used in serial assimilation."""
    i, j = matrix.shape
    assert i == j 
    test = matrix.reshape(-1)[:-1].reshape(i-1,j+1)
    return ~np.any(test[:,1:])


class DiagWarning(UserWarning):
    """Used in serially assimilation when R is not diagonal"""
    pass


class DAbase:
    def __init__(self, model, dt, store_history=False):
        self._isstore = store_history
        self._params = {'alpha': 0, 'inflat': 1}
        self.model = model
        self.dt = dt
        self.X_ini = None
        
    def set_params(self, param_list, **kwargs):
        for key, value in kwargs.items():
            if key in param_list:
                self._params[key] = kwargs.get(key)
            else:
                raise ValueError(f'Invalid parameter: {key}')
        
    def _check_params(self, param_list):
        missing_params = []
        for var in param_list:
            if self._params.get(var) is None:
                missing_params.append(var)
        return missing_params
    
    
class EnsembleBase(DAbase):    
    def __init__(self, model, dt, store_history=False):
        super().__init__(model, dt, store_history)
        self._param_list = [
            'X_ens_ini', 
            'obs', 
            'obs_interv', 
            'R', 
            'H_func', 
            'alpha', 
            'inflat',
            'local',
        ]
    
    def list_params(self):
        return self._param_list
    
    def set_params(self, **kwargs):
        local = kwargs.get('local')
        if local is not None and not isinstance(local, (tuple, list)):
            kwargs['local'] = tuple(local)
        super().set_params(self._param_list, **kwargs)
    
    def _check_params(self):
        if self._params.get('H_func') is None:
            H_func = lambda arr: arr
            self._params['H_func']
        
        missing_params = super()._check_params(self._param_list)
        if missing_params:
            raise ValueError(f"Missing parameters: {missing_params}")
            
    def _check_R_diag(self):
        if not _isdiag(self._params['R']):
            name = self.__class__.__name__
            messg = f'{name} assimilates observations serially. It suggests that R should be diagonal matrix.'
            warnings.warn(messg, DiagWarning)
            
    def _default_local(self, is_loc_mo=False, is_loc_oo=False):
        ndim_model = self._params.get('X_ens_ini').shape[0]
        ndim_obs = self._params.get('obs').shape[0]
        loc_mo = np.ones((ndim_model, ndim_obs))
        loc_oo = np.ones((ndim_obs, ndim_obs))
        
        if is_loc_mo and is_loc_oo:
            return loc_mo, loc_oo
        elif is_loc_mo and not is_loc_oo:
            return loc_mo
        elif not is_loc_mo and is_loc_oo:
            return loc_oo
            
    def _analysis(self):
        pass
            
    def cycle(self, **kwargs):
        self._check_params()
        
        model = self.model
        dt = self.dt
        cycle_len = self._params['obs_interv']
        cycle_num = self._params['obs'].shape[1]
        
        xb = self._params['X_ens_ini'].copy()
        obs = self._params['obs']
        R = self._params['R']
        H_func = self._params['H_func']
        alpha = self._params['alpha']
        inflat = self._params['inflat']
        local = self._params['local']
        
        ndim, N_ens = xb.shape
        background = np.zeros((N_ens, ndim, cycle_len*cycle_num))
        analysis = np.zeros_like(background)
        
        t_start = 0
        ts = np.linspace(t_start, (cycle_len-1)*dt, cycle_len)
        
        for nc in range(cycle_num):
            # analysis
            xa = self._analysis(xb, obs[:,[nc]], R, H_func, *local, **kwargs)
            
            # inflat
            xa_perturb = xa - xa.mean(axis=1)[:,np.newaxis]
            xa_perturb *= inflat
            xa = xa.mean(axis=1)[:,np.newaxis] + xa_perturb
            
            # ensemble forecast
            for iens in range(N_ens):
                x_forecast = model(xa[:,iens], ts)   # (ndim, ts.size)
                
                idx1 = nc*cycle_len
                idx2 = (nc+1)*cycle_len
                analysis[iens,:,idx1:idx2] = x_forecast
                background[iens,:,[idx1]] = xb[:,iens]
                background[iens,:,(idx1+1):idx2] = x_forecast[:,1:]
                
                # xb for next cycle
                xb[:,iens] = x_forecast[:,-1]
                
            # for next cycle
            t_start = int(ts[-1] + dt)
            ts = np.linspace(t_start, t_start+(cycle_len-1)*dt, cycle_len)
            
        self.background = background
        self.analysis = analysis


class EnKF(EnsembleBase):
    def _check_params(self):          
        if self._params.get('local') is None:
            loc_mo, loc_oo = self._default_local(is_loc_mo=True, is_loc_oo=True)
            self._params['local'] = (loc_mo, loc_oo)
        super()._check_params()
 
    def _analysis(self, xb, yo, R, H_func, loc_mo, loc_oo, random_state=None):
        """xb.shape = (n_dim, n_ens)"""
        N_ens = xb.shape[1]
        xb_mean = xb.mean(axis=1)[:,np.newaxis]   # (ndim_xb, 1)
        Xb_perturb = xb - xb_mean   # (ndim_xb, N_ens)
        Hxb_mean = H_func(xb).mean(axis=1)[:,np.newaxis]   # (ndim_yo, 1)
        HXb_perturb = H_func(xb) - Hxb_mean   # (ndim_yo, N_ens)
        
        PfH_T = Xb_perturb @ HXb_perturb.T / (N_ens-1)
        HPfH_T = HXb_perturb @ HXb_perturb.T / (N_ens-1)
        K = loc_mo * PfH_T @ np.linalg.inv(loc_oo * HPfH_T + R)
        
        rst = np.random.RandomState(seed=random_state)
        yo_ens = rst.multivariate_normal(yo.ravel(), R, size=N_ens).T   # (ndim_yo, N_ens)
        xa_ens = np.zeros((xb.shape[0], N_ens))
        for iens in range(N_ens):            
            xa_ens[:,[iens]] = xb[:,[iens]] + K @ (yo_ens[:,[iens]] - H_func(xb[:,[iens]]))
            
        return xa_ens
    
    def cycle(self, random_state=None):
        super().cycle(random_state=random_state)


class EnSRF(EnsembleBase):         
    def _check_params(self):
        # default setting
        if self._params.get('local') is None:
            loc_mo = self._default_local(is_loc_mo=True)
            self._params['local'] = (loc_mo,)
        
        # check parameters and check if R is diagonal matrix
        super()._check_params()
        self._check_R_diag()
           
    def _analysis(self, xb, yo, R, H_func, loc_mo):
        """xb.shape = (N_x, N_ens)"""
        xb = xb.copy()
        N_x, N_ens = xb.shape
        
        xb_mean = xb.mean(axis=1)[:,np.newaxis]   # (N_x, 1)
        xb_pertb = xb - xb_mean   # (N_x, N_ens) 
        
        for io, y in enumerate(yo):
            Hxb = H_func(xb)   # (N_y, N_ens)
            Hxb_mean = Hxb.mean(axis=1)[:,np.newaxis]   # (N_y, 1)
            Hxb_pertb = Hxb - Hxb_mean   # (N_y, N_ens)
            iHxb_mean = Hxb_mean[io]   # scalar
            iHxb_pertb = Hxb_pertb[io,:]   # (N_ens,)
                    
            # update mean field
            HPfH_T = np.sum(iHxb_pertb**2) / (N_ens-1)   # scalar
            PfH_T = _covariance(xb_pertb, iHxb_pertb, N_ens)[:,np.newaxis]   # (N_x, 1)
            K = loc_mo[:,[io]] * PfH_T / (HPfH_T + R[io,io])   # (N_x, 1)
            xa_mean = xb_mean + K * (y - iHxb_mean)
                
            # update perturbation field
            D = R[io,io] + HPfH_T
            gamma = 1 / (1 + np.sqrt(R[io,io] / D))   # scalar
            innovation_H = Hxb[io,:] - iHxb_mean   # (N_ens,)
            xa_pertb = xb_pertb - gamma * K * innovation_H   # (N_x, N_ens)
            
            # for next loop
            xb_mean = xa_mean
            xb_pertb = xa_pertb
            xb = xb_mean + xb_pertb
        
        xa = xa_mean + xa_pertb
        return xa
    
    
class ETKF(EnsembleBase):
    """
    Ensemble Transform Kalman Filter
    
    It should note that localization is only used for updating ensemble mean 
    of K method (e.g etkf.cycle(mean_method='K')). There is no localization
    for w method (e.g etkf.cycle(mean_method='w')).
    
    And localization is for ensemble mean only, there is no localization for 
    updating ensemble perturbation.
    
    *Reference
    Update ensemble mean of w method:
        Harlim and Hunt: Local Ensemble Transform Kalman Filter: An efficient
        scheme for assimilating atmospheric data.
        https://www.atmos.umd.edu/~ekalnay/pubs/harlim_hunt05.pdf
    Update ensemble perturbation:
        Tippett, M. K., J. L. Anderson, C. H. Bishop, T. M. Hamill, and J. S. 
        Whitaker, 2003: Ensemble square root filters.
        https://journals.ametsoc.org/doi/pdf/10.1175/1520-0493%282003%29131%3C1485%3AESRF%3E2.0.CO%3B2      
    """
    def _check_params(self):
        # default setting
        if self._params.get('local') is None:
            loc_mo, loc_oo = self._default_local(is_loc_mo=True, is_loc_oo=True)
            self._params['local'] = (loc_mo, loc_oo)         
        
        # check parameters
        super()._check_params()
    
    def _analysis_mean_w(self, xb_mean, xb_pertb, Hxb_mean, Hxb_pertb, N_ens, yo, R):
        """
        Using the w vector in Harlim and Hunt* to update background ensemble
        mean to analysis mean.
        *Reference: 
        https://www.atmos.umd.edu/~ekalnay/pubs/harlim_hunt05.pdf
        """
        P_tilt = np.linalg.inv(Hxb_pertb.T @ np.linalg.inv(R) @ Hxb_pertb + (N_ens-1) * np.eye(N_ens))
        w = P_tilt @ Hxb_pertb.T @ np.linalg.inv(R) @ (yo - Hxb_mean)
        xa_mean = xb_mean + xb_pertb @ w
        return xa_mean
    
    def _analysis_mean_K(self, xb_mean, xb_pertb, Hxb_pertb, N_ens, yo, R, H_func, loc_mo, loc_oo):
        """
        Using the K matrix (Kalman gain matrix) in traditional Kalman filter to
        upate background ensemble mean to analysis ensemble mean.
        """
        PfH_T = xb_pertb @ Hxb_pertb.T / (N_ens-1)
        HPfH_T = Hxb_pertb @ Hxb_pertb.T / (N_ens-1)
        K = loc_mo * PfH_T @ np.linalg.inv(loc_oo * HPfH_T + R)
        xa_mean = xb_mean + K @ (yo - H_func(xb_mean))
        return xa_mean
    
    def _analysis_perturb(self, xb_pertb, Hxb_pertb, N_ens, R):
        """
        Update background ensemble perturbation tp analysis ensemble perturbation.
        *Reference:
        https://journals.ametsoc.org/doi/pdf/10.1175/1520-0493%282003%29131%3C1485%3AESRF%3E2.0.CO%3B2
        """
        Z = xb_pertb / np.sqrt(N_ens-1)
        HZ = Hxb_pertb / np.sqrt(N_ens-1)
        eigval, C = np.linalg.eig(HZ.T @ np.linalg.inv(R) @ HZ)
        T = C @ np.diag(1 / np.sqrt(1 + eigval))
        T = T.real   # imag part is likely due to numerical error
        xa_pertb = xb_pertb @ T
        return xa_pertb
        
    def _analysis(self, xb, yo, R, H_func, loc_mo, loc_oo, mean_method='w'):       
        N_ens = xb.shape[1]
        xb_mean = xb.mean(axis=1)[:,np.newaxis]   # (ndim_xb, 1)
        xb_pertb = xb - xb_mean   # (ndim_xb, N_ens)
        Hxb_mean = H_func(xb).mean(axis=1)[:,np.newaxis]   # (ndim_yo, 1)
        Hxb_pertb = H_func(xb) - Hxb_mean   # (ndim_yo, N_ens)
        
        if mean_method == 'w':
            xa_mean = self._analysis_mean_w(xb_mean, xb_pertb, Hxb_mean, Hxb_pertb, N_ens, yo, R)
        elif mean_method == 'K':
            xa_mean = self._analysis_mean_K(xb_mean, xb_pertb, Hxb_pertb, N_ens, yo, R, H_func, loc_mo, loc_oo)
        else:
            raise TypeError('`mean_method` should be "w" or "K"')
            
        xa_pertb = self._analysis_perturb(xb_pertb, Hxb_pertb, N_ens, R)
        xa = xa_mean + xa_pertb
        return xa
    
    def cycle(self, mean_method='w'):
        super().cycle(mean_method=mean_method)
        
        
class EAKF(EnsembleBase): 
    """
    Ensemble Adjustment Kalman Filter
    
    It based on the 2-step procedure of Anderson (2003), and followed the 
    step-by-step introduction of Shen et al. (2018) or Liu el al. (2007).
    
    *Reference
    [1]
    Zheqi Shen, Youmin Tang, Xiaojing Li, Yanqiu Gao, and Junde Li, 2018:
    On the localization in strongly coupled ensemble data assimilationusing 
    a two-scale Lorenz model
    https://www.nonlin-processes-geophys-discuss.net/npg-2018-50/
    
    [2]
    Anderson, 2003: A local least squares framework for ensemble filtering
    https://doi.org/10.1175/1520-0493(2003)131<0634:ALLSFF>2.0.CO;2
    
    [3]
    Liu, H., J. Anderson, Y.-H. Kuo, and K. Raeder, 2007: Importance of 
    forecast error multivariate correlations in idealized assimilation of GPS
    radio occultation data with the ensemble adjustment filter. 
    https://journals.ametsoc.org/doi/abs/10.1175/MWR3270.1
    """
    def _check_params(self):
        # default setting
        if self._params.get('local') is None:
            loc_mo = self._default_local(is_loc_mo=True)
            self._params['local'] = (loc_mo,)          
        
        # check parameters and check if R is diagonal matrix
        super()._check_params()
        self._check_R_diag()
    
    def _analysis(self, xb, yo, R, H_func, loc_mo):  
        """xb.shape = (N_x, N_ens)"""
        N_x, N_ens = xb.shape
        
        # serially assimilation
        xa = xb.copy()
        for io, iyo in enumerate(yo):
            ### step 1
            # estimate background field at the observation space
            yp = H_func(xa)[io,:]   # (N_ens,)
                
            # analysis for the background field at the observation space
            yp_mean = yp.mean()
            yp_var = np.sum((yp - yp_mean)**2) / (N_ens-1)
            r = R[io,io]
            yu_var = 1 / (1/yp_var + 1/r)
            yu_mean = yu_var * (yp_mean / yp_var + iyo / r)
            yu = np.sqrt(yu_var / yp_var) * (yp - yp_mean) + yu_mean   # (N_ens,)
            increment_y = yu - yp   # (N_ens,)
            
            ### step 2 
            cov_xy_states = _covariance(xa, yp, N_ens)   # (N_x,)
            increment_x_states = cov_xy_states[:,np.newaxis] / yp_var * increment_y   # (N_x, N_ens)
            xa += loc_mo[:,[io]] * increment_x_states
                
        return xa

In [25]:
for Da in [EnKF, EnSRF, ETKF, EAKF]:
    da = Da(lorenz96_fdm, dt)
    params = {
        'X_ens_ini': X_ens_ini,
        'obs': obs,
        'obs_interv': obs_intv,
        'R': R,
        'H_func': lambda arr: arr,
        'alpha': 0.4,
        'inflat': 1.5,
    }
    if Da in [EnKF, ETKF]:
        params['local'] = (loc1, loc2)
    else:
        params['local'] = (loc1,)
    da.set_params(**params)
    
    if Da is ETKF:
        da.cycle(mean_method='K')
    elif Da is EnKF:
        da.cycle()
    else:
        da.cycle()
    analysis = da.analysis.mean(axis=0)
    
    name = da.__str__().split('.')[1].split()[0]
    print(name + ', Mean RMSE: ', da_rmse(nature, analysis, obs_intv).mean())

EnKF, Mean RMSE:  0.9300048445638943
EnSRF, Mean RMSE:  0.9266270182116392
ETKF, Mean RMSE:  1.102719872230467
EAKF, Mean RMSE:  0.9266270182116393


In [24]:
for seed in range(25):
    da = EnKF(lorenz96_fdm, dt)
    params = {
        'X_ens_ini': X_ens_ini,
        'obs': obs,
        'obs_interv': obs_intv,
        'R': R,
        'H_func': lambda arr: arr,
        'alpha': 0.4,
        'inflat': 1.5,
        'local': (loc1, loc2)
    }
    da.set_params(**params)
    da.cycle(random_state=seed)
    analysis = da.analysis.mean(axis=0)
    print(seed, da_rmse(nature, analysis, obs_intv).mean())

0 1.1127781706911606
1 1.1138001155678579
2 1.1206119752916517
3 1.1291263315300275
4 1.1116639054646738
5 1.106930572901526
6 1.1044092336795808
7 1.111391046434195
8 1.1261910641683837
9 1.1204035986623504
10 1.1036391830926335
11 1.109070014051735
12 1.1162063540256706
13 1.118571829946433
14 1.13563489437062
15 1.1173907891599326
16 1.1021783784254295
17 1.1265696100427387
18 1.1359512872950592
19 1.1163339959102339
20 1.1215044842242528
21 1.1178066296798521
22 1.1311339905649345
23 1.1245671805787556
24 1.124292506724144
