In [1]:
import dabench as dab
import numpy as np
import jax
from timeit import default_timer as timer
import pandas as pd

from ray import train, tune
from hyperopt import hp
from ray.tune.search.hyperopt import HyperOptSearch

In [2]:
%%bash
mkdir -p out/l96

# Define parameters

In [7]:
system_dim= 36
spinup_size = 14400
valid_size = 5000
test_size = 5000
nr_steps = spinup_size + valid_size + test_size
delta_t=0.01
obs_sd = 0.5
sigma_bg = 0.3
sigma_obs = 0.625
analysis_window = 0.1
analysis_time_in_window = 0.05
obs_location_count = 18
random_seed = 5
num_iters = 3
n_outer_loops = 3 

# Function definition: Backprop 4DVar

We'll need to prep and run Backprop-4DVar many times, so this wraps it all into one function

In [8]:
def run_backprop_4dvar(system_dim, nr_steps, spinup_size, valid_size, test_size, 
                       test_run, delta_t, obs_location_count, obs_sd, sigma_bg, 
                       sigma_obs, analysis_window, analysis_time_in_window, 
                       random_seed, num_iters, learning_rate, lr_decay):
    np_rng = np.random.default_rng(random_seed)
    jax.clear_caches()

    ### Nature Run
    nature_run = dab.data.Lorenz96(system_dim=system_dim, delta_t=delta_t,
                                   store_as_jax=True, random_seed=random_seed)

    x0_initial = np_rng.normal(size=system_dim, scale=1)
    nature_run.generate(n_steps=nr_steps, x0 = x0_initial) 
    nr_spinup, nr_valid, nr_test = nature_run.split_train_valid_test(
        spinup_size, valid_size, test_size)

    if not test_run:
        nr_eval = nr_valid
    else:
        nr_eval = nr_test


    ### Observations
    obs_l96 = dab.observer.Observer(
        nr_eval,
        time_indices = np.arange(0, nr_eval.time_dim, 5),
        random_location_count = obs_location_count,
        error_bias = 0.0,
        error_sd = obs_sd,
        random_seed=random_seed,
        stationary_observers=True,
        store_as_jax=True
    )
    obs_vec_l96 = obs_l96.observe()

    
    ### Forecast Model
    model_l96 = dab.data.Lorenz96(system_dim=system_dim, delta_t=delta_t, 
                                  store_as_jax=True, random_seed=random_seed)

    class L96Model(dab.model.Model):                                                                       
        """Defines model wrapper for Lorenz96 to test forecasting."""
        def forecast(self, state_vec, n_steps):
            self.model_obj.generate(x0=state_vec.values, n_steps=n_steps)
            new_vals = self.model_obj.values 

            new_vec = dab.vector.StateVector(values=new_vals, store_as_jax=True)

            return new_vec

    fc_model = L96Model(model_obj=model_l96)
    
    ### Set up DA matrices: H (observation), R (obs error), B (background error)
    H = np.zeros((obs_location_count, system_dim))
    H[np.arange(H.shape[0]), obs_vec_l96.location_indices[0]] = 1
    R = (sigma_obs**2)* np.identity(obs_location_count)
    B = (sigma_bg**2)*np.identity(system_dim)

    
    ### Run data assimilation
    da_time_start = timer()
    
    # Prep DA object
    dc = dab.dacycler.Var4DBackprop(
        system_dim=system_dim,
        delta_t=nr_eval.delta_t,
        H=H,
        B=B,
        R=R,
        num_iters=num_iters,
        loss_growth_limit=5,
        learning_rate=learning_rate,
        lr_decay=lr_decay,
        model_obj=fc_model,
        obs_window_indices=[0,5,10],
        steps_per_window=11, # 11 instead of 10 because inclusive of 0 and 11
        )

    # Generate initial conditions
    cur_tstep = 0
    x0_original = nr_eval.values[cur_tstep] + np_rng.normal(size=(system_dim,), 
                                                            scale=1)
    x0_sv = dab.vector.StateVector(
        values=x0_original,
        store_as_jax=True)
    
    # Execute
    out_statevec = dc.cycle(
        input_state = x0_sv,
        start_time = nr_eval.times[cur_tstep],
        obs_vector = obs_vec_l96,
        analysis_window=analysis_window,
        timesteps=498,
        obs_error_sd=sigma_obs,
        analysis_time_in_window=analysis_time_in_window)
    
    da_time = timer()-da_time_start
    rmse = np.sqrt(np.mean(np.square(nr_eval.values[:-20] - out_statevec.values)))
    
    return out_statevec, rmse, obs_vec_l96, nr_eval, da_time

# Hyperparameter optimization on validation set

Using RayTune to optimize alpha (learning rate) and alpha decay

In [9]:
def train_backprop_4dvar(config, system_dim, num_obs, obs_sd):
    """A wrapper for RayTune to run backprop 4Dvar"""
    random_seed = system_dim 
    obs_location_count = num_obs
    run_dict = dict(
        system_dim=system_dim, 
        nr_steps=nr_steps,
        spinup_size=spinup_size,
        valid_size=valid_size,
        test_size=test_size,
        test_run=False,
        delta_t=delta_t,
        obs_location_count=obs_location_count,
        obs_sd=obs_sd,
        sigma_bg=obs_sd/1.5,
        sigma_obs=obs_sd*1.25,
        analysis_window=analysis_window,
        analysis_time_in_window=analysis_time_in_window,
        random_seed=random_seed,
        num_iters=num_iters,
        learning_rate=0,
        lr_decay = 0)

    run_dict['learning_rate'] = config['lr']
    run_dict['lr_decay'] = config['lr_decay']
    
    out_bp, error_bp, obs_vec_l96, nr_eval, da_time = run_backprop_4dvar(**run_dict)
    
    train.report({'rmse':error_bp})

### System size experiments

In [6]:
# Define search space
space = {
    "lr": hp.loguniform("lr", -10, 0),
    "lr_decay": hp.uniform("lr_decay", 0.1, 0.99),
}

In [7]:
all_results_df_list = []
system_dim_list = [6, 20, 36, 72, 144, 256]

for system_dim in system_dim_list:
    tune_time_start = timer()
    
    trainable_with_system_dim = tune.with_parameters(train_backprop_4dvar, 
                                                     system_dim=system_dim,
                                                     num_obs=int(system_dim/2),
                                                     obs_sd=0.5
                                                     )
    
    # Not running in parallel to get fairer time comparison against regular 4DVar
    trainable_with_resources = tune.with_resources(trainable_with_system_dim, {"cpu": 16})
    hyperopt_search = HyperOptSearch(space, metric="rmse", mode="min")
    tuner = tune.Tuner(
        trainable_with_resources,
        tune_config=tune.TuneConfig(
            num_samples=50,
            search_alg=hyperopt_search,
        ),
    )
    
    results = tuner.fit()

    cur_results_df = results.get_dataframe()
    cur_results_df['system_dim'] = system_dim
    tune_time = timer() - tune_time_start
    cur_results_df['total_tune_time'] = tune_time

    all_results_df_list.append(cur_results_df)

0,1
Current time:,2024-06-11 14:55:43
Running for:,00:06:09.40
Memory:,14.3/30.2 GiB

Trial name,status,loc,lr,lr_decay,iter,total time (s),rmse
train_backprop_4dvar_d81d8bcc,RUNNING,192.168.0.249:176836,0.135031,0.437126,,,
train_backprop_4dvar_9febde54,PENDING,,0.0038607,0.393393,,,
train_backprop_4dvar_f83c1671,TERMINATED,192.168.0.249:172253,6.41662e-05,0.711967,1.0,131.809,5.25383
train_backprop_4dvar_ac16fc1c,TERMINATED,192.168.0.249:174140,0.00155875,0.17249,1.0,17.1571,5.19851
train_backprop_4dvar_2b64d62c,TERMINATED,192.168.0.249:174604,0.00120958,0.890643,1.0,21.3363,5.23255
train_backprop_4dvar_fee5a340,TERMINATED,192.168.0.249:175278,0.000101951,0.910586,1.0,19.4053,5.21081
train_backprop_4dvar_da844860,TERMINATED,192.168.0.249:175610,0.00156493,0.418865,1.0,19.803,5.40696
train_backprop_4dvar_62c94ac2,TERMINATED,192.168.0.249:175908,0.140245,0.854714,1.0,17.6376,0.43664
train_backprop_4dvar_2b87265a,TERMINATED,192.168.0.249:176211,0.0170713,0.664552,1.0,17.6616,4.00711
train_backprop_4dvar_f493dac4,TERMINATED,192.168.0.249:176521,0.927751,0.468668,1.0,19.1296,1.07684




[36m(train_backprop_4dvar pid=172253)[0m [[ 28.07867463  28.07412639  28.07088946]
[36m(train_backprop_4dvar pid=172253)[0m  [ 71.57125221  71.550212    71.53523923]
[36m(train_backprop_4dvar pid=172253)[0m  [151.74321096 151.70264896 151.67378378]
[36m(train_backprop_4dvar pid=172253)[0m  ...
[36m(train_backprop_4dvar pid=172253)[0m  [273.30950889 273.26930963 273.24069862]
[36m(train_backprop_4dvar pid=172253)[0m  [208.25943168 208.22513458 208.20072482]
[36m(train_backprop_4dvar pid=172253)[0m  [183.21065004 183.17359444 183.14722231]]




[36m(train_backprop_4dvar pid=174140)[0m  [150.13942426 149.16726516 149.00121814]
[36m(train_backprop_4dvar pid=174140)[0m  [482.5793027  480.79438437 480.48859941]
[36m(train_backprop_4dvar pid=174140)[0m  [479.4559678  477.39235305 477.03900635]
[36m(train_backprop_4dvar pid=174140)[0m  [677.93424813 675.25120966 674.79169314]]
[36m(train_backprop_4dvar pid=174140)[0m [[ 28.07867463  27.96845592  27.94959719]
[36m(train_backprop_4dvar pid=174140)[0m  [ 71.27030986  70.76291536  70.67626381]
[36m(train_backprop_4dvar pid=174140)[0m  ...




[36m(train_backprop_4dvar pid=174604)[0m [[ 28.07867463  27.99309721  27.91746117]
[36m(train_backprop_4dvar pid=174604)[0m  [ 71.03377489  70.64109965  70.29466798]
[36m(train_backprop_4dvar pid=174604)[0m  [148.88640619 148.13773051 147.47713208]
[36m(train_backprop_4dvar pid=174604)[0m  ...
[36m(train_backprop_4dvar pid=174604)[0m  [106.16251928 105.83383137 105.54309229]
[36m(train_backprop_4dvar pid=174604)[0m  [178.91854114 178.35934068 177.86469654]
[36m(train_backprop_4dvar pid=174604)[0m  [197.95393857 197.3394204  196.7958674 ]]




[36m(train_backprop_4dvar pid=175278)[0m  [151.61585216 151.55146328 151.49287799]
[36m(train_backprop_4dvar pid=175278)[0m  [898.62833594 898.37422216 898.14297045]
[36m(train_backprop_4dvar pid=175278)[0m  [783.4940789  783.28499972 783.09472328]
[36m(train_backprop_4dvar pid=175278)[0m  [916.38875952 916.1667505  915.96470695]]
[36m(train_backprop_4dvar pid=175278)[0m [[ 28.07867463  28.07144858  28.06487292]
[36m(train_backprop_4dvar pid=175278)[0m  [ 71.54741615  71.51400017  71.48359644]
[36m(train_backprop_4dvar pid=175278)[0m  ...




[36m(train_backprop_4dvar pid=175610)[0m [[ 28.07867463  27.96802021  27.92207406]
[36m(train_backprop_4dvar pid=175610)[0m  [ 71.16309967  70.65448307  70.44372868]
[36m(train_backprop_4dvar pid=175610)[0m  [149.5707017  148.59835972 148.19538909]
[36m(train_backprop_4dvar pid=175610)[0m  ...
[36m(train_backprop_4dvar pid=175610)[0m  [404.60990543 403.14143743 402.53100662]
[36m(train_backprop_4dvar pid=175610)[0m  [269.52097317 268.46627122 268.02776845]
[36m(train_backprop_4dvar pid=175610)[0m  [474.64260565 472.865109   472.12591576]]




[36m(train_backprop_4dvar pid=175908)[0m [[28.07867463 20.40680594 18.37526037]
[36m(train_backprop_4dvar pid=175908)[0m  [36.71642512 20.58626618 17.93322187]
[36m(train_backprop_4dvar pid=175908)[0m  [30.81212654 19.34242801 17.37349936]
[36m(train_backprop_4dvar pid=175908)[0m  [13.43922756 11.67455148 11.15176768]
[36m(train_backprop_4dvar pid=175908)[0m  [11.85292821  9.75606906  9.11168035]
[36m(train_backprop_4dvar pid=175908)[0m  [ 4.13281062  3.33239705  3.1030025 ]]
[36m(train_backprop_4dvar pid=175908)[0m  ...
[36m(train_backprop_4dvar pid=175908)[0m 




[36m(train_backprop_4dvar pid=176211)[0m [[ 28.07867463  26.90213716  26.19806853]
[36m(train_backprop_4dvar pid=176211)[0m  [ 65.70196775  60.75008426  57.86266261]
[36m(train_backprop_4dvar pid=176211)[0m  [122.2329138  113.82162631 108.90820591]
[36m(train_backprop_4dvar pid=176211)[0m  ...
[36m(train_backprop_4dvar pid=176211)[0m  [310.96593196 297.63808016 289.55845101]
[36m(train_backprop_4dvar pid=176211)[0m  [223.31724365 214.82559955 209.6551506 ]
[36m(train_backprop_4dvar pid=176211)[0m  [119.57849178 114.18900252 110.93657296]]




[36m(train_backprop_4dvar pid=176521)[0m [[ 28.07867463  61.54478706  25.47938382]
[36m(train_backprop_4dvar pid=176521)[0m  [  2.7142332    3.7506663    2.39657495]][32m [repeated 5x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
[36m(train_backprop_4dvar pid=176521)[0m  ...


2024-06-11 14:55:43,578	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/ksolvik/ray_results/train_backprop_4dvar_2024-06-11_14-49-25' in 0.0272s.


KeyboardInterrupt: 

In [None]:
full_results_df = pd.concat(all_results_df_list)
full_results_df.to_csv('./out/l96/raytune_l96_hessian_v4.csv')

### Experiments varying number of observations and obs error

In [10]:
space = {
    "lr": hp.loguniform("lr", -5, 0),
    "lr_decay": hp.uniform("lr_decay", 0.1, 0.99),
}

In [None]:
all_results_df_list = []

system_dim = 36
num_obs_list = [6, 12, 18, 24, 30, 36]
obs_error_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 1.5, 2.0]

for obs_location_count in num_obs_list:
    for obs_sd in obs_error_list:
        tune_time_start = timer()
        
        trainable_with_system_dim = tune.with_parameters(train_backprop_4dvar, 
                                                         system_dim=system_dim,
                                                         num_obs=obs_location_count,
                                                         obs_sd=obs_sd
                                                         )
        trainable_with_resources = tune.with_resources(trainable_with_system_dim, {"cpu": 4})

        hyperopt_search = HyperOptSearch(space, metric="rmse", mode="min")
        tuner = tune.Tuner(
            trainable_with_resources,
            tune_config=tune.TuneConfig(
                num_samples=20,
                search_alg=hyperopt_search,
            ),
        )
        
        results = tuner.fit()

        cur_results_df = results.get_dataframe()
        cur_results_df['system_dim'] = system_dim
        cur_results_df['num_obs'] = obs_location_count
        cur_results_df['obs_sd'] = obs_sd
        tune_time = timer() - tune_time_start
        cur_results_df['total_tune_time'] = tune_time
        
        all_results_df_list.append(cur_results_df)


0,1
Current time:,2024-06-11 15:01:06
Running for:,00:00:55.77
Memory:,14.8/30.2 GiB

Trial name,status,loc,lr,lr_decay,iter,total time (s),rmse
train_backprop_4dvar_22e61c85,RUNNING,192.168.0.249:180653,0.0489818,0.714696,,,
train_backprop_4dvar_d7e9b92e,RUNNING,192.168.0.249:180892,0.00770761,0.28534,,,
train_backprop_4dvar_30ca3b4b,PENDING,,0.162636,0.348245,,,
train_backprop_4dvar_1362ac85,TERMINATED,192.168.0.249:179748,0.065463,0.752681,1.0,18.0214,4.38151
train_backprop_4dvar_afe521ea,TERMINATED,192.168.0.249:179866,0.0761334,0.310982,1.0,17.7332,4.70712
train_backprop_4dvar_360314a0,TERMINATED,192.168.0.249:180014,0.935157,0.110244,1.0,17.3901,3.6583
train_backprop_4dvar_afe950b2,TERMINATED,192.168.0.249:180328,0.0104742,0.402635,1.0,15.4631,5.1357


In [None]:
full_results_df = pd.concat(all_results_df_list)
full_results_df.to_csv('./out/l96/raytune_werrors_heatmap_hessian_v4_alldims_bigger_error.csv')