In [None]:
import dabench as dab
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
from timeit import default_timer as timer
import pandas as pd
import pickle

# Read-in Raytune Results

In [None]:
raytune_system_dim_results = pd.read_csv('./pyqg_jax_raytune_sgdopt_werrors_v10_10epochs_system_dims_16_20_24_32.csv')
raytune_system_dim_results['trialnum'] = raytune_system_dim_results.index
raytune_system_dim_results.index = np.arange(raytune_system_dim_results.shape[0])
rows_to_get = raytune_system_dim_results.groupby(['system_dim_xy']).idxmin(numeric_only=True)['rmse']
best_results_system_dim = raytune_system_dim_results.loc[rows_to_get]

In [None]:
best_results_system_dim

# Define some parameters

In [None]:
year_in_timesteps = 4380
spinup_size = 5*year_in_timesteps
valid_size = round(year_in_timesteps/4)
transient_size = 1*year_in_timesteps
test_size = 1*year_in_timesteps

In [None]:
nr_steps = spinup_size + valid_size + transient_size + test_size
delta_t=7200
analysis_window = 6*delta_t
analysis_time_in_window = 3*delta_t
num_epochs = 10
n_outer_loops = 3

# Function for running 4DVar

In [None]:
def run_4dvar(system_dim_xy, nr_steps, spinup_size, valid_size, test_size, test_run, delta_t,
              sigma_bg_multiplier, sigma_obs_multiplier, analysis_window, analysis_time_in_window,
              random_seed, n_outer_loops):
    np_rng = np.random.default_rng(random_seed)
    jax.clear_backends()


    ### Nature Run
    nature_run = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy, delta_t=delta_t, store_as_jax=True, random_seed=random_seed)

    nature_run.generate(n_steps=nr_steps) 
    nr_spinup, nr_valid, nr_transient_and_test = nature_run.split_train_valid_test(
        spinup_size, valid_size, transient_size + test_size)
    nr_transient, nr_test, _ = nr_transient_and_test.split_train_valid_test(transient_size, test_size, 0)

    if not test_run:
        nr_eval = nr_valid
    else:
        nr_eval = nr_test
        
    ### Observations
    obs_location_count = round(nature_run.system_dim/2)

    # First we need to calculate the per-variable SD for QGS model
    obs_sd_scale = 0.1
    per_variable_sd = np.std(nr_spinup.values, axis=0)
    obs_sd = 0.1*per_variable_sd

    obs_pyqg = dab.observer.Observer(
        nr_eval, # Data generator object
        time_indices = np.arange(0, nr_eval.time_dim, 3), # Observation every other timestep
        random_location_count = obs_location_count, # Probability of picking each location in l63.system_dim for random sampling
        error_bias = 0.0, # Mean for observation error, Gaussian/Normal distribution
        error_sd = obs_sd, # Standard deviation for observation error, Gaussian/Normal distribution
        random_seed=random_seed+test_run, # We can specify a random seed. Default is 99
        stationary_observers=True,
        store_as_jax=True
    )

    obs_vec_pyqg = obs_pyqg.observe()

    ### Model
    model_pyqg = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy, store_as_jax=True, random_seed=random_seed)

    class PyQGModel(dab.model.Model):                                                                       
        """Defines model wrapper for Lorenz96 to test forecasting."""
        def forecast(self, state_vec, n_steps):
            gridded_values = state_vec.values.reshape(self.model_obj.original_dim)
            self.model_obj.generate(x0=gridded_values, n_steps=n_steps)
            new_vals = self.model_obj.values

            new_vec = dab.vector.StateVector(values=new_vals, store_as_jax=True)

            return new_vec

        def _forecast_x0(self, x0, n_steps):
            self.model_obj.generate(x0=x0.reshape(self.model_obj.original_dim), n_steps=n_steps)
            return self.model_obj.values

        def compute_tlm(self, state_vec, n_steps):
            x0 = state_vec.values
            return jax.jacrev(self._forecast_x0, argnums=0)(x0, n_steps), self._forecast_x0(x0, n_steps)

    fc_model = PyQGModel(model_obj=model_pyqg)

    
    ### Set up DA matrices
    sigma_obs=sigma_obs_multiplier*obs_sd[obs_vec_pyqg.location_indices[0]]
    sigma_bg = sigma_bg_multiplier*obs_sd
    H = np.zeros((obs_location_count, nature_run.system_dim))
    H[np.arange(H.shape[0]), obs_vec_pyqg.location_indices[0]] = 1
    R = (sigma_obs**2) * np.identity(obs_location_count)
    B = (sigma_bg**2)*np.identity(nature_run.system_dim)
    Bsqrt = np.sqrt(B)


    da_time_start = timer()
    
    dc = dab.dacycler.Var4D(
        system_dim=nature_run.system_dim,
        delta_t=nr_eval.delta_t,
        H=H,
        B=B,
        R=R,
        n_outer_loops=n_outer_loops,
        model_obj=fc_model,
        obs_window_indices=[0,3,6],
        steps_per_window=7, # 0 and 6 inclusive
        )

    
    ### Execute
    cur_tstep = 0
    x0_original = nr_eval.values[cur_tstep] + np_rng.normal(size=(nature_run.system_dim,), scale=sigma_bg)

    x0_sv = dab.vector.StateVector(
        values=x0_original,
        store_as_jax=True)
    
    out_statevec, losses = dc.cycle(
        input_state = x0_sv,
        start_time = nr_eval.times[cur_tstep],
        obs_vector = obs_vec_pyqg,
        analysis_window=analysis_window,
        timesteps=int(nr_eval.time_dim/6) - 2,
        obs_error_sd=sigma_obs,
        analysis_time_in_window=analysis_time_in_window)
    
    da_time = timer()-da_time_start

    rmse = np.sqrt(np.mean(np.square(nr_eval.values[:out_statevec.values.shape[0]] - out_statevec.values)))
    
    
    return rmse, da_time, out_statevec, losses

# Function for running Backprop-4DVar

In [None]:
def run_backprop_4dvar(system_dim_xy, nr_steps, spinup_size, valid_size, test_size, test_run, delta_t,
              sigma_bg_multiplier, sigma_obs_multiplier, analysis_window, analysis_time_in_window,
              random_seed, num_epochs, learning_rate, lr_decay):
    np_rng = np.random.default_rng(random_seed)
    jax.clear_backends()


    ### Nature Run
    nature_run = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy, delta_t=delta_t, store_as_jax=True, random_seed=random_seed)

    nature_run.generate(n_steps=nr_steps) 
    nr_spinup, nr_valid, nr_transient_and_test = nature_run.split_train_valid_test(
        spinup_size, valid_size, transient_size + test_size)
    nr_transient, nr_test, _ = nr_transient_and_test.split_train_valid_test(transient_size, test_size, 0)

    if not test_run:
        nr_eval = nr_valid
    else:
        nr_eval = nr_test
        
    ### Observations
    obs_location_count = round(nature_run.system_dim/2)

    # First we need to calculate the per-variable SD for QGS model
    obs_sd_scale = 0.1
    per_variable_sd = np.std(nr_spinup.values, axis=0)
    obs_sd = 0.1*per_variable_sd

    obs_pyqg = dab.observer.Observer(
        nr_eval, # Data generator object
        time_indices = np.arange(0, nr_eval.time_dim, 3), # Observation every other timestep
        random_location_count = obs_location_count, # Probability of picking each location in l63.system_dim for random sampling
        error_bias = 0.0, # Mean for observation error, Gaussian/Normal distribution
        error_sd = obs_sd, # Standard deviation for observation error, Gaussian/Normal distribution
        random_seed=random_seed+test_run, # We can specify a random seed. Default is 99
        stationary_observers=True,
        store_as_jax=True
    )

    obs_vec_pyqg = obs_pyqg.observe()

    ### Model
    model_pyqg = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy, store_as_jax=True, random_seed=random_seed)

    class PyQGModel(dab.model.Model):                                                                       
        """Defines model wrapper for Lorenz96 to test forecasting."""
        def forecast(self, state_vec, n_steps):
            gridded_values = state_vec.values.reshape(self.model_obj.original_dim)
            self.model_obj.generate(x0=gridded_values, n_steps=n_steps)
            new_vals = self.model_obj.values

            new_vec = dab.vector.StateVector(values=new_vals, store_as_jax=True)

            return new_vec

        def _forecast_x0(self, x0, n_steps):
            self.model_obj.generate(x0=x0.reshape(self.model_obj.original_dim), n_steps=n_steps)
            return self.model_obj.values

        def compute_tlm(self, state_vec, n_steps):
            x0 = state_vec.values
            return jax.jacrev(self._forecast_x0, argnums=0)(x0, n_steps), self._forecast_x0(x0, n_steps)

    fc_model = PyQGModel(model_obj=model_pyqg)

    
    ### Set up DA matrices    
    obs_total_size = int(obs_location_count*3)
    sigma_obs=sigma_obs_multiplier*np.tile(obs_sd[obs_vec_pyqg.location_indices[0]], 3)
    sigma_bg = sigma_bg_multiplier*obs_sd
    H = np.zeros((obs_location_count, nature_run.system_dim))
    H[np.arange(H.shape[0]), np.tile(obs_vec_pyqg.location_indices[0], 1)] = 1
    R = (sigma_obs**2)* np.identity(obs_total_size)
    B = (sigma_bg**2)*np.identity(nature_run.system_dim)
    Bsqrt = np.sqrt(B)

    ### Set up DA Cycler
    da_time_start = timer()
    
    dc = dab.dacycler.Var4DBackprop(
        system_dim=nature_run.system_dim,
        delta_t=nr_eval.delta_t,
        H=H,
        B=B,
        R=R,
        # These will be set later in run_backprop_4dvar
        num_epochs=num_epochs, 
        learning_rate=learning_rate,
        lr_decay=lr_decay,
        model_obj=fc_model,
        obs_window_indices=[0,3,6],
        steps_per_window=7, # 7 instead of 6 because inclusive of 0 and 6
        )
    
    ### Execute
    cur_tstep = 0
    x0_original = nr_eval.values[cur_tstep] + np_rng.normal(size=(nature_run.system_dim,), scale=sigma_bg)

    x0_sv = dab.vector.StateVector(
        values=x0_original,
        store_as_jax=True)
    

    out_statevec, losses = dc.cycle(
        input_state = x0_sv,
        start_time = nr_eval.times[cur_tstep],
        obs_vector = obs_vec_pyqg,
        analysis_window=analysis_window,
        timesteps=int(nr_eval.time_dim/6) - 2,
        obs_error_sd=sigma_obs,
        analysis_time_in_window=analysis_time_in_window)
    
    da_time = timer()-da_time_start
    print(out_statevec, losses)

    rmse = np.sqrt(np.mean(np.square(nr_eval.values[:out_statevec.values.shape[0]] - out_statevec.values)))
    
    return rmse, da_time, out_statevec, losses

# Function for baserun without DA

In [None]:
def run_baserun(system_dim_xy, nr_steps, spinup_size, valid_size, test_size, test_run, delta_t,random_seed, sigma_bg_multiplier):
    np_rng = np.random.default_rng(random_seed)
    jax.clear_backends()


    ### Nature Run
    nature_run = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy, delta_t=delta_t, store_as_jax=True, random_seed=random_seed)

    nature_run.generate(n_steps=nr_steps) 
    nr_spinup, nr_valid, nr_transient_and_test = nature_run.split_train_valid_test(
        spinup_size, valid_size, transient_size + test_size)
    nr_transient, nr_test, _ = nr_transient_and_test.split_train_valid_test(transient_size, test_size, 0)

    if not test_run:
        nr_eval = nr_valid
    else:
        nr_eval = nr_test
        
    # Calculate the per-variable SD for QGS model
    obs_sd_scale = 0.1
    per_variable_sd = np.std(nr_spinup.values, axis=0)
    obs_sd = 0.1*per_variable_sd
    sigma_bg = sigma_bg_multiplier*obs_sd


    ### Model
    model_pyqg = dab.data.PyQGJax(nx=system_dim_xy, ny=system_dim_xy, store_as_jax=True, random_seed=random_seed)
    
    ### Initial conditions
    
    ### Set up DA Cycler
    da_time_start = timer()
    
    
    cur_tstep = 0
    x0_original = nr_eval.values[cur_tstep] + np_rng.normal(size=(nature_run.system_dim,), scale=sigma_bg)
    x0_gridded = x0_original.reshape(nature_run.original_dim)
 
    model_pyqg.generate(x0=x0_gridded, n_steps=nr_eval.time_dim)
    rmse = np.sqrt(np.mean(np.square(model_pyqg.values[:-12] - nr_eval.values[:-12])))


    da_time = timer()-da_time_start
    
    return rmse, da_time

# Run DA for test period

### Baserun

In [None]:
all_results_df_list_baserun = []
system_dim_xy_list = [16, 20, 24, 32]
test_run=True
# Run
for system_dim_xy in system_dim_xy_list:
    random_seed = system_dim_xy
    run_dict = dict(
        system_dim_xy=system_dim_xy, 
        nr_steps=nr_steps,
        spinup_size=spinup_size,
        valid_size=valid_size,
        test_size=test_size,
        test_run=test_run,
        delta_t=delta_t,
        sigma_bg_multiplier=0.5,
        random_seed=random_seed)
    error_br, da_time = run_baserun(**run_dict)
    out_df = pd.DataFrame(run_dict,index=[0])
    out_df['rmse'] = error_br
    out_df['da_time'] = da_time
    print(error_br)
    print(da_time)
    all_results_df_list_baserun.append(out_df)


In [None]:
full_out_df_baserun = pd.concat(all_results_df_list_baserun)
print(full_out_df_baserun)
full_out_df_baserun.to_csv('./out/pyqg_jax/pyqg_baserun_results_test_v1.csv')

### 4Dvar

In [None]:
all_results_df_list_4d = []
all_statevecs_4d = []
system_dim_xy_list = [16, 20, 24, 32]
test_run=True
# Run
for system_dim_xy in system_dim_xy_list:
    random_seed = system_dim_xy
    run_dict = dict(
        system_dim_xy=system_dim_xy, 
        nr_steps=nr_steps,
        spinup_size=spinup_size,
        valid_size=valid_size,
        test_size=test_size,
        test_run=test_run,
        delta_t=delta_t,
        sigma_bg_multiplier=0.5,
        sigma_obs_multiplier=1.25,
        analysis_window=analysis_window,
        analysis_time_in_window=analysis_time_in_window,
        random_seed=random_seed,
        n_outer_loops=n_outer_loops)
    error_4d, da_time, out_sv, loss_vals = run_4dvar(**run_dict)
    out_df = pd.DataFrame(run_dict,index=[0])
    out_df['rmse'] = error_4d
    out_df['da_time'] = da_time
    print(error_4d)
    print(da_time)
    all_results_df_list_4d.append(out_df)
    all_statevecs_4d.append(out_sv)
    


In [None]:
full_out_df_4d = pd.concat(all_results_df_list_4d)
print(full_out_df_4d)
full_out_df_4d.to_csv('./out/pyqg_jax/pyqg_4dvar_results_test_v2_sysdimxy_all.csv')

In [None]:
for i in range(full_out_df_4d.shape[0]):
    sysdimxy = full_out_df_4d['system_dim_xy'].values[i]
    out_file = './out/pyqg_jax/pyqg_4dvar_results_sysdimxy_{}.pkl'.format(sysdimxy)
    out_vec = all_statevecs_4d[i]
    
    with open(out_file, 'wb') as f:  # open a text file
         pickle.dump(out_vec, f) # serialize the list
    f.close()

In [None]:
full_rmse_df_list = []
full_df_20_24 = pd.read_csv('./out/pyqg_jax/pyqg_4dvar_results_test_v2_sysdimxy_20_24.csv')
full_df_32 = pd.read_csv('./out/pyqg_jax/pyqg_4dvar_results_test_v2_sysdimxy_32_fullyear.csv')
temp_df_16 = pd.DataFrame({'rmse':[2.401127956217721e-7],
                           'da_time': [2827.1763451290026],
                           'system_dim_xy': [16]})
full_rmse_df_list.append(temp_df_16)
full_rmse_df_list.append(full_df_20_24.loc[:, ['rmse','da_time', 'system_dim_xy']])
full_rmse_df_list.append(full_df_32.loc[:, ['rmse','da_time', 'system_dim_xy']])

In [None]:
pd.concat(full_rmse_df_list).to_csv('./out/pyqg_jax/pyqg_4dvar_results_test_v2_sysdimxy_all_fullyear.csv')

### Backprop-4DVar

In [None]:
all_results_df_list_bp = []
all_statevecs_bp = []
all_losses_bp = []
system_dim_xy_list = [16, 20, 24, 32]
test_run=True
# Run
for system_dim_xy in system_dim_xy_list:
    random_seed=system_dim_xy
    raytune_results = best_results_system_dim.loc[best_results_system_dim['system_dim_xy']==system_dim_xy]
    test_size = year_in_timesteps
    lr = raytune_results['config/lr'].values[0]
    lr_decay = raytune_results['config/lr_decay'].values[0]
    run_dict = dict(
        system_dim_xy=system_dim_xy, 
        nr_steps=nr_steps,
        spinup_size=spinup_size,
        valid_size=valid_size,
        test_size=test_size,
        test_run=test_run,
        delta_t=delta_t,
        sigma_bg_multiplier=0.5,
        sigma_obs_multiplier=1.25,
        analysis_window=analysis_window,
        analysis_time_in_window=analysis_time_in_window,
        random_seed=random_seed,
        num_epochs=num_epochs,
        learning_rate=lr,
        lr_decay=lr_decay)
    error_bp, da_time, out_sv, out_losses = run_backprop_4dvar(**run_dict)
    out_df = pd.DataFrame(run_dict,index=[0])
    out_df['rmse'] = error_bp
    out_df['da_time'] = da_time
    all_results_df_list_bp.append(out_df)
    all_statevecs_bp.append(out_sv)
    all_losses_bp.append(out_losses)


In [None]:
full_out_df_bp = pd.concat(all_results_df_list_bp)
full_out_df_bp.to_csv('./out/pyqg_jax/pyqg_4dvar_results_test_v5_25epochs_dim32fullyear.csv')

In [None]:
full_out_df_bp