In [None]:
import dabench as dab
import numpy as np
import jax
from timeit import default_timer as timer
import pandas as pd
import jaxlib

from ray import train, tune
from hyperopt import hp
from ray.tune.search.hyperopt import HyperOptSearch

# Define parameters

In [None]:
year_in_timesteps = 4380
spinup_size = 5*year_in_timesteps
valid_size = round(year_in_timesteps/4)
transient_size = 1*year_in_timesteps
test_size = 1*year_in_timesteps

In [None]:
nr_steps = spinup_size + valid_size + transient_size + test_size
delta_t=7200
analysis_window = 6*delta_t
analysis_time_in_window = 3*delta_t
num_iters = 3

### Function definition: Backprop 4DVar

We'll need to prep and run Backprop-4DVar many times, so this wraps it all into functions.

Note: Since we're using raytune, we're separating this into a prep function and a run function to speed things up

In [None]:
def prep_backprop_4dvar(system_dim_xy, nr_steps, spinup_size, valid_size, test_size,
                        test_run, delta_t, sigma_bg_multiplier, sigma_obs_multiplier, 
                        random_seed):
    
    np_rng = np.random.default_rng(random_seed)
    jax.clear_backends()

    
    ### Nature Run
    nature_run = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy, delta_t=delta_t, 
                                  store_as_jax=True, random_seed=random_seed)

    nature_run.generate(n_steps=nr_steps) 
    nr_spinup, nr_valid, nr_transient_and_test = nature_run.split_train_valid_test(
        spinup_size, valid_size, transient_size + test_size)
    nr_transient, nr_test, _ = nr_transient_and_test.split_train_valid_test(
        transient_size, test_size, 0)

    if not test_run:
        nr_eval = nr_valid
    else:
        nr_eval = nr_test
        
        
    ### Observations
    obs_location_count = round(nature_run.system_dim/2)

    # First we need to calculate the per-variable SD for QGS model
    obs_sd_scale = 0.1
    per_variable_sd = np.std(nr_spinup.values, axis=0)
    obs_sd = 0.1*per_variable_sd

    obs_pyqg = dab.observer.Observer(
        nr_eval,
        time_indices = np.arange(0, nr_eval.time_dim, 3),
        random_location_count = obs_location_count,
        error_bias = 0.0,
        error_sd = obs_sd,
        random_seed=random_seed+test_run,
        stationary_observers=True,
        store_as_jax=True
    )

    obs_vec_pyqg = obs_pyqg.observe()

    
    ### Forecast Model
    model_pyqg = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy,
                                  store_as_jax=True, random_seed=random_seed)

    class PyQGModel(dab.model.Model):                                                                       
        """Defines model wrapper for forecasting."""
        def forecast(self, state_vec, n_steps):
            gridded_values = state_vec.values.reshape(self.model_obj.original_dim)
            self.model_obj.generate(x0=gridded_values, n_steps=n_steps)
            new_vals = self.model_obj.values

            new_vec = dab.vector.StateVector(values=new_vals, store_as_jax=True)

            return new_vec

        def _forecast_x0(self, x0, n_steps):
            self.model_obj.generate(x0=x0.reshape(self.model_obj.original_dim),
                                    n_steps=n_steps)
            return self.model_obj.values
        
    fc_model = PyQGModel(model_obj=model_pyqg)

    
    ### Set up DA matrices: H (observation), R (obs error), B (background error)
    sigma_obs=sigma_obs_multiplier*obs_sd[obs_vec_pyqg.location_indices[0]]
    sigma_bg = sigma_bg_multiplier*obs_sd
    H = np.zeros((obs_location_count, nature_run.system_dim))
    H[np.arange(H.shape[0]), obs_vec_pyqg.location_indices[0]] = 1
    R = (sigma_obs**2) * np.identity(obs_location_count)
    B = (sigma_bg**2)*np.identity(nature_run.system_dim)
    
    
    ### Prep DA    
    dc = dab.dacycler.Var4DBackprop(
        system_dim=nature_run.system_dim,
        delta_t=nr_eval.delta_t,
        H=H,
        B=B,
        R=R,
        # These will be set later in run_backprop_4dvar
        num_iters=None, 
        learning_rate=None,
        lr_decay=None,
        model_obj=fc_model,
        obs_window_indices=[0,3,6],
        steps_per_window=7, # 7 instead of 6 because inclusive of 0 and 6
        )
    
    # Generate initial conditions
    cur_tstep = 0
    x0_original = nr_eval.values[cur_tstep] + np_rng.normal(size=(nature_run.system_dim,),
                                                            scale=sigma_bg)
    x0_sv = dab.vector.StateVector(
        values=x0_original,
        store_as_jax=True)
    start_time = nr_eval.times[cur_tstep]
    
    # Return necessary objects for running DA with different LR/LR Decay
    return dc, x0_sv, start_time, obs_vec_pyqg, nr_eval, sigma_obs

In [None]:
def run_backprop_4dvar(config, dc, x0_sv, start_time, obs_vec, nr_eval,
                       sigma_obs, analysis_window, analysis_time_in_window,
                       num_iters):
    
    dc.learning_rate = config['lr']
    dc.lr_decay = config['lr_decay']
    dc.num_iters = num_iters
    
    try: 
        out_statevec = dc.cycle(
            input_state = x0_sv,
            start_time = start_time,
            obs_vector = obs_vec,
            analysis_window=analysis_window,
            timesteps=int(nr_eval.time_dim/6) - 2,
            obs_error_sd=sigma_obs,
            analysis_time_in_window=analysis_time_in_window)

        rmse = np.sqrt(np.mean(np.square(
            nr_eval.values[:out_statevec.values.shape[0]] - out_statevec.values
        )))
        train.report({'rmse':rmse})
        
    # Catch problem with exploding gradients
    except jaxlib.xla_extension.XlaRuntimeError:
        train.report({'rmse':999999})
    
    return 

# Hyperparameter optimization on validation set

Using RayTune to optimize alpha (learning rate) and alpha decay

In [None]:
space = {
    "lr": hp.loguniform("lr", -5, 0),
    "lr_decay": hp.uniform("lr_decay", 0.1, 0.99),
}



In [None]:
%%time

all_results_df_list = []
system_dim_xy_list = [16, 24, 32]

for system_dim_xy in system_dim_xy_list:
    tune_time_start = timer()

    random_seed = system_dim_xy
    
    # Run prep
    run_dict = dict(
        system_dim_xy=system_dim_xy, 
        nr_steps=nr_steps,
        spinup_size=spinup_size,
        valid_size=valid_size,
        test_size=test_size,
        test_run=False,
        delta_t=delta_t,
        sigma_bg_multiplier=0.5,
        sigma_obs_multiplier=1.25,
        random_seed=random_seed)
    dc, x0_sv, start_time, obs_vec, nr_eval, sigma_obs = prep_backprop_4dvar(**run_dict)
    
    
    print('Starting... {} system dim'.format(system_dim_xy))
    trainable_with_system_dim = tune.with_parameters(
        run_backprop_4dvar, 
        dc=dc,
        x0_sv=x0_sv,
        start_time=start_time,
        obs_vec=obs_vec,
        nr_eval=nr_eval,
        sigma_obs=sigma_obs, 
        analysis_window=analysis_window,
        analysis_time_in_window=analysis_time_in_window,
        num_iters=num_iters
    )
    
    hyperopt_search = HyperOptSearch(space, metric="rmse", mode="min",
                                     random_state_seed=22+system_dim_xy)
    
    # Run tuner
    tuner = tune.Tuner(
        trainable_with_system_dim,
        tune_config=tune.TuneConfig(
            num_samples=50,
            max_concurrent_trials=4,
            search_alg=hyperopt_search,
        ),
    )
    results = tuner.fit()

    cur_results_df = results.get_dataframe()

    cur_results_df['system_dim_xy'] = system_dim_xy
    
    tune_time = timer() - tune_time_start
    
    cur_results_df['total_tune_time'] = tune_time

    all_results_df_list.append(cur_results_df)

In [None]:
full_results_df = pd.concat(all_results_df_list)
full_results_df.to_csv('./out/pyqg_jax/pyqg_jax_raytune_sgdopt_werrors_v15_hessian_approx_3epochs_system_dims_16_24_32_smallsearch.csv')