# 4. Example: 4DVar Backprop DA on Lorenz 96

In [24]:
import dabench as dab
import numpy as np
import matplotlib.pyplot as plt
from jax.scipy.sparse.linalg import bicgstabsearch
from scipy.linalg import inv
from scipy.sparse.linalg import aslinearoperator
from copy import deepcopy
import jax.numpy as jnp
from jax import value_and_grad
from IPython.lib.deepreload import reload
import pickle

from ray import train, tune
from hyperopt import hp
from ray.tune.search.hyperopt import HyperOptSearch

In [2]:
random_seed = 50
np_rng = np.random.default_rng(seed=random_seed)

## A. Create Nature Run

In [30]:
spinup_size = 10000
val_size = 2000
transient_size = 1000
test_size = 2000
system_dim = 36
obs_location_count = round(system_dim/2)
obs_sd = 0.5
sigma_bg = 0.3
sigma_obs = 0.625
delta_t = 0.01
analysis_window = 0.1
analysis_time_in_window = 0.05

In [19]:
# First define data generator and generate data
nature_run = dab.data.Lorenz96(system_dim=system_dim, delta_t=delta_t,
                               store_as_jax=True, random_seed=random_seed)

nature_run.generate(n_steps=(spinup_size + val_size + transient_size + test_size))
nr_spinup, nr_valid, nr_transient_and_test = nature_run.split_train_valid_test(
    spinup_size, val_size, transient_size + test_size)
nr_transient, nr_test, _ = nr_transient_and_test.split_train_valid_test(
    transient_size, test_size, 0)


## B. Define Forecast Model

In [20]:
model_l96 = dab.data.Lorenz96(system_dim=system_dim, delta_t=delta_t, 
                              store_as_jax=True, random_seed=random_seed)

class L96Model(dab.model.Model):                                                                       
    """Defines model wrapper for Lorenz96 to test forecasting."""
    def forecast(self, state_vec, n_steps):
        self.model_obj.generate(x0=state_vec.values, n_steps=n_steps)
        new_vals = self.model_obj.values 

        new_vec = dab.vector.StateVector(values=new_vals, store_as_jax=True)

        return new_vec

fc_model = L96Model(model_obj=model_l96)

## C. Create observer and DA matrices for validation set

In [21]:

# Now we can define the observer:
obs_l96 = dab.observer.Observer(
    nr_valid,
    time_indices = np.arange(0, nr_valid.time_dim, 5),
    random_location_count = obs_location_count,
    error_bias = 0.0,
    error_sd = obs_sd,
    random_seed=random_seed,
    stationary_observers=True,
    store_as_jax=True
)
obs_vec_l96 = obs_l96.observe()

# Making observations
obs_vec_l96 = obs_l96.observe()

### Set up DA matrices: H (observation), R (obs error), B (background error)
H = np.zeros((obs_location_count, system_dim))
H[np.arange(H.shape[0]), obs_vec_l96.location_indices[0]] = 1
R = (sigma_obs**2)* np.identity(obs_location_count)
B = (sigma_bg**2)*np.identity(system_dim)

## D. RayTune to find learning rate and learning rate decay

In [39]:
def run_4dvar_backprop(lr_config):
    learning_rate = lr_config['learning_rate']
    lr_decay = lr_config['lr_decay']
    # Prep DA object
    dc = dab.dacycler.Var4DBackprop(
        system_dim=system_dim,
        delta_t=nr_valid.delta_t,
        H=H,
        B=B,
        R=R,
        learning_rate=learning_rate,
        lr_decay=lr_decay,
        model_obj=fc_model,
        obs_window_indices=[0,5,10],
        steps_per_window=11, # 11 instead of 10 because inclusive of 0 and 11
        )

    # Generate initial conditions
    cur_tstep = 0
    x0_original = nr_valid.values[cur_tstep] + np_rng.normal(size=(system_dim,), 
                                                            scale=1)
    x0_sv = dab.vector.StateVector(
        values=x0_original,
        store_as_jax=True)
    
    # Execute
    out_statevec = dc.cycle(
        input_state = x0_sv,
        start_time = nr_valid.times[cur_tstep],
        obs_vector = obs_vec_l96,
        analysis_window=analysis_window,
        timesteps=198,
        obs_error_sd=sigma_obs,
        analysis_time_in_window=analysis_time_in_window)
    
    rmse = np.sqrt(np.mean(np.square(nr_valid.values[:-20] - out_statevec.values)))

    train.report({'rmse':rmse})

In [40]:
# Define search space
space = {
    "learning_rate": hp.loguniform("lr", -10, 0),
    "lr_decay": hp.uniform("lr_decay", 0.1, 0.99),
}

hyperopt_search = HyperOptSearch(space, metric="rmse", mode="min")
tuner = tune.Tuner(
    run_4dvar_backprop,
    tune_config=tune.TuneConfig(
        num_samples=50,
        max_concurrent_trials=4,
        search_alg=hyperopt_search,
    ),
)

results = tuner.fit()

0,1
Current time:,2024-06-21 15:28:59
Running for:,00:03:05.33
Memory:,7.4/30.2 GiB

Trial name,status,loc,learning_rate,lr_decay,iter,total time (s),rmse
run_4dvar_backprop_673c7dc3,TERMINATED,192.168.1.97:226014,0.745799,0.809063,1,6.31994,0.415743
run_4dvar_backprop_606ac3ff,TERMINATED,192.168.1.97:226112,0.00117245,0.823186,1,6.49358,4.78157
run_4dvar_backprop_6c2711ba,TERMINATED,192.168.1.97:226242,0.147453,0.969142,1,6.5042,0.376743
run_4dvar_backprop_cc693dc6,TERMINATED,192.168.1.97:226390,0.230598,0.861168,1,6.73677,0.35024
run_4dvar_backprop_870bafed,TERMINATED,192.168.1.97:226569,5.07984e-05,0.398777,1,6.57032,4.88589
run_4dvar_backprop_ef30a7b1,TERMINATED,192.168.1.97:226695,0.000614408,0.181215,1,6.72395,4.71815
run_4dvar_backprop_97a354ea,TERMINATED,192.168.1.97:226843,8.37834e-05,0.945128,1,6.46317,4.92967
run_4dvar_backprop_73d6a704,TERMINATED,192.168.1.97:227001,0.00189957,0.169936,1,6.47952,4.6311
run_4dvar_backprop_25d4e4cc,TERMINATED,192.168.1.97:227143,0.289999,0.330016,1,6.65484,0.352524
run_4dvar_backprop_b195ee6e,TERMINATED,192.168.1.97:227285,0.478601,0.613413,1,6.60309,0.342282


2024-06-21 15:28:59,244	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/ksolvik/ray_results/run_4dvar_backprop_2024-06-21_15-25-53' in 0.0108s.
2024-06-21 15:28:59,256	INFO tune.py:1041 -- Total run time: 185.35 seconds (185.32 seconds for the tuning loop).
