In [None]:
import dabench as dab
import numpy as np
import jax
from timeit import default_timer as timer
import pandas as pd

from ray import train, tune
from hyperopt import hp
from ray.tune.search.hyperopt import HyperOptSearch

In [None]:
%%bash
mkdir -p out/l96

# Define parameters

In [None]:
system_dim= 36
spinup_size = 14400
valid_size = 5000
test_size = 5000
nr_steps = spinup_size + valid_size + test_size
delta_t=0.01
obs_sd = 0.5
sigma_bg = 0.3
sigma_obs = 0.625
analysis_window = 0.1
analysis_time_in_window = 0.05
obs_location_count = 18
random_seed = 5
num_iters = 3
n_outer_loops = 3 

# Function definition: Backprop 4DVar

We'll need to prep and run Backprop-4DVar many times, so this wraps it all into one function

In [None]:
def run_backprop_4dvar(system_dim, nr_steps, spinup_size, valid_size, test_size, 
                       test_run, delta_t, obs_location_count, obs_sd, sigma_bg, 
                       sigma_obs, analysis_window, analysis_time_in_window, 
                       random_seed, num_iters, learning_rate, lr_decay):
    np_rng = np.random.default_rng(random_seed)
    jax.clear_caches()

    ### Nature Run
    nature_run = dab.data.Lorenz96(system_dim=system_dim, delta_t=delta_t,
                                   store_as_jax=True, random_seed=random_seed)

    x0_initial = np_rng.normal(size=system_dim, scale=1)
    nature_run.generate(n_steps=nr_steps, x0 = x0_initial) 
    nr_spinup, nr_valid, nr_test = nature_run.split_train_valid_test(
        spinup_size, valid_size, test_size)

    if not test_run:
        nr_eval = nr_valid
    else:
        nr_eval = nr_test


    ### Observations
    obs_l96 = dab.observer.Observer(
        nr_eval,
        time_indices = np.arange(0, nr_eval.time_dim, 5),
        random_location_count = obs_location_count,
        error_bias = 0.0,
        error_sd = obs_sd,
        random_seed=random_seed,
        stationary_observers=True,
        store_as_jax=True
    )
    obs_vec_l96 = obs_l96.observe()

    
    ### Forecast Model
    model_l96 = dab.data.Lorenz96(system_dim=system_dim, delta_t=delta_t, 
                                  store_as_jax=True, random_seed=random_seed)

    class L96Model(dab.model.Model):                                                                       
        """Defines model wrapper for Lorenz96 to test forecasting."""
        def forecast(self, state_vec, n_steps):
            self.model_obj.generate(x0=state_vec.values, n_steps=n_steps)
            new_vals = self.model_obj.values 

            new_vec = dab.vector.StateVector(values=new_vals, store_as_jax=True)

            return new_vec

    fc_model = L96Model(model_obj=model_l96)
    
    ### Set up DA matrices: H (observation), R (obs error), B (background error)
    H = np.zeros((obs_location_count, system_dim))
    H[np.arange(H.shape[0]), obs_vec_l96.location_indices[0]] = 1
    R = (sigma_obs**2)* np.identity(obs_location_count)
    B = (sigma_bg**2)*np.identity(system_dim)

    
    ### Run data assimilation
    da_time_start = timer()
    
    # Prep DA object
    dc = dab.dacycler.Var4DBackprop(
        system_dim=system_dim,
        delta_t=nr_eval.delta_t,
        H=H,
        B=B,
        R=R,
        num_iters=num_iters,
        loss_growth_limit=5,
        learning_rate=learning_rate,
        lr_decay=lr_decay,
        model_obj=fc_model,
        obs_window_indices=[0,5,10],
        steps_per_window=11, # 11 instead of 10 because inclusive of 0 and 11
        )

    # Generate initial conditions
    cur_tstep = 0
    x0_original = nr_eval.values[cur_tstep] + np_rng.normal(size=(system_dim,), 
                                                            scale=1)
    x0_sv = dab.vector.StateVector(
        values=x0_original,
        store_as_jax=True)
    
    # Execute
    out_statevec = dc.cycle(
        input_state = x0_sv,
        start_time = nr_eval.times[cur_tstep],
        obs_vector = obs_vec_l96,
        analysis_window=analysis_window,
        timesteps=498,
        obs_error_sd=sigma_obs,
        analysis_time_in_window=analysis_time_in_window)
    
    da_time = timer()-da_time_start
    rmse = np.sqrt(np.mean(np.square(nr_eval.values[:-20] - out_statevec.values)))
    
    return out_statevec, rmse, obs_vec_l96, nr_eval, da_time

# Hyperparameter optimization on validation set

Using RayTune to optimize alpha (learning rate) and alpha decay

In [None]:
def train_backprop_4dvar(config, system_dim, num_obs, obs_sd):
    """A wrapper for RayTune to run backprop 4Dvar"""
    random_seed = system_dim 
    obs_location_count = num_obs
    run_dict = dict(
        system_dim=system_dim, 
        nr_steps=nr_steps,
        spinup_size=spinup_size,
        valid_size=valid_size,
        test_size=test_size,
        test_run=False,
        delta_t=delta_t,
        obs_location_count=obs_location_count,
        obs_sd=obs_sd,
        sigma_bg=obs_sd/1.5,
        sigma_obs=obs_sd*1.25,
        analysis_window=analysis_window,
        analysis_time_in_window=analysis_time_in_window,
        random_seed=random_seed,
        num_iters=num_iters,
        learning_rate=0,
        lr_decay = 0)

    run_dict['learning_rate'] = config['lr']
    run_dict['lr_decay'] = config['lr_decay']
    
    out_bp, error_bp, obs_vec_l96, nr_eval, da_time = run_backprop_4dvar(**run_dict)
    
    train.report({'rmse':error_bp})

### System size experiments

In [None]:
# Define search space
space = {
    "lr": hp.loguniform("lr", -5, 0),
    "lr_decay": hp.uniform("lr_decay", 0.1, 0.99),
}

In [None]:
all_results_df_list = []
system_dim_list = [6, 20, 36, 72, 144, 256]

for system_dim in system_dim_list:
    tune_time_start = timer()
    
    trainable_with_system_dim = tune.with_parameters(train_backprop_4dvar, 
                                                     system_dim=system_dim,
                                                     num_obs=int(system_dim/2),
                                                     obs_sd=0.5
                                                     )
    
    hyperopt_search = HyperOptSearch(space, metric="rmse", mode="min",
                                     random_state_seed=22+system_dim)
    tuner = tune.Tuner(
        trainable_with_system_dim,
        tune_config=tune.TuneConfig(
            num_samples=50,
            max_concurrent_trials=4,
            search_alg=hyperopt_search,
        ),
    )
    
    results = tuner.fit()

    cur_results_df = results.get_dataframe()
    cur_results_df['system_dim'] = system_dim
    tune_time = timer() - tune_time_start
    cur_results_df['total_tune_time'] = tune_time

    all_results_df_list.append(cur_results_df)

In [None]:
full_results_df = pd.concat(all_results_df_list)
full_results_df.to_csv('./out/l96/raytune_l96_hessian_v6.csv')

### Experiments varying number of observations and obs error

In [None]:
space = {
    "lr": hp.loguniform("lr", -5, 0),
    "lr_decay": hp.uniform("lr_decay", 0.1, 0.99),
}

In [None]:
all_results_df_list = []

system_dim = 36
num_obs_list = [6, 12, 18, 24, 30, 36]
obs_error_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 1.5, 2.0]

for obs_location_count in num_obs_list:
    for obs_i, obs_sd in enumerate(obs_error_list):
        tune_time_start = timer()
        
        trainable_with_system_dim = tune.with_parameters(train_backprop_4dvar, 
                                                         system_dim=system_dim,
                                                         num_obs=obs_location_count,
                                                         obs_sd=obs_sd
                                                         )

        hyperopt_search = HyperOptSearch(space, metric="rmse", mode="min",
                                         random_state_seed=22+obs_location_count+obs_i)
        tuner = tune.Tuner(
            trainable_with_system_dim,
            tune_config=tune.TuneConfig(
                num_samples=20,
                max_concurrent_trials=4,
                search_alg=hyperopt_search,
            ),
        )
        
        results = tuner.fit()

        cur_results_df = results.get_dataframe()
        cur_results_df['system_dim'] = system_dim
        cur_results_df['num_obs'] = obs_location_count
        cur_results_df['obs_sd'] = obs_sd
        tune_time = timer() - tune_time_start
        cur_results_df['total_tune_time'] = tune_time
        
        all_results_df_list.append(cur_results_df)


In [None]:
full_results_df = pd.concat(all_results_df_list)
full_results_df.to_csv('./out/l96/raytune_werrors_heatmap_hessian_v5.csv')