<a id="test"></a>
# Define truth and ensemble

In [1]:
from essentials.physical_models import VdP
from essentials.create import create_truth, create_ensemble, create_bias_model
from essentials.DA import dataAssimilation
from essentials.bias_models import ESN
from essentials.plotResults import *

import numpy as np

rng = np.random.default_rng(0)

dt_t = 2e-4


# The manual bias is a function of state and/or time
def manual_bias(y, t):
    # Linear function of the state
    return .2 * y + .3 * np.max(y, axis=0), 'linear'
    # Periodic function of the state
    # return 0.5 * np.max(y, axis=0) * np.cos(2 * y / np.max(y, axis=0)), 'periodic'
    # Time-varying bias
    # return .4 * y * np.sin((np.expand_dims(t, -1) * np.pi * 2) ** 2), 'time'


true_params = dict(model=VdP,
                   t_start=1.5,
                   t_stop=1.8,
                   t_max=2.5,
                   Nt_obs=30,
                   dt=dt_t,
                   psi0=rng.random(2)+5,
                   std_obs=0.1,
                   noise_type='gauss,additive',
                   manual_bias=manual_bias
                   )

truth = create_truth(**true_params)


y_obs, t_obs = [truth[key].copy() for key in ['y_obs', 't_obs']]

# # Visualize the truth and observations
# plot_truth(f_max=300, window=0.1, **truth)



In [6]:

forecast_params = dict(filter='rBA_EnKF',
                       m=10,
                       dt=dt_t,
                       model=VdP,
                       est_a=dict(zeta=(40, 80.),
                                 beta=(50, 80),
                                 kappa=(3, 5),
                                 ),
                       std_psi=0.3,
                       alpha_distr='uniform',
                       inflation=1.0,
                       regularization_factor=5.
                       )
ensemble = create_ensemble(**forecast_params)

train_params = dict(bias_model=ESN, 
                    upsample=2,
                    N_units=40,
                    N_wash=10,
                    t_train=ensemble.t_CR * 5,
                    t_test=ensemble.t_CR * 1,
                    t_val=ensemble.t_CR * 1,
                    # Training data generation options
                    augment_data=True,
                    biased_observations=True,
                    L=20,
                    # Hyperparameter search ranges
                    rho_range=(0.4, 1.),
                    sigma_in_range=(np.log10(1e-5), np.log10(1e1)),
                    tikh_range=[1e-16]
                    )

bias_model, wash_obs, wash_t = create_bias_model(ensemble, truth.copy(), bias_params=train_params)

ensemble.bias = bias_model.copy()


##  Run simulation

In [7]:

ensemble.regularization_factor = 5.

filter_ens = dataAssimilation(ensemble.copy(), y_obs=y_obs, t_obs=t_obs, std_obs=0.01, wash_t=wash_t, wash_obs=wash_obs)

##  Plot results

In [8]:

# Forecast the ensemble further without assimilation
ens = filter_ens.copy()

Nt = int(4 * true_params['Nt_obs'])
psi, t = ens.time_integrate(Nt, averaged=False)
ens.update_history(psi, t)

y = ens.get_observable_hist(Nt)
b, t_b = ens.bias.time_integrate(t=t, y=y)
ens.bias.update_history(b, t_b)


plot_timeseries(ens, truth.copy(), plot_ensemble_members=False, plot_bias=True)
plot_parameters(ens, truth.copy(), reference_p=truth['case'].alpha0)
