In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

from model import lorenz63_fdm
from nmc import nmc

In [2]:
x0 = np.array([[8, 0, 30]]).T
end_time = 10
dt = 0.01
ts = np.arange(0, end_time, dt)

nature = lorenz63_fdm(x0, ts)
nature.shape

(3, 1000)

In [3]:
Pb = nmc(lorenz63_fdm, nature, dt, 1, 0.04)
Pb

array([[ 6.67803485,  7.53315421, -0.03577492],
       [ 7.53315421,  8.74734398, -0.33658703],
       [-0.03577492, -0.33658703,  3.5070809 ]])

In [4]:
obs_intv = 8
obs = nature + np.sqrt(2) * np.random.randn(*nature.shape)
obs = obs[:,::obs_intv]
obs.shape

(3, 125)

In [5]:
R = np.eye(3) * 2
R

array([[2., 0., 0.],
       [0., 2., 0.],
       [0., 0., 2.]])

In [6]:
# a very bad initial condition
X_ini = x0 + np.array([[10, -10, 15]]).T
X_ini

array([[ 18],
       [-10],
       [ 45]])

In [7]:
N_ens = 30
X_ens_ini = np.random.multivariate_normal(X_ini.ravel(), Pb, size=N_ens).T  # (ndim, N_ens)
X_ens_ini.shape

(3, 30)

In [8]:
def da_rmse(nature, analysis, obs_intv):
    return np.sqrt(np.mean((analysis[:,::obs_intv] - nature[:,::obs_intv]) ** 2, axis=0))

def plot_assimilation_result(nature, obs, analysis, obs_intv):
    fig, axs = plt.subplots(nrows=4, figsize=(8, 8), sharex=True)
    for i in range(3):
        axs[i].plot(ts, nature[i,:], color='#024BC7', label='nature')
        axs[i].plot(ts[::obs_intv], obs[i,:], '.', color='#024BC7', label='obs')
        axs[i].plot(ts, analysis[i,:], color='#FFA500', label='analysis')
    axs[0].legend()
    axs[0].set_title('X')
    axs[1].set_title('Y')
    axs[2].set_title('Z')
    
    rmse = da_rmse(nature, analysis, obs_intv)
    axs[3].plot(ts[::obs_intv], rmse, '.-')
    axs[3].set_title('RMSE')
    
    plt.tight_layout()

In [53]:
from assimilation import EnKF

class ETKF(EnKF):
    def _analysis(self, xb, yo, R, H_func=None, loc_mo=None, loc_oo=None):
        print('aa')
        if H_func is None:
            H_func = lambda arr: arr
            
        N_ens = xb.shape[1]
        xb_mean = xb.mean(axis=1)[:,np.newaxis]   # (ndim_xb, 1)
        xb_pertb = xb - xb_mean   # (ndim_xb, N_ens)
        
        # assimilate ensemble mean
        Hxb_mean = H_func(xb).mean(axis=1)[:,np.newaxis]   # (ndim_yo, 1)
        Hxb_pertb = H_func(xb) - Hxb_mean   # (ndim_yo, N_ens)
        PfH_T = xb_pertb @ Hxb_pertb.T / (N_ens-1)
        HPfH_T = Hxb_pertb @ Hxb_pertb.T / (N_ens-1)
        K = PfH_T @ np.linalg.inv(HPfH_T + R)
        xa_mean = xb_mean + K @ (yo - H_func(xb_mean))
        
        # assimilate ensemble deviation    
        Y = Hxb_pertb   # (ndim_yo, N_ens)
        WW_T = np.linalg.inv(np.eye(N_ens) - Y.T @ np.linalg.inv(R) @ Y)
        print(xb)
        
        D, U = np.linalg.eig(WW_T)
        D = np.diag(D)
        
        W = U @ np.sqrt(D) @ U.T
        xa_pertb = xb_pertb @ W
        
        xa = xa_mean + xa_pertb
        return xa

In [54]:
etkf = ETKF(lorenz63_fdm, dt)
params = {
    'X_ens_ini': X_ens_ini,
    'obs': obs,
    'obs_interv': obs_intv,
    'R': R,
    'H_func': lambda arr: arr,
    'alpha': 0.3,
    'inflat': 1.4,
}
etkf.set_params(**params)
etkf.cycle()

aa
[[ 19.14312766  16.23891122  16.63499517  19.13582046  20.68592426
   18.10774106  17.06792126  21.48434828  20.03652221  16.8711954
   22.94118775  14.44049962  18.53334121  16.67453795  18.33759604
   22.90374456  17.54421099  14.79774362  19.80619077  19.4192304
   15.98344409  17.03758883  16.56505625  20.37378176  20.89073295
   17.87084733  22.26020152  15.84616215  18.83845726  16.30965359]
 [ -8.58240864 -11.04893289 -10.82389886  -9.47393714  -6.60522298
   -9.66621712 -10.76564873  -6.32608338  -7.09640066 -11.79195825
   -4.33043372 -13.3618757   -7.80615771 -11.79504859  -9.61370588
   -4.93537274 -11.11274486 -13.27839528  -7.56790755  -8.88962788
  -12.59893078 -11.07725213 -11.71953943  -6.88821612  -6.68858457
   -9.90870213  -5.49704561 -11.85794063  -8.78862966 -11.57423449]
 [ 47.14885322  44.86549575  44.11849609  47.87699699  42.96886585
   46.55492568  44.06826113  43.44759241  41.33116041  48.49310032
   45.06949705  44.06080868  45.91533825  45.31675053  45.6



LinAlgError: Array must not contain infs or NaNs