# 5. Example: 4DVar Backprop DA using Reservoir Computing ML Model on QGS Model - RayTune for HyperParams

Using Reservoir Computing model macro-parameters from QGS/Jason Platt's paper: Platt, J. A., Wong, A., Clark, R., Penny, S. G. & Abarbanel, H. D. I. Robust forecasting using predictive generalized synchronization in reservoir computing. Chaos: An Interdisciplinary Journal of Nonlinear Science 31, 123118 (2021).


In [None]:
import dabench as dab
import numpy as np
import matplotlib.pyplot as plt
from qgs.params.params import QgParams
import pandas as pd

from ray import train, tune
from hyperopt import hp
from ray.tune.search.hyperopt import HyperOptSearch

In [None]:
%%bash
# Make output dir
mkdir -p out/qgs

In [None]:
random_seed=50
np_rng = np.random.default_rng(seed=random_seed)

## A. Create Nature Run

In [None]:
# Time parameters
dt = 0.5
# transient time to attractor
transient_time = 1.e5
# integration time on the attractor
integration_time = 1.e4


# Setting some model parameters
# Model parameters instantiation with some non-default specs
model_params = QgParams({'phi0_npi': np.deg2rad(50.)/np.pi, 'hd': 0.1})
# Mode truncation at the wavenumber 2 in both x and y spatial coordinate
model_params.set_atmospheric_channel_fourier_modes(2, 2)

# Changing (increasing) the orography depth and the meridional temperature gradient
model_params.ground_params.set_orography(0.2, 1)
model_params.atemperature_params.set_thetas(0.2, 0)

In [None]:
train_size = 100000
transient_size = 1000
valid_size = 10000
test_size = 10000

In [None]:
nature_run = dab.data.QGS(model_params=model_params, store_as_jax=False, delta_t=dt,
                          random_seed=random_seed)
nature_run.generate(n_steps=200*1000, stride=1000, x0=np_rng.random(model_params.ndim)*0.001, mxstep=5000)
x0 = nature_run.values[-1]

In [None]:
nature_run.generate(x0=x0, n_steps=(train_size +  valid_size + transient_size + test_size), mxstep=5000)

In [None]:
# NOTE: This raises a Parameter dimensional conversion warning that can be safely ignored
nr_train, nr_valid, nr_transient_and_test = nature_run.split_train_valid_test(
    train_size, valid_size, transient_size + test_size)
nr_transient, nr_test, _ = nr_transient_and_test.split_train_valid_test(transient_size, test_size, 0)

In [None]:
# Let's visualize the results
fig, axes = plt.subplots(6, 1, sharex = True, figsize = (10, 8))
for j, ax in enumerate(axes):
    ax.plot(nr_test.times, nr_test.values[:,j], lw = 3, label = 'Nature Run')
    ax.set_ylabel(r'$x_{:d}$'.format(j), fontsize = 16)
#     ax.set_xlim(54000, 55000)

## B. Generate Observations

In [None]:
# First we need to calculate the per-variable SD for QGS model
obs_sd_scale = 0.1
per_variable_sd = np.std(nr_train.values, axis=0)
obs_sd = 0.1*per_variable_sd

In [None]:
# Now we can define the observer:
obs_qgs_test = dab.observer.Observer(
    nr_test, # Data generator object
    time_indices = np.arange(0, nr_test.time_dim, 3), # Observation every 5th timestep
    random_location_count = 10,
    error_bias = 0.0, 
    error_sd = obs_sd, 
    random_seed=93,
    stationary_observers=True,
    store_as_jax=False
)

# Now we can define the observer:
obs_qgs_valid = dab.observer.Observer(
    nr_valid, # Data generator object
    time_indices = np.arange(0, nr_valid.time_dim, 3), # Observation every 5th timestep
    random_location_count = 10,
    error_bias = 0.0, 
    error_sd = obs_sd,  
    random_seed=93,
    stationary_observers=True,
    store_as_jax=False
)

# Making observations
obs_vec_valid = obs_qgs_valid.observe()
obs_vec_test = obs_qgs_test.observe()

## C. Define and train model

In [None]:
# Define model
forecast_model = dab.model.RCModel(
    system_dim=20, 
    reservoir_dim=2000, 
    input_dim=20, 
    sparse_adj_matrix=False, 
    sparsity=0.99,
    readout_method='linear',
    sigma=0.98765777241154,
    sigma_bias = 0.675882947305197,
    spectral_radius =  0.376752115791648, # SR 
    leak_rate=0.5343730100231164, # alpha
    log_beta=-20.219432227197363)

In [None]:
# # # Train: Takes a few minutes
# forecast_model.weights_init()
# forecast_model.train(nr_train)
# forecast_model.save_weights('./out/qgs/rc_weights.pkl')

In [None]:
# Shortcut: load the W_out weights from previous training
weights_path = './out/qgs/rc_weights.pkl'
forecast_model.weights_init()
forecast_model.load_weights(weights_path)

## D. Set up and run DA Cycler for validation period

In [None]:
# Getting standard deviation of reservoir state values for sigma_bg
train_res_values = forecast_model.generate(nr_train.values)
train_res_sd = np.std(train_res_values, axis=0)

In [None]:
sigma_obs=obs_sd[obs_vec_valid.location_indices[0]]*1.25
sigma_bg = 0.1*train_res_sd 

R = (sigma_obs**2)*np.identity(sigma_obs.shape[0])
B = (sigma_bg**2)*np.identity(forecast_model.reservoir_dim)
S = np.zeros((10, 20))
S[np.arange(S.shape[0]), np.tile(obs_vec_valid.location_indices[0], 1)] = 1
H = S @ forecast_model.Wout.T

In [None]:
r0_original = forecast_model.update(forecast_model.generate(
        nr_train.values[-1000:]
        + np.random.normal(size = (1000, nr_train.system_dim), scale=obs_sd))[-1], nr_train.values[-1])

In [None]:
def raytune_v4d_bp_valid(config, num_iters, H, B, R, 
                         delta_t, forecast_model, r0_original,
                         nr, obs_vec, sigma_obs):
    
    dc = dab.dacycler.Var4DBackprop(
    system_dim=forecast_model.reservoir_dim,
    delta_t=delta_t,
    H=H,
    B=B,
    R=R,
    learning_rate=config['lr'],
    lr_decay=config['lr_decay'],
    num_iters=num_iters,
    model_obj=forecast_model,
    obs_window_indices=[0,2,5,8],
    steps_per_window=10, # 10 instead of 9 because inclusive of start and end
    )
    cur_tstep = 0

    r0_sv = dab.vector.StateVector(
        values=r0_original,
        store_as_jax=True)

    out_statevec = dc.cycle(
        input_state = r0_sv,
        start_time = nr.times[cur_tstep],
        obs_vector = obs_vec,
        analysis_window=4.5,
        timesteps=int(nr_valid.time_dim/9)-2,
        obs_error_sd=sigma_obs,
        analysis_time_in_window=2.25)
        
    rmse = np.sqrt(np.mean(np.square(nr.values[:-19] - forecast_model.readout(out_statevec.values))))
    
    train.report({'rmse':rmse})

### RayTune

In [None]:
space = {
    "lr": hp.loguniform("lr", -10, 0),
    "lr_decay": hp.uniform("lr_decay", 0.1, 0.99),
}
num_iters = 3

In [None]:
trainable_w_num_iters = tune.with_parameters(raytune_v4d_bp_valid, 
                                              delta_t=nature_run.delta_t,
                                              num_iters=num_iters,
                                              H=H,
                                              B=B,
                                              R=R,
                                              forecast_model=forecast_model,
                                              r0_original=r0_original,
                                              sigma_obs=sigma_obs,
                                              obs_vec=obs_vec_valid,
                                              nr=nr_valid
                                             )

In [None]:
hyperopt_search = HyperOptSearch(space, metric="rmse", mode="min")
tuner = tune.Tuner(
    trainable_w_num_iters,
    tune_config=tune.TuneConfig(
        num_samples=50,
        max_concurrent_trials=1,
        search_alg=hyperopt_search,
    ),
)
results = tuner.fit()

cur_results_df = results.get_dataframe()

cur_results_df['num_iters'] = num_iters

In [None]:
cur_results_df.to_csv('./out/qgs/raytune_qgs_v4_hessian.csv')