In [1]:
import dabench as dab
import numpy as np
import jax.numpy as jnp
import jax
from timeit import default_timer as timer
import pandas as pd
import pickle
from dabench.dacycler._var4d_backprop_exacthessian import Var4DBackpropExactHessian

# Set up

### Read-in Raytune Results

In [2]:
raytune_system_dim_results = pd.read_csv('./out/pyqg_jax/pyqg_jax_raytune_sgdopt_werrors_v13_hessian_approx_3epochs_system_dims_16_24_32.csv')
raytune_system_dim_results['trialnum'] = raytune_system_dim_results.index
raytune_system_dim_results.index = np.arange(raytune_system_dim_results.shape[0])
rows_to_get = raytune_system_dim_results.groupby(['system_dim_xy']).idxmin(numeric_only=True)['rmse']
best_results_system_dim = raytune_system_dim_results.loc[rows_to_get]

### Define parameters

In [4]:
year_in_timesteps = 4380
spinup_size = 5*year_in_timesteps
valid_size = round(year_in_timesteps/4)
transient_size = 1*year_in_timesteps
test_size = 1*year_in_timesteps # For system_dim_xy=64, used: round(year_in_timesteps/4)

In [5]:
nr_steps = spinup_size + valid_size + transient_size + test_size
delta_t=7200
analysis_window = 6*delta_t
analysis_time_in_window = 3*delta_t

# Function for running Backprop-4DVar

In [7]:
def run_backprop_4dvar(system_dim_xy, nr_steps, spinup_size, valid_size, test_size,
                       test_run, delta_t, sigma_bg_multiplier, sigma_obs_multiplier,
                       analysis_window, analysis_time_in_window, random_seed, num_iters,
                       learning_rate, lr_decay):
    
    np_rng = np.random.default_rng(random_seed)
    jax.clear_backends()


    ### Nature Run
    nature_run = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy, delta_t=delta_t, 
                                  store_as_jax=True, random_seed=random_seed)

    nature_run.generate(n_steps=nr_steps) 
    nr_spinup, nr_valid, nr_transient_and_test = nature_run.split_train_valid_test(
        spinup_size, valid_size, transient_size + test_size)
    nr_transient, nr_test, _ = nr_transient_and_test.split_train_valid_test(
        transient_size, test_size, 0)

    if not test_run:
        nr_eval = nr_valid
    else:
        nr_eval = nr_test
        
        
    ### Observations
    obs_location_count = round(nature_run.system_dim/2)

    # First we need to calculate the per-variable SD for QGS model
    obs_sd_scale = 0.1
    per_variable_sd = np.std(nr_spinup.values, axis=0)
    obs_sd = 0.1*per_variable_sd

    obs_pyqg = dab.observer.Observer(
        nr_eval,
        time_indices = np.arange(0, nr_eval.time_dim, 3),
        random_location_count = obs_location_count,
        error_bias = 0.0,
        error_sd = obs_sd,
        random_seed=random_seed+test_run,
        stationary_observers=True,
        store_as_jax=True
    )

    obs_vec_pyqg = obs_pyqg.observe()

    
    ### Forecast Model
    model_pyqg = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy, 
                                  store_as_jax=True, random_seed=random_seed)

    class PyQGModel(dab.model.Model):                                                                       
        """Defines model wrapper for forecasting."""
        def forecast(self, state_vec, n_steps):
            gridded_values = state_vec.values.reshape(self.model_obj.original_dim)
            self.model_obj.generate(x0=gridded_values, n_steps=n_steps)
            new_vals = self.model_obj.values

            new_vec = dab.vector.StateVector(values=new_vals, store_as_jax=True)

            return new_vec

        def _forecast_x0(self, x0, n_steps):
            self.model_obj.generate(x0=x0.reshape(self.model_obj.original_dim),
                                    n_steps=n_steps)
            return self.model_obj.values

    fc_model = PyQGModel(model_obj=model_pyqg)

    
    ### Set up DA matrices: H (observation), R (obs error), B (background error)
    sigma_obs=sigma_obs_multiplier*obs_sd[obs_vec_pyqg.location_indices[0]]
    sigma_bg = sigma_bg_multiplier*obs_sd
    H = np.zeros((obs_location_count, nature_run.system_dim))
    H[np.arange(H.shape[0]), obs_vec_pyqg.location_indices[0]] = 1
    R = (sigma_obs**2) * np.identity(obs_location_count)
    B = (sigma_bg**2)*np.identity(nature_run.system_dim)

    
    ### Run Data Assimilation
    da_time_start = timer()
    
    # Create DA Cycler object
    dc = dab.dacycler.Var4DBackprop(
        system_dim=nature_run.system_dim,
        delta_t=nr_eval.delta_t,
        H=H,
        B=B,
        R=R,
        num_iters=num_iters, 
        learning_rate=learning_rate,
        lr_decay=lr_decay,
        loss_growth_limit=2,
        model_obj=fc_model,
        obs_window_indices=[0,3,6],
        steps_per_window=7, # 7 instead of 6 because inclusive of 0 and 6
        )
    
    # Generate initial conditions
    cur_tstep = 0
    x0_original = nr_eval.values[cur_tstep] + np_rng.normal(size=(nature_run.system_dim,),
                                                            scale=sigma_bg)

    x0_sv = dab.vector.StateVector(
        values=x0_original,
        store_as_jax=True)
    
    # Execute
    out_statevec = dc.cycle(
        input_state = x0_sv,
        start_time = nr_eval.times[cur_tstep],
        obs_vector = obs_vec_pyqg,
        analysis_window=analysis_window,
        timesteps=int(nr_eval.time_dim/6) - 2,
        obs_error_sd=sigma_obs,
        analysis_time_in_window=analysis_time_in_window)
    
    da_time = timer() - da_time_start

    rmse = np.sqrt(np.mean(np.square(
        nr_eval.values[:out_statevec.values.shape[0]] - out_statevec.values)))
    
    return rmse, da_time, out_statevec,nr_eval

# Function for running 4DVar

In [6]:
def run_4dvar(system_dim_xy, nr_steps, spinup_size, valid_size, test_size,
              test_run, delta_t, sigma_bg_multiplier, sigma_obs_multiplier,
              analysis_window, analysis_time_in_window, random_seed, n_outer_loops):
    np_rng = np.random.default_rng(random_seed)
    jax.clear_backends()


    ### Nature Run
    nature_run = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy, delta_t=delta_t,
                                  store_as_jax=True, random_seed=random_seed)

    nature_run.generate(n_steps=nr_steps) 
    nr_spinup, nr_valid, nr_transient_and_test = nature_run.split_train_valid_test(
        spinup_size, valid_size, transient_size + test_size)
    nr_transient, nr_test, _ = nr_transient_and_test.split_train_valid_test(
        transient_size, test_size, 0)

    if not test_run:
        nr_eval = nr_valid
    else:
        nr_eval = nr_test


    ### Observations
    obs_location_count = round(nature_run.system_dim/2)

    # First we need to calculate the per-variable SD for QGS model
    obs_sd_scale = 0.1
    per_variable_sd = np.std(nr_spinup.values, axis=0)
    obs_sd = 0.1*per_variable_sd
    
    obs_pyqg = dab.observer.Observer(
        nr_eval,
        time_indices = np.arange(0, nr_eval.time_dim, 3),
        random_location_count = obs_location_count,
        error_bias = 0.0,
        error_sd = obs_sd,
        random_seed=random_seed+test_run,
        stationary_observers=True,
        store_as_jax=True
    )
    
    obs_vec_pyqg = obs_pyqg.observe()

    
    ### Forecast Model
    model_pyqg = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy,
                                  store_as_jax=True, random_seed=random_seed)

    class PyQGModel(dab.model.Model):                                                                       
        """Defines model wrapper for forecasting and computing TLM."""
        def forecast(self, state_vec, n_steps):
            gridded_values = state_vec.values.reshape(self.model_obj.original_dim)
            self.model_obj.generate(x0=gridded_values, n_steps=n_steps)
            new_vals = self.model_obj.values

            new_vec = dab.vector.StateVector(values=new_vals, store_as_jax=True)

            return new_vec

        def _forecast_x0(self, x0, n_steps):
            self.model_obj.generate(x0=x0.reshape(self.model_obj.original_dim),
                                    n_steps=n_steps)
            return self.model_obj.values

        def compute_tlm(self, state_vec, n_steps):
            x0 = state_vec.values
            return jax.jacrev(self._forecast_x0, argnums=0)(x0, n_steps), self._forecast_x0(x0, n_steps)
 
    fc_model = PyQGModel(model_obj=model_pyqg)

    
    ### Set up DA matrices: H (observation), R (obs error), B (background error)
    sigma_obs=sigma_obs_multiplier*obs_sd[obs_vec_pyqg.location_indices[0]]
    sigma_bg = sigma_bg_multiplier*obs_sd
    H = np.zeros((obs_location_count, nature_run.system_dim))
    H[np.arange(H.shape[0]), obs_vec_pyqg.location_indices[0]] = 1
    R = (sigma_obs**2) * np.identity(obs_location_count)
    B = (sigma_bg**2)*np.identity(nature_run.system_dim)

    ### Prep DA    
    da_time_start = timer()
    
    # Create DA Cycler object
    dc = dab.dacycler.Var4D(
        system_dim=nature_run.system_dim,
        delta_t=nr_eval.delta_t,
        H=H,
        B=B,
        R=R,
        n_outer_loops=n_outer_loops,
        model_obj=fc_model,
        obs_window_indices=[0,3,6],
        steps_per_window=7, # 0 and 6 inclusive
        )

    # Generate initial conditions
    cur_tstep = 0
    x0_original = nr_eval.values[cur_tstep] + np_rng.normal(size=(nature_run.system_dim,), scale=sigma_bg)

    x0_sv = dab.vector.StateVector(
        values=x0_original,
        store_as_jax=True)

    # Execute
    out_statevec = dc.cycle(
        input_state = x0_sv,
        start_time = nr_eval.times[cur_tstep],
        obs_vector = obs_vec_pyqg,
        analysis_window=analysis_window,
        timesteps=int(nr_eval.time_dim/6) - 2,
        obs_error_sd=sigma_obs,
        analysis_time_in_window=analysis_time_in_window)
    
    da_time = timer()-da_time_start

    rmse = np.sqrt(np.mean(np.square(nr_eval.values[:out_statevec.values.shape[0]] - out_statevec.values)))
    
    return rmse, da_time, out_statevec

# Function for running Backprop-4DVar with exact Hessian

In [None]:
def run_backprop_4dvar_exact_hessian(
        system_dim_xy, nr_steps, spinup_size, valid_size,test_size,
        test_run, delta_t, sigma_bg_multiplier, sigma_obs_multiplier,
        analysis_window, analysis_time_in_window, random_seed, num_iters,
        learning_rate, lr_decay):
    
    np_rng = np.random.default_rng(random_seed)
    jax.clear_backends()


    ### Nature Run
    nature_run = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy, delta_t=delta_t, 
                                  store_as_jax=True, random_seed=random_seed)

    nature_run.generate(n_steps=nr_steps) 
    nr_spinup, nr_valid, nr_transient_and_test = nature_run.split_train_valid_test(
        spinup_size, valid_size, transient_size + test_size)
    nr_transient, nr_test, _ = nr_transient_and_test.split_train_valid_test(
        transient_size, test_size, 0)

    if not test_run:
        nr_eval = nr_valid
    else:
        nr_eval = nr_test
        
        
    ### Observations
    obs_location_count = round(nature_run.system_dim/2)

    # First we need to calculate the per-variable SD for QGS model
    obs_sd_scale = 0.1
    per_variable_sd = np.std(nr_spinup.values, axis=0)
    obs_sd = 0.1*per_variable_sd

    obs_pyqg = dab.observer.Observer(
        nr_eval,
        time_indices = np.arange(0, nr_eval.time_dim, 3),
        random_location_count = obs_location_count,
        error_bias = 0.0,
        error_sd = obs_sd,
        random_seed=random_seed+test_run,
        stationary_observers=True,
        store_as_jax=True
    )

    obs_vec_pyqg = obs_pyqg.observe()

    
    ### Forecast Model
    model_pyqg = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy, 
                                  store_as_jax=True, random_seed=random_seed)

    class PyQGModel(dab.model.Model):                                                                       
        """Defines model wrapper for forecasting."""
        def forecast(self, state_vec, n_steps):
            gridded_values = state_vec.values.reshape(self.model_obj.original_dim)
            self.model_obj.generate(x0=gridded_values, n_steps=n_steps)
            new_vals = self.model_obj.values

            new_vec = dab.vector.StateVector(values=new_vals, store_as_jax=True)

            return new_vec

        def _forecast_x0(self, x0, n_steps):
            self.model_obj.generate(x0=x0.reshape(self.model_obj.original_dim),
                                    n_steps=n_steps)
            return self.model_obj.values

    fc_model = PyQGModel(model_obj=model_pyqg)

    
    ### Set up DA matrices: H (observation), R (obs error), B (background error)
    sigma_obs=sigma_obs_multiplier*obs_sd[obs_vec_pyqg.location_indices[0]]
    sigma_bg = sigma_bg_multiplier*obs_sd
    H = np.zeros((obs_location_count, nature_run.system_dim))
    H[np.arange(H.shape[0]), obs_vec_pyqg.location_indices[0]] = 1
    R = (sigma_obs**2) * np.identity(obs_location_count)
    B = (sigma_bg**2)*np.identity(nature_run.system_dim)

    
    ### Run Data Assimilation
    da_time_start = timer()
    
    # Create DA Cycler object
    dc = Var4DBackpropExactHessian(
        system_dim=nature_run.system_dim,
        delta_t=nr_eval.delta_t,
        H=H,
        B=B,
        R=R,
        num_iters=num_iters, 
        learning_rate=learning_rate,
        lr_decay=lr_decay,
        loss_growth_limit=2,
        model_obj=fc_model,
        obs_window_indices=[0,3,6],
        steps_per_window=7, # 7 instead of 6 because inclusive of 0 and 6
        )
    
    # Generate initial conditions
    cur_tstep = 0
    x0_original = nr_eval.values[cur_tstep] + np_rng.normal(size=(nature_run.system_dim,),
                                                            scale=sigma_bg)

    x0_sv = dab.vector.StateVector(
        values=x0_original,
        store_as_jax=True)
    
    # Execute
    out_statevec = dc.cycle(
        input_state = x0_sv,
        start_time = nr_eval.times[cur_tstep],
        obs_vector = obs_vec_pyqg,
        analysis_window=analysis_window,
        timesteps=int(nr_eval.time_dim/6) - 2,
        obs_error_sd=sigma_obs,
        analysis_time_in_window=analysis_time_in_window)
    
    da_time = timer() - da_time_start

    rmse = np.sqrt(np.mean(np.square(
        nr_eval.values[:out_statevec.values.shape[0]] - out_statevec.values)))
    
    return rmse, da_time, out_statevec,nr_eval

# Function for baserun without DA

In [9]:
def run_baserun(system_dim_xy, nr_steps, spinup_size, valid_size, test_size,
                test_run, delta_t,random_seed, sigma_bg_multiplier):
    
    np_rng = np.random.default_rng(random_seed)
    jax.clear_backends()


    ### Nature Run
    nature_run = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy, delta_t=delta_t,
                                  store_as_jax=True, random_seed=random_seed)

    nature_run.generate(n_steps=nr_steps) 
    nr_spinup, nr_valid, nr_transient_and_test = nature_run.split_train_valid_test(
        spinup_size, valid_size, transient_size + test_size)
    nr_transient, nr_test, _ = nr_transient_and_test.split_train_valid_test(
        transient_size, test_size, 0)

    if not test_run:
        nr_eval = nr_valid
    else:
        nr_eval = nr_test
        
        
    ### Forecast Model
    model_pyqg = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy, store_as_jax=True, random_seed=random_seed)
    
    
    ### Initial conditions
    da_time_start = timer()
    
    cur_tstep = 0
    x0_original = nr_eval.values[cur_tstep] + np_rng.normal(size=(nature_run.system_dim,), scale=sigma_bg)
    x0_gridded = x0_original.reshape(nature_run.original_dim)
    
    
    ### Run forecast model
    model_pyqg.generate(x0=x0_gridded, n_steps=nr_eval.time_dim)
    
    rmse = np.sqrt(np.mean(np.square(model_pyqg.values[:-12] - nr_eval.values[:-12])))

    da_time = timer()-da_time_start
    
    return rmse, da_time, dab.vector.StateVector(values=model_pyqg.values, times=model_pyqg.times, store_as_jax=True)

# Run DA for test period

### Baserun (No DA)

In [10]:
all_results_df_list_baserun = []
all_statevecs_baserun = []
system_dim_xy_list = [16, 24, 32]
test_run=True

for system_dim_xy in system_dim_xy_list:
    
    random_seed = system_dim_xy
    
    run_dict = dict(
        system_dim_xy=system_dim_xy, 
        nr_steps=nr_steps,
        spinup_size=spinup_size,
        valid_size=valid_size,
        test_size=test_size,
        test_run=test_run,
        delta_t=delta_t,
        sigma_bg_multiplier=0.5,
        random_seed=random_seed)
    
    error_br, da_time, out_sv = run_baserun(**run_dict)
    
    out_df = pd.DataFrame(run_dict,index=[0])
    out_df['rmse'] = error_br
    out_df['da_time'] = da_time
    
    all_results_df_list_baserun.append(out_df)
    all_statevecs_baserun.append(out_sv)

2024-04-23 15:59:38.053268: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 15:59:38.082234: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 15:59:38.108554: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none


Initial condition not set. Start with random IC.


2024-04-23 15:59:38.555965: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 15:59:38.810713: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 15:59:38.835959: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 15:59:39.248251: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 15:59:39.285309: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 15:59:39.286581: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 16:00:12.293788: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 16:00:16.395715: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 16:00:32.917223: E external/o

1.8253622554103463e-06
2.919982246996369


In [11]:
# Write out csv
full_out_df_baserun = pd.concat(all_results_df_list_baserun)
full_out_df_baserun.to_csv('./out/pyqg_jax/pyqg_baserun_results_test_v1.csv')

In [12]:
# Write out statevec
for i in range(full_out_df_baserun.shape[0]):
    sysdimxy = full_out_df_baserun['system_dim_xy'].values[i]
    out_file = './out/pyqg_jax/pyqg_baserun_results_sysdimxy_{}_v1.pkl'.format(sysdimxy)
    out_vec = all_statevecs_baserun[i]
    
    with open(out_file, 'wb') as f:  # open a text file
         pickle.dump(out_vec, f) # serialize the list
    f.close()

### 4Dvar: 3 outer

In [13]:
n_outer_loops = 3
all_results_df_list_4d = []
all_statevecs_4d = []
system_dim_xy_list = [16, 24, 32]
test_run=True

for system_dim_xy in system_dim_xy_list:
    
    random_seed = system_dim_xy
    
    run_dict = dict(
        system_dim_xy=system_dim_xy, 
        nr_steps=nr_steps,
        spinup_size=spinup_size,
        valid_size=valid_size,
        test_size=test_size,
        test_run=test_run,
        delta_t=delta_t,
        sigma_bg_multiplier=0.5,
        sigma_obs_multiplier=1.25,
        analysis_window=analysis_window,
        analysis_time_in_window=analysis_time_in_window,
        random_seed=random_seed,
        n_outer_loops=n_outer_loops)
    
    error_4d, da_time, out_sv = run_4dvar(**run_dict)
    
    out_df = pd.DataFrame(run_dict,index=[0])
    out_df['rmse'] = error_4d
    out_df['da_time'] = da_time

    all_results_df_list_4d.append(out_df)
    all_statevecs_4d.append(out_sv)

In [14]:
# Write csv
full_out_df_4d = pd.concat(all_results_df_list_4d)
print(full_out_df_4d)
full_out_df_4d.to_csv('./out/pyqg_jax/pyqg_4dvar_results_test_v4_3outer_sysdimxy_all.csv')

In [15]:
# Write statevecs
for i in range(full_out_df_4d.shape[0]):
    sysdimxy = full_out_df_4d['system_dim_xy'].values[i]
    out_file = './out/pyqg_jax/pyqg_4dvar_results_sysdimxy_{}_3outer_v4.pkl'.format(sysdimxy)
    out_vec = all_statevecs_4d[i]
    
    with open(out_file, 'wb') as f:
         pickle.dump(out_vec, f)
            
    f.close()

### Backprop-4DVar: Approx hessian, 3 iters

In [13]:
num_iters = 3
all_results_df_list_bp = []
all_statevecs_bp = []
all_losses_bp = []
system_dim_xy_list = [16, 24, 32]
test_run=True

for system_dim_xy in system_dim_xy_list:
    
    random_seed=system_dim_xy
    raytune_results = best_results_system_dim.loc[best_results_system_dim['system_dim_xy']==system_dim_xy]
    lr = raytune_results['config/lr'].values[0]
    lr_decay = raytune_results['config/lr_decay'].values[0]
    
    run_dict = dict(
        system_dim_xy=system_dim_xy, 
        nr_steps=nr_steps,
        spinup_size=spinup_size,
        valid_size=valid_size,
        test_size=test_size,
        test_run=test_run,
        delta_t=delta_t,
        sigma_bg_multiplier=0.5,
        sigma_obs_multiplier=1.25,
        analysis_window=analysis_window,
        analysis_time_in_window=analysis_time_in_window,
        random_seed=random_seed,
        num_iters=num_iters,
        learning_rate=lr,
        lr_decay=lr_decay)
    
    error_bp, da_time, out_sv, nr = run_backprop_4dvar(**run_dict)
    
    out_df = pd.DataFrame(run_dict,index=[0])
    out_df['rmse'] = error_bp
    out_df['da_time'] = da_time
    all_results_df_list_bp.append(out_df)
    all_statevecs_bp.append(out_sv)

2024-04-23 16:00:47.531586: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 16:00:47.572784: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 16:00:47.613974: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none


Initial condition not set. Start with random IC.


2024-04-23 16:00:48.134879: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 16:00:48.453977: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 16:00:48.489364: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 16:00:48.933868: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 16:00:48.976229: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 16:00:48.979162: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 16:01:22.854453: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 16:01:26.512400: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-23 16:01:39.391428: E external/o

(8192, 8192)
float64
[[8706.85376882 8112.3725353  8083.54870652]
 [8123.02825321 7659.12954144 7637.29871796]
 [8054.45971577 7606.223307   7586.15061114]
 [8119.79059862 7674.82630802 7654.74147019]
 [8138.87392013 7671.56240651 7650.52627788]
 [8116.43660138 7684.27389402 7664.97779847]
 [8230.64942185 7781.16643517 7760.61611543]
 [8284.7205033  7827.6981503  7806.83298094]
 [8263.09431792 7789.29053418 7766.67952608]
 [8102.82467296 7649.90036716 7628.70191163]
 [8178.88898891 7739.94441641 7720.21366087]
 [8263.68438628 7791.54069614 7769.65155523]
 [8141.76586592 7683.13207673 7662.60497206]
 [8266.26724526 7782.06830598 7760.32412993]
 [8293.6902023  7828.77044042 7807.91759905]
 [8213.81675037 7727.21584246 7704.97001326]
 [8264.36397552 7772.57513431 7749.23813162]
 [8437.67387061 7944.76060867 7922.68906907]
 [8348.53753043 7872.50567186 7850.03578748]
 [8368.64326154 7883.10989434 7859.93284118]
 [8296.65942487 7803.46060039 7780.83369934]
 [8339.38671855 7849.39705915 7826

In [14]:
full_out_df_bp = pd.concat(all_results_df_list_bp)
full_out_df_bp.to_csv('./out/pyqg_jax/pyqg_bp_results_test_v8_hessian_approx_3epoch.csv')

In [15]:
# Write statevecs
for i in range(full_out_df_bp.shape[0]):
    sysdimxy = full_out_df_bp['system_dim_xy'].values[i]
    out_file = './out/pyqg_jax/pyqg_bp_results_sysdimxy_{}_3epoch_hessian_approx_v8.pkl'.format(sysdimxy)
    out_vec = all_statevecs_bp[i]
    
    with open(out_file, 'wb') as f:  # open a text file
         pickle.dump(out_vec, f) # serialize the list
    f.close()

### Backprop-4DVar: Exact Hessian, 3 iters

In [23]:
num_iters = 3
all_results_df_list_bp = []
all_statevecs_bp = []
all_losses_bp = []
system_dim_xy_list = [16, 24, 32]
test_run=True

for system_dim_xy in system_dim_xy_list:
    random_seed=system_dim_xy
    lr = 1.0 
    lr_decay = 1.0

    run_dict = dict(
        system_dim_xy=system_dim_xy, 
        nr_steps=nr_steps,
        spinup_size=spinup_size,
        valid_size=valid_size,
        test_size=test_size,
        test_run=test_run,
        delta_t=delta_t,
        sigma_bg_multiplier=0.5,
        sigma_obs_multiplier=1.25,
        analysis_window=analysis_window,
        analysis_time_in_window=analysis_time_in_window,
        random_seed=random_seed,
        num_iters=num_iters,
        learning_rate=lr,
        lr_decay=lr_decay)
    
    error_bp, da_time, out_sv, nr = run_backprop_4dvar_exact_hessian(**run_dict)
    
    out_df = pd.DataFrame(run_dict,index=[0])
    out_df['rmse'] = error_bp
    out_df['da_time'] = da_time
    
    all_results_df_list_bp.append(out_df)
    all_statevecs_bp.append(out_sv)


2024-04-18 22:00:09.861479: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:00:09.883884: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:00:09.906283: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none


Initial condition not set. Start with random IC.


2024-04-18 22:00:10.309719: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:00:10.533243: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:00:10.558670: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:00:10.859391: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:00:10.881004: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:00:10.881940: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:00:14.486677: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:00:14.659838: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:00:15.727946: E external/o

[[499.70457884 460.6505436  460.65053852]
 [556.86785065 516.41138025 516.4113744 ]
 [503.78349967 471.65423206 471.65423057]
 [496.81478534 470.22728029 470.2272792 ]
 [490.56735776 464.4775804  464.47757923]
 [523.20147885 490.42153133 490.4215295 ]
 [481.21216642 447.74750858 447.74750619]
 [507.25197971 481.29090786 481.29090658]
 [456.69031399 428.19454677 428.19454413]
 [461.24866837 430.25006223 430.25006037]
 [518.75713563 486.13757616 486.13757343]
 [524.34610924 491.51137069 491.51136813]
 [461.77028504 434.84543654 434.84543569]
 [454.00790843 426.33871858 426.33871627]
 [477.66486769 450.22352656 450.22352523]
 [448.53234161 421.99033605 421.99033485]
 [509.8157372  481.0448819  481.04488121]
 [507.44636457 477.38828543 477.38828383]
 [544.21535942 512.88715144 512.88714864]
 [548.52998388 521.09031183 521.09030952]
 [470.49877976 440.44120535 440.44120349]
 [519.49200934 487.52217324 487.52216963]
 [536.10480889 502.56162166 502.56161988]
 [482.16355873 455.02714686 455.02

2024-04-18 22:04:38.320685: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:04:38.353239: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:04:38.388500: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none


Initial condition not set. Start with random IC.


2024-04-18 22:04:38.832330: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:04:39.062427: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:04:39.096297: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:04:39.484091: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:04:39.522111: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:04:39.523208: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:04:45.526315: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:04:45.921380: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:04:47.035177: E external/o

[[1249.60632189 1154.02185152 1154.02184582]
 [1234.73012742 1160.74742037 1160.74741463]
 [1198.89531881 1130.07776314 1130.07776045]
 [1135.41207415 1065.40881967 1065.40881656]
 [1172.51682808 1096.48661647 1096.48661406]
 [1092.91031835 1028.83830229 1028.83829888]
 [1134.50849782 1064.52935417 1064.529352  ]
 [1113.94125714 1043.45712025 1043.4571181 ]
 [1126.79726883 1063.07779073 1063.07778869]
 [1127.28483786 1062.78611018 1062.78610821]
 [1088.76970365 1024.48416634 1024.4841623 ]
 [1050.23845363  991.37390039  991.37389733]
 [1096.95366019 1032.72828258 1032.72828025]
 [1078.32236752 1015.59506058 1015.59505909]
 [1122.91311967 1053.61431326 1053.61431073]
 [1144.15338355 1077.1956134  1077.19560764]
 [1112.08524173 1046.76237782 1046.76237477]
 [1195.38830286 1119.06379754 1119.06379325]
 [1127.54764057 1065.26991264 1065.26991073]
 [1091.88554396 1031.66580871 1031.66580754]
 [1101.49727081 1032.53297893 1032.53297521]
 [1128.23950359 1058.75652911 1058.75652533]
 [1055.587

2024-04-18 22:24:27.417076: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:24:27.458174: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:24:27.499201: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none


Initial condition not set. Start with random IC.


2024-04-18 22:24:28.002435: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:24:28.226849: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:24:28.250999: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:24:28.558638: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:24:28.580298: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:24:28.581239: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:24:37.604131: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:24:38.265128: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:461] fastpath_data is none
2024-04-18 22:24:40.048953: E external/o

[[2197.84394749 2021.10042308 2021.1004159 ]
 [1992.68583221 1866.34692403 1866.34692104]
 [1940.10642346 1830.92505776 1830.92505546]
 [1913.45564699 1809.5990573  1809.59905539]
 [1945.67057868 1836.77227087 1836.77226815]
 [1991.21301871 1883.38393037 1883.38392779]
 [2061.74582092 1942.68488474 1942.68488009]
 [1973.82400842 1855.90949028 1855.9094875 ]
 [2011.01229146 1891.47057301 1891.47056932]
 [1976.88410319 1862.76202165 1862.76201886]
 [2023.88659238 1910.23492811 1910.2349254 ]
 [2075.29119243 1949.72488126 1949.72487764]
 [1978.9255286  1864.4963815  1864.49637847]
 [1940.3577523  1824.44856092 1824.44855891]
 [2030.95283828 1903.19264668 1903.19264221]
 [2064.12771559 1950.89135585 1950.89135354]
 [2034.17770014 1928.12644495 1928.12644274]
 [2001.99750784 1893.29923325 1893.29923034]
 [1957.19826401 1841.71305995 1841.71305789]
 [1916.02323924 1812.58270003 1812.58269814]
 [1964.66665464 1864.90268672 1864.90268515]
 [1914.98547877 1808.55003966 1808.55003645]
 [1952.183

In [24]:
full_out_df_bp = pd.concat(all_results_df_list_bp)
full_out_df_bp.to_csv('./out/pyqg_jax/pyqg_bp_results_test_v8_hessian_approx_3epoch_exact_hessian_noraytune_1_1.csv')

In [25]:
# Write statevecs
for i in range(full_out_df_bp.shape[0]):
    sysdimxy = full_out_df_bp['system_dim_xy'].values[i]
    out_file = './out/pyqg_jax/pyqg_bp_results_sysdimxy_{}_3epoch_hessian_exact_v8.pkl'.format(sysdimxy)
    out_vec = all_statevecs_bp[i]
    
    with open(out_file, 'wb') as f:
         pickle.dump(out_vec, f)
    f.close()