In [1]:
import dabench as dab
import numpy as np
import jax
from timeit import default_timer as timer
import pandas as pd
import jaxlib

from ray import train, tune
from hyperopt import hp
from ray.tune.search.hyperopt import HyperOptSearch

# Define parameters

In [2]:
year_in_timesteps = 4380
spinup_size = 5*year_in_timesteps
valid_size = round(year_in_timesteps/4)
transient_size = 1*year_in_timesteps
test_size = 1*year_in_timesteps

In [3]:
nr_steps = spinup_size + valid_size + transient_size + test_size
delta_t=7200
analysis_window = 6*delta_t
analysis_time_in_window = 3*delta_t
num_epochs = 3

### Function definition: Backprop 4DVar

We'll need to prep and run Backprop-4DVar many times, so this wraps it all into functions.

Note: Since we're using raytune, we're separating this into a prep function and a run function to speed things up

In [4]:
def prep_backprop_4dvar(system_dim_xy, nr_steps, spinup_size, valid_size, test_size,
                        test_run, delta_t, sigma_bg_multiplier, sigma_obs_multiplier, 
                        random_seed):
    
    np_rng = np.random.default_rng(random_seed)
    jax.clear_backends()

    
    ### Nature Run
    nature_run = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy, delta_t=delta_t, 
                                  store_as_jax=True, random_seed=random_seed)

    nature_run.generate(n_steps=nr_steps) 
    nr_spinup, nr_valid, nr_transient_and_test = nature_run.split_train_valid_test(
        spinup_size, valid_size, transient_size + test_size)
    nr_transient, nr_test, _ = nr_transient_and_test.split_train_valid_test(
        transient_size, test_size, 0)

    if not test_run:
        nr_eval = nr_valid
    else:
        nr_eval = nr_test
        
        
    ### Observations
    obs_location_count = round(nature_run.system_dim/2)

    # First we need to calculate the per-variable SD for QGS model
    obs_sd_scale = 0.1
    per_variable_sd = np.std(nr_spinup.values, axis=0)
    obs_sd = 0.1*per_variable_sd

    obs_pyqg = dab.observer.Observer(
        nr_eval,
        time_indices = np.arange(0, nr_eval.time_dim, 3),
        random_location_count = obs_location_count,
        error_bias = 0.0,
        error_sd = obs_sd,
        random_seed=random_seed+test_run,
        stationary_observers=True,
        store_as_jax=True
    )

    obs_vec_pyqg = obs_pyqg.observe()

    
    ### Forecast Model
    model_pyqg = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy,
                                  store_as_jax=True, random_seed=random_seed)

    class PyQGModel(dab.model.Model):                                                                       
        """Defines model wrapper for Lorenz96 to test forecasting."""
        def forecast(self, state_vec, n_steps):
            gridded_values = state_vec.values.reshape(self.model_obj.original_dim)
            self.model_obj.generate(x0=gridded_values, n_steps=n_steps)
            new_vals = self.model_obj.values

            new_vec = dab.vector.StateVector(values=new_vals, store_as_jax=True)

            return new_vec

        def _forecast_x0(self, x0, n_steps):
            self.model_obj.generate(x0=x0.reshape(self.model_obj.original_dim),
                                    n_steps=n_steps)
            return self.model_obj.values

        def compute_tlm(self, state_vec, n_steps):
            x0 = state_vec.values
            return jax.jacrev(self._forecast_x0, argnums=0)(x0, n_steps), self._forecast_x0(x0, n_steps)

    fc_model = PyQGModel(model_obj=model_pyqg)

    
    ### Set up DA matrices: H (observation), R (obs error), B (background error)
    sigma_obs=sigma_obs_multiplier*obs_sd[obs_vec_pyqg.location_indices[0]]
    sigma_bg = sigma_bg_multiplier*obs_sd
    H = np.zeros((obs_location_count, nature_run.system_dim))
    H[np.arange(H.shape[0]), obs_vec_pyqg.location_indices[0]] = 1
    R = (sigma_obs**2) * np.identity(obs_location_count)
    B = (sigma_bg**2)*np.identity(nature_run.system_dim)
    
    
    ### Prep DA    
    dc = dab.dacycler.Var4DBackprop(
        system_dim=nature_run.system_dim,
        delta_t=nr_eval.delta_t,
        H=H,
        B=B,
        R=R,
        # These will be set later in run_backprop_4dvar
        num_epochs=None, 
        learning_rate=None,
        lr_decay=None,
        model_obj=fc_model,
        obs_window_indices=[0,3,6],
        steps_per_window=7, # 7 instead of 6 because inclusive of 0 and 6
        )
    
    # Generate initial conditions
    cur_tstep = 0
    x0_original = nr_eval.values[cur_tstep] + np_rng.normal(size=(nature_run.system_dim,),
                                                            scale=sigma_bg)
    x0_sv = dab.vector.StateVector(
        values=x0_original,
        store_as_jax=True)
    start_time = nr_eval.times[cur_tstep]
    
    # Return necessary objects for running DA with different LR/LR Decay
    return dc, x0_sv, start_time, obs_vec_pyqg, nr_eval, sigma_obs

In [1]:

def run_backprop_4dvar(config, dc, x0_sv, start_time, obs_vec, nr_eval,
                       sigma_obs, analysis_window, analysis_time_in_window,
                       num_epochs):
    
    dc.learning_rate = config['lr']
    dc.lr_decay = config['lr_decay']
    dc.num_epochs = num_epochs
    
    try: 
        out_statevec = dc.cycle(
            input_state = x0_sv,
            start_time = start_time,
            obs_vector = obs_vec,
            analysis_window=analysis_window,
            timesteps=int(nr_eval.time_dim/6) - 2,
            obs_error_sd=sigma_obs,
            analysis_time_in_window=analysis_time_in_window)

        rmse = np.sqrt(np.mean(np.square(
            nr_eval.values[:out_statevec.values.shape[0]] - out_statevec.values
        )))
        train.report({'rmse':rmse})
        
    # Catch problem with exploding gradients
    except jaxlib.xla_extension.XlaRuntimeError:
        train.report({'rmse':999999})
    
    return 

# Hyperparameter optimization on validation set

Using RayTune to optimize alpha (learning rate) and alpha decay

In [6]:
space = {
    "lr": hp.loguniform("lr", -10, 0),
    "lr_decay": hp.uniform("lr_decay", 0.1, 0.99),
}



In [7]:
%%time

all_results_df_list = []
system_dim_xy_list = [16, 24, 32]

for system_dim_xy in system_dim_xy_list:
    tune_time_start = timer()

    random_seed = system_dim_xy
    
    # Run prep
    run_dict = dict(
        system_dim_xy=system_dim_xy, 
        nr_steps=nr_steps,
        spinup_size=spinup_size,
        valid_size=valid_size,
        test_size=test_size,
        test_run=False,
        delta_t=delta_t,
        sigma_bg_multiplier=0.5,
        sigma_obs_multiplier=1.25,
        random_seed=random_seed)
    dc, x0_sv, start_time, obs_vec, nr_eval, sigma_obs = prep_backprop_4dvar(**run_dict)
    
    
    print('Starting... {} system dim'.format(system_dim_xy))
    trainable_with_system_dim = tune.with_parameters(
        run_backprop_4dvar, 
        dc=dc,
        x0_sv=x0_sv,
        start_time=start_time,
        obs_vec=obs_vec,
        nr_eval=nr_eval,
        sigma_obs=sigma_obs, 
        analysis_window=analysis_window,
        analysis_time_in_window=analysis_time_in_window,
        num_epochs=num_epochs
    )
    
    # Set CPU requirement to num cpus on your machine to NOT run in parallel
    # in order to get a better time comparison. Reduce this to run in parallel
    # and speed up the process.
    trainable_with_resources = tune.with_resources(trainable_with_system_dim, {"cpu": 16})
    hyperopt_search = HyperOptSearch(space, metric="rmse", mode="min")
    
    # Run tuner
    tuner = tune.Tuner(
        trainable_with_resources,
        tune_config=tune.TuneConfig(
            num_samples=50,
            search_alg=hyperopt_search,
        ),
    )
    results = tuner.fit()

    cur_results_df = results.get_dataframe()

    cur_results_df['system_dim_xy'] = system_dim_xy
    
    tune_time = timer() - tune_time_start
    
    cur_results_df['total_tune_time'] = tune_time

    all_results_df_list.append(cur_results_df)

0,1
Current time:,2024-03-25 22:32:24
Running for:,00:32:09.37
Memory:,8.7/30.2 GiB

Trial name,status,loc,lr,lr_decay,iter,total time (s),rmse
run_backprop_4dvar_3e4cf865,TERMINATED,192.168.0.249:907508,0.0521733,0.28766,1,43.3981,8.74666e-07
run_backprop_4dvar_233794b6,TERMINATED,192.168.0.249:907508,0.000751894,0.916643,1,37.7306,2.17627e-06
run_backprop_4dvar_cf42d1e0,TERMINATED,192.168.0.249:907508,0.000331725,0.450786,1,38.0711,2.35044e-06
run_backprop_4dvar_4de3792f,TERMINATED,192.168.0.249:907508,0.0010385,0.971847,1,38.5359,2.09045e-06
run_backprop_4dvar_9d54401c,TERMINATED,192.168.0.249:907508,9.06624e-05,0.363874,1,38.3749,2.40873e-06
run_backprop_4dvar_97581bd6,TERMINATED,192.168.0.249:907508,0.0206727,0.841053,1,37.9243,9.93514e-07
run_backprop_4dvar_49e32204,TERMINATED,192.168.0.249:907508,0.00025193,0.188882,1,38.1495,2.38363e-06
run_backprop_4dvar_1d1cafe0,TERMINATED,192.168.0.249:907508,0.194802,0.839249,1,38.4097,4.75683e-07
run_backprop_4dvar_d7ecb36c,TERMINATED,192.168.0.249:907508,0.0337546,0.255542,1,38.3943,1.05311e-06
run_backprop_4dvar_9d60f7ff,TERMINATED,192.168.0.249:907508,0.00111487,0.171681,1,38.7264,2.2531e-06


[36m(run_backprop_4dvar pid=907508)[0m 2024-03-25 22:00:22.538342: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
[36m(run_backprop_4dvar pid=907508)[0m 2024-03-25 22:00:22.970447: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
[36m(run_backprop_4dvar pid=907508)[0m 2024-03-25 22:00:23.083461: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
[36m(run_backprop_4dvar pid=907508)[0m 2024-03-25 22:00:26.246313: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
[36m(run_backprop_4dvar pid=907508)[0m 2024-03-25 22:00:26.296158: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none


[36m(run_backprop_4dvar pid=907508)[0m [[2221.41596682 2183.85257707 2175.2235743 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2173.01625192 2134.58696133 2125.7590487 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2136.60414203 2099.90330435 2091.4825808 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2410.10002074 2362.12413619 2351.0899557 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2457.61213631 2410.06449095 2399.1392089 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2329.42541594 2282.78466374 2272.0634738 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2499.69977601 2446.16737891 2433.8724011 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2683.91305304 2622.91824781 2608.8878787 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2727.92786552 2662.19174647 2647.0755678 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2789.38912044 2719.7306154  2703.7191148 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2818.6376595  2751.58649937 2736.1605504 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2932.266013

[36m(run_backprop_4dvar pid=907508)[0m [[ 2221.41596682  2221.16099743  2221.04622127]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2202.82741775  2202.55292362  2202.42935851]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2166.81132249  2166.55081168  2166.43354177]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2124.07377144  2123.80257025  2123.6804884 ]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2256.0632236   2255.78388874  2255.65814561]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2270.23078726  2269.93913119  2269.80784364]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2262.37710274  2262.0670631   2261.92749943]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2283.42185917  2283.09557466  2282.94869921]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2332.52740957  2332.1836802   2332.02895185]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2484.2875439   2483.90356861  2483.73072425]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2539.34479673  2538.94379352  2538.76328409]
[36m(run_backprop_4d

[36m(run_backprop_4dvar pid=907508)[0m [[ 2221.41596682  2221.34626173  2221.32090729]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2203.01443686  2202.93937049  2202.91206596]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2167.13959137  2167.068329    2167.04240817]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2124.5219813   2124.44777994  2124.4207901 ]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2256.64962702  2256.57318487  2256.54537997]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2271.00906383  2270.92922499  2270.90018471]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2263.3876474   2263.30275115  2263.27187125]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2284.75531936  2284.66593573  2284.63342366]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2334.28169377  2334.18748466  2334.15321738]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2486.44648976  2486.34122215  2486.30293252]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2541.98946096  2541.87947294  2541.83946632]
[36m(run_backprop_4d

[36m(run_backprop_4dvar pid=907508)[0m [[ 2221.41596682  2221.22231022  2221.18576776]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2202.93613415  2202.72761003  2202.68826217]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2167.00212321  2166.80418946  2166.76684008]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2124.33424777  2124.12816796  2124.08928153]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2256.40396649  2256.19168124  2256.15162389]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2270.68296796  2270.46127901  2270.41944768]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2262.96415583  2262.728453    2262.68397714]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2284.19641953  2283.94830372  2283.90148579]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2333.54629421  2333.28483772  2333.23550244]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2485.54127822  2485.24916388  2485.19404376]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2540.88036597  2540.57521333  2540.51763296]
[36m(run_backprop_4d

[36m(run_backprop_4dvar pid=907508)[0m  [2334.2906146  2304.90248012 2298.3828229 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2727.78300533 2682.43648729 2672.3781291 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2895.17130225 2842.3670275  2830.6518964 ]
[36m(run_backprop_4dvar pid=907508)[0m  [3026.6079422  2970.45245893 2957.9912365 ]
[36m(run_backprop_4dvar pid=907508)[0m  [3033.48923659 2978.07905734 2965.7738214 ]
[36m(run_backprop_4dvar pid=907508)[0m  [3092.23898508 3035.93243338 3023.4275862 ]
[36m(run_backprop_4dvar pid=907508)[0m  [4016.9095242  3929.42571299 3910.0066371 ]
[36m(run_backprop_4dvar pid=907508)[0m  [3991.45107363 3902.79916831 3883.1157049 ]
[36m(run_backprop_4dvar pid=907508)[0m  [4138.31101132 4046.41867179 4026.0207425 ]
[36m(run_backprop_4dvar pid=907508)[0m  [5419.51966816 5280.79423275 5250.0134339 ]
[36m(run_backprop_4dvar pid=907508)[0m  [5536.68684347 5391.68173444 5359.50065   ]
[36m(run_backprop_4dvar pid=907508)[0m  [5398.767008

[36m(run_backprop_4dvar pid=907508)[0m  [2315.07994717 2296.96104954 2287.4973119 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2500.03404023 2475.67246188 2462.9479821 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2813.72814944 2783.78324988 2768.1421247 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2853.99255824 2822.44438929 2805.9676925 ]
[36m(run_backprop_4dvar pid=907508)[0m  [3382.01381918 3337.13691084 3313.6927979 ]
[36m(run_backprop_4dvar pid=907508)[0m  [3914.31199713 3860.68697286 3832.6679271 ]
[36m(run_backprop_4dvar pid=907508)[0m  [4313.97967201 4250.56928215 4217.4432335 ]
[36m(run_backprop_4dvar pid=907508)[0m  [4920.79250985 4841.56485742 4800.1886972 ]
[36m(run_backprop_4dvar pid=907508)[0m  [5859.57834508 5762.21547746 5711.3666989 ]
[36m(run_backprop_4dvar pid=907508)[0m  [5737.15937395 5641.94825254 5592.2148642 ]
[36m(run_backprop_4dvar pid=907508)[0m  [6450.83386251 6338.58181235 6279.9518726 ]
[36m(run_backprop_4dvar pid=907508)[0m  [6061.744957

[36m(run_backprop_4dvar pid=907508)[0m [[ 2221.41596682  2220.29655179  2220.10889952]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2202.2859797   2201.08193649  2200.8800992 ]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2165.86213779  2164.72036194  2164.52896615]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2122.7795252   2121.59157292  2121.39243903]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2254.37197702  2253.14910649  2252.94412007]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2267.98860117  2266.7129455   2266.49912397]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2259.46922535  2258.11431001  2257.88719759]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2279.5885885   2278.16448278  2277.92577791]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2327.48974935  2325.99163929  2325.74052868]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2478.09588271  2476.42363472  2476.14333883]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2531.77062598  2530.02659608  2529.73426796]
[36m(run_backprop_4d

[36m(run_backprop_4dvar pid=907508)[0m [[ 2221.41596682  2220.35937567  2219.37665464]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2201.33050202  2200.19584855  2199.14053211]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2164.19145095  2163.11701037  2162.11771736]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2120.50785625  2119.39107062  2118.35240897]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2251.41087472  2250.26242319  2249.19431504]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2264.071912    2262.87574883  2261.76334636]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2254.40225292  2253.1336541   2251.95385335]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2272.92275024  2271.59230007  2270.35500888]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2318.7487922   2317.3526877   2316.05433035]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2467.38140539  2465.82503924  2464.37766088]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2518.70094424  2517.08159587  2515.57564229]
[36m(run_backprop_4d

[36m(run_backprop_4dvar pid=907508)[0m [[ 2221.41596682  2221.26617227  2221.16881338]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2202.89288205  2202.73159903  2202.62677334]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2166.92620604  2166.77312402  2166.67362884]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2124.23059518  2124.07122023  2123.96763517]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2256.26835835  2256.10419196  2255.99749279]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2270.50299319  2270.33156676  2270.22015001]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2262.73047521  2262.54822473  2262.42977257]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2283.88807295  2283.69624382  2283.57156653]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2333.14064596  2332.93852571  2332.80715971]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2485.0420719   2484.81626522  2484.66950475]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2540.26886392  2540.03300441  2539.87971021]
[36m(run_backprop_4d

[36m(run_backprop_4dvar pid=907508)[0m [[ 2221.41596682  2215.56073449  2211.78803213]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2195.8734296   2189.64312056  2185.62889635]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2154.75535232  2148.90322324  2145.13304054]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2107.83055371  2101.78088609  2097.88370332]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2235.06110097  2228.87379497  2224.88800196]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2242.65327747  2236.26427365  2232.15005418]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2226.98249458  2220.26033164  2215.93082065]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2237.15734883  2230.1921746   2225.70666639]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2272.27810834  2265.06807091  2260.42460809]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2411.05161874  2403.06615107  2397.92360685]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2450.79830406  2442.5920691   2437.30718721]
[36m(run_backprop_4d

[36m(run_backprop_4dvar pid=907508)[0m [[ 2221.41596682  2213.87113478  2212.38811297]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2197.53103628  2189.48044574  2187.89808419]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2157.60224746  2150.02195223  2148.53220308]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2111.62739485  2103.77864025  2102.23625305]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2239.92637498  2231.88619454  2230.30621849]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2248.98809043  2240.66503731  2239.03016097]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2235.03938715  2226.26168158  2224.53716828]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2247.61173175  2238.48428288  2236.69130453]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2285.78395202  2276.29748837  2274.43387784]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2427.30869241  2416.78299425  2414.71537917]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2470.25139868  2459.39599581  2457.26355261]
[36m(run_backprop_4d

[36m(run_backprop_4dvar pid=907508)[0m [[ 2221.41596682  2221.17725526  2221.10743421]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2202.88195928  2202.62494451  2202.54977006]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2166.90703561  2166.66309377  2166.59174332]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2124.20442334  2123.95045648  2123.87617401]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2256.23412087  2255.97252187  2255.89600715]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2270.45755642  2270.18439388  2270.10449805]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2262.67148416  2262.3810787   2262.29613912]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2283.8102384   2283.50457788  2283.41517685]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2333.03825672  2332.71620761  2332.62201306]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2484.91607824  2484.55629393  2484.45106269]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2540.11454376  2539.73875239  2539.62883931]
[36m(run_backprop_4d

[36m(run_backprop_4dvar pid=907508)[0m  [1923.99572435 1921.49667499 1863.7685443 ]
[36m(run_backprop_4dvar pid=907508)[0m  [1956.96912608 1954.71960047 1898.0921126 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2025.27022827 2023.97069773 1962.1302856 ]
[36m(run_backprop_4dvar pid=907508)[0m  [1972.52504309 1968.19028692 1910.1326596 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2042.47708529 2039.66844114 1978.5608887 ]
[36m(run_backprop_4dvar pid=907508)[0m  [1962.04028607 1959.98996362 1901.3546083 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2042.02831065 2039.59219088 1979.9339992 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2036.96262813 2032.74880827 1972.0013373 ]
[36m(run_backprop_4dvar pid=907508)[0m  [1983.9208479  1977.41664374 1923.0121885 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2044.47374001 2038.83667082 1983.4870849 ]
[36m(run_backprop_4dvar pid=907508)[0m  [1936.31603904 1935.56462383 1876.5671499 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2132.858593

[36m(run_backprop_4dvar pid=907508)[0m  [1957.82830439 1849.53578596 1848.3374568 ]
[36m(run_backprop_4dvar pid=907508)[0m  [1948.09902979 1844.48864281 1843.3378179 ]
[36m(run_backprop_4dvar pid=907508)[0m  [1981.1349421  1867.63384735 1866.381185  ]
[36m(run_backprop_4dvar pid=907508)[0m  [2083.676726   1962.6695237  1961.3830123 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2038.0601353  1918.74532562 1917.471143  ]
[36m(run_backprop_4dvar pid=907508)[0m  [2090.24718208 1959.99109719 1958.4952152 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2112.53375498 1984.32332317 1983.0570045 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2079.34961183 1961.39219457 1960.0052117 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2107.91347134 1981.33951214 1979.8865523 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2126.7444176  1998.73278987 1997.3535875 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2084.11622784 1961.61887523 1960.2762749 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2094.550658

[36m(run_backprop_4dvar pid=907508)[0m  [2099.57560468 2045.03961207 2041.1105781 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2103.51474369 2047.92592151 2043.9235443 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2193.56997031 2135.43997648 2131.2507801 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2168.40061729 2111.67970274 2107.5944926 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2170.9875697  2110.29280848 2105.9070026 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2385.27752048 2306.66938162 2301.0049665 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2562.12196379 2476.66185463 2470.481938  ]
[36m(run_backprop_4dvar pid=907508)[0m  [2627.74515223 2541.29909425 2535.039782  ]
[36m(run_backprop_4dvar pid=907508)[0m  [2505.66477881 2425.78552746 2420.0012765 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2971.13145228 2858.83438629 2850.7109674 ]
[36m(run_backprop_4dvar pid=907508)[0m  [3313.77756487 3175.79245719 3165.7918207 ]
[36m(run_backprop_4dvar pid=907508)[0m  [3315.487162

[36m(run_backprop_4dvar pid=907508)[0m [[2221.41596682 2077.05485793 2064.4772679 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2030.08276434 1927.76781308 1918.7999337 ]
[36m(run_backprop_4dvar pid=907508)[0m  [1948.93203904 1853.94492884 1845.2632629 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2044.32961065 1940.32036224 1930.8044567 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2035.68644954 1929.3625689  1920.0364526 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2114.95159558 2009.41084141 1999.6336656 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2031.39258726 1918.3792938  1908.4536171 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2052.32321501 1946.39011384 1936.8741053 ]
[36m(run_backprop_4dvar pid=907508)[0m  [1959.50286516 1855.47717917 1846.072391  ]
[36m(run_backprop_4dvar pid=907508)[0m  [2022.36788734 1906.07517186 1895.7187035 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2080.9486638  1965.89327141 1955.411588  ]
[36m(run_backprop_4dvar pid=907508)[0m  [2038.799836

[36m(run_backprop_4dvar pid=907508)[0m  [1972.39610272 1870.59390366 1864.636519  ]
[36m(run_backprop_4dvar pid=907508)[0m  [2096.99048934 1986.08639474 1979.5093328 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2067.2756573  1961.69191062 1955.3246193 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2113.1692311  1993.27526561 1986.2065353 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2066.58928873 1953.38375011 1946.6377692 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2202.1778491  2082.10585967 2074.8467109 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2205.11604179 2083.67019417 2076.3776099 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2134.26588998 2014.18036714 2007.0231037 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2168.57232443 2044.94668508 2037.5383932 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2211.8120608  2092.49627772 2085.3278058 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2144.38896029 2024.54910637 2017.2669858 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2138.158800

[36m(run_backprop_4dvar pid=907508)[0m [[2221.41596682 2124.46696839 2110.1337372 ]
[36m(run_backprop_4dvar pid=907508)[0m  [1991.69789757 1922.06523286 1911.6719304 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2086.34977349 2012.54617557 2001.5080119 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2118.33415908 2031.29103339 2018.4238593 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2040.19731619 1962.63246149 1951.0473991 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2176.25577182 2094.56719899 2082.4131613 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2146.24032913 2062.69001292 2050.2455737 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2142.85021554 2059.29272017 2046.8463976 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2179.97962638 2095.32208847 2082.6116431 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2118.27139456 2036.52638258 2024.2683505 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2118.28550813 2043.15245766 2031.8678834 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2109.001598

[36m(run_backprop_4dvar pid=907508)[0m  [2048.61343053 2005.1362226  1989.688648  ]
[36m(run_backprop_4dvar pid=907508)[0m  [2138.29417645 2093.53745005 2077.6377193 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2262.05326906 2206.68460097 2186.9927816 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2334.63705426 2276.4226063  2255.6583807 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2368.21646529 2301.91438181 2278.341712  ]
[36m(run_backprop_4dvar pid=907508)[0m  [2348.80959939 2290.29374858 2269.4299609 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2488.02447775 2419.7510311  2395.3995972 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2475.67129911 2403.68781523 2378.095082  ]
[36m(run_backprop_4dvar pid=907508)[0m  [2548.68479299 2475.75132352 2449.8236648 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2647.65510026 2573.52011056 2547.0549159 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2739.03441008 2660.77515318 2632.8307144 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2899.954535

[36m(run_backprop_4dvar pid=907508)[0m  [2010.27020344 1920.84346853 1908.2007546 ]
[36m(run_backprop_4dvar pid=907508)[0m  [1955.61550928 1874.8476123  1863.391852  ]
[36m(run_backprop_4dvar pid=907508)[0m  [1950.43758968 1871.72525478 1860.7926258 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2100.42508906 2010.49279375 1997.0233751 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2000.22307271 1913.98800686 1902.1552306 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2122.90229001 2026.01642815 2012.732523  ]
[36m(run_backprop_4dvar pid=907508)[0m  [2063.37784826 1973.94484765 1961.4490218 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2063.09520537 1976.62119975 1963.9584168 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2074.79303739 1983.21064364 1970.4460282 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2052.01878033 1964.57407222 1951.9428144 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2050.80393459 1959.84292261 1946.5751282 ]
[36m(run_backprop_4dvar pid=907508)[0m  [2083.855201

[36m(run_backprop_4dvar pid=907508)[0m [[ 2221.41596682  2219.71899953  2219.47289899]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2201.90745283  2200.08335236  2199.81881755]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2165.19959662  2163.47081351  2163.22010859]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2121.8776681   2120.07968137  2119.81894549]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2253.19528311  2251.34519212  2251.07690174]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2266.43073789  2264.50200079  2264.2223314 ]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2257.45188682  2255.40451467  2255.10763187]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2276.93259241  2274.78255926  2274.47080031]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2324.00393571  2321.74441337  2321.41677553]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2473.81858388  2471.29773347  2470.93220889]
[36m(run_backprop_4dvar pid=907508)[0m  [ 2526.54741697  2523.92083167  2523.53997455]
[36m(run_backprop_4d

2024-03-25 22:32:24,650	INFO tune.py:1047 -- Total run time: 1929.41 seconds (1928.66 seconds for the tuning loop).


CPU times: user 1min 29s, sys: 15.6 s, total: 1min 45s
Wall time: 52min 31s


In [8]:
full_results_df = pd.concat(all_results_df_list)
full_results_df.to_csv('./out/pyqg_jax/pyqg_jax_raytune_sgdopt_werrors_v13_hessian_approx_3epochs_system_dims_16_24_32.csv')