In [17]:
!hostname

nid004215


In [18]:
!pwd

/global/cfs/cdirs/m4334/jerry/climsim3_dev


# Import packages

In [3]:
import xarray as xr
import numpy as np
import pandas as pd
from sklearn.metrics import r2_score
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.cm import ScalarMappable
from matplotlib import gridspec
import matplotlib.lines as mlines
from matplotlib.transforms import blended_transform_factory
import os, gc, sys, glob, string, argparse
from tqdm import tqdm
import time
import string
import itertools
import sys
import pickle
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from climsim_utils.data_utils import *

# Set fontsizes

In [4]:
plt.rcParams.update({
    "axes.titlesize":   14,
    "axes.labelsize":   13,
    "figure.titlesize": 17,
    "xtick.labelsize":  11,
    "ytick.labelsize":  11,
    "legend.fontsize":  12,
})

def scale_default(param_name, scale_factor):
    """Get scaled version of default rcParam"""
    return plt.rcParams[param_name] * scale_factor

# Load utilities

In [5]:
grid_path = '/global/cfs/cdirs/m4334/jerry/climsim3_dev/grid_info/ClimSim_low-res_grid-info.nc'

input_mean_v2_rh_mc_file = 'input_mean_v2_rh_mc_pervar.nc'

input_max_v2_rh_mc_file = 'input_max_v2_rh_mc_pervar.nc'
input_min_v2_rh_mc_file = 'input_min_v2_rh_mc_pervar.nc'
output_scale_v2_rh_mc_file = 'output_scale_std_lowerthred_v2_rh_mc.nc'

input_mean_v6_file = 'input_mean_v6_pervar.nc'
input_max_v6_file = 'input_max_v6_pervar.nc'
input_min_v6_file = 'input_min_v6_pervar.nc'
output_scale_v6_file = 'output_scale_std_lowerthred_v6.nc'

lbd_qn_file = 'qn_exp_lambda_large.txt'

grid_info = xr.open_dataset(grid_path)
grid_area = grid_info['area'].values
area_weight = grid_area/np.sum(grid_area)
level = grid_info.lev.values

input_mean_v2_rh_mc = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/' + input_mean_v2_rh_mc_file)
input_max_v2_rh_mc = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/' + input_max_v2_rh_mc_file)
input_min_v2_rh_mc = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/' + input_min_v2_rh_mc_file)
output_scale_v2_rh_mc = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/outputs/' + output_scale_v2_rh_mc_file)

input_mean_v6 = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/' + input_mean_v6_file)
input_max_v6 = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/' + input_max_v6_file)
input_min_v6 = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/' + input_min_v6_file)
output_scale_v6 = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/outputs/' + output_scale_v6_file)

lbd_qn = np.loadtxt('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/' + lbd_qn_file, delimiter = ',')

data_v2_rh_mc = data_utils(grid_info = grid_info, 
                           input_mean = input_mean_v2_rh_mc, 
                           input_max = input_max_v2_rh_mc, 
                           input_min = input_min_v2_rh_mc, 
                           output_scale = output_scale_v2_rh_mc,
                           qinput_log = False,
                           normalize = False)
data_v2_rh_mc.set_to_v2_rh_mc_vars()

data_v6 = data_utils(grid_info = grid_info,
                     input_mean = input_mean_v6,
                     input_max = input_max_v6,
                     input_min = input_min_v6,
                     output_scale = output_scale_v6,
                     qinput_log = False,
                     normalize = False)                     
data_v6.set_to_v6_vars()

actual_input_v2_rh_mc = np.load('/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v2_rh_mc/test_set/actual_input.npy')
actual_target_v2_rh_mc = np.load('/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v2_rh_mc/test_set/actual_target.npy')

actual_input_v6 = np.load('/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v6/test_set/actual_input.npy')
actual_target_v6 = np.load('/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v6/test_set/actual_target.npy')

assert np.array_equal(actual_target_v2_rh_mc, actual_target_v6)
actual_target = np.load('/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v2_rh_mc/test_set/original_test_vars/actual_target_v2.npy')
del actual_target_v2_rh_mc
del actual_target_v6

assert np.array_equal(actual_input_v2_rh_mc[:,:,data_v2_rh_mc.ps_index], 
                      actual_input_v6[:,:,data_v6.ps_index])

surface_pressure = actual_input_v2_rh_mc[:, :, data_v2_rh_mc.ps_index]
hyam_component = (data_v2_rh_mc.hyam * data_v2_rh_mc.p0)[None,None,:]
hybm_component = data_v2_rh_mc.hybm[None,None,:] * surface_pressure[:,:,None]
pressures = hyam_component + hybm_component
pressures_binned = data_v2_rh_mc.zonal_bin_weight_3d(pressures)
lat_bin_mids = data_v2_rh_mc.lat_bin_mids
lats = data_v2_rh_mc.lats
lons = data_v2_rh_mc.lons

idx_p400_t10 = np.load('/pscratch/sd/z/zeyuanhu/hu_etal2024_data/microphysics_hourly/first_true_indices_p400_t10.npy')
for i in range(idx_p400_t10.shape[0]):
    for j in range(idx_p400_t10.shape[1]):
        idx_p400_t10[i,j] = level[int(idx_p400_t10[i,j])]

idx_p400_t10 = idx_p400_t10.mean(axis=0)
idx_p400_t10 = idx_p400_t10[np.newaxis,:]

idx_tropopause_zm = data_v2_rh_mc.zonal_bin_weight_2d(idx_p400_t10).flatten()

area_weight_dict = {
    'global': area_weight,
    'nh': np.where(lats > 30, area_weight, 0),
    'sh': np.where(lats < -30, area_weight, 0),
    'tropics': np.where((lats > -30) & (lats < 30), area_weight, 0)
}

lat_idx_dict = {
    '30S_30N': ((data_v2_rh_mc.lats < 30) & (data_v2_rh_mc.lats > -30))[None,:,None],
    '30N_60N': ((data_v2_rh_mc.lats < 60) & (data_v2_rh_mc.lats > 30))[None,:,None],
    '30S_60S': ((data_v2_rh_mc.lats < -30) & (data_v2_rh_mc.lats > -60))[None,:,None],
    '60N_90N': (data_v2_rh_mc.lats > 60)[None,:,None],
    '60S_90S': (data_v2_rh_mc.lats < -60)[None,:,None]
}

pressure_idx_dict = {
    'below_400hPa': pressures >= 400,
    'above_400hPa': pressures < 400
}

config_names = {
    'standard': 'Standard',
    'conf_loss': 'Confidence Loss',
    'diff_loss': 'Difference Loss',
    'multirep': 'Multirepresentation',
    'v6': 'Expanded Variable List'
}

model_names = {
    'unet': 'U-Net',
    'squeezeformer': 'Squeezeformer',
    'pure_resLSTM': 'Pure ResLSTM',
    'pao_model': 'Pao Model',
    'convnext': 'ConvNeXt',
    'encdec_lstm': 'Encoder-Decoder LSTM'
}

color_dict = {
    'unet': 'green',
    'squeezeformer': 'purple',
    'pure_resLSTM': 'blue',
    'pao_model': 'red',
    'convnext': 'gold',
    'encdec_lstm': 'orange',
}

color_dict_config = {
    'standard': 'blue',
    'conf_loss': 'cyan',
    'diff_loss': 'red',
    'multirep': 'orange',
    'v6': 'green'
}

offline_var_settings = {
    'DTPHYS': {'var_title': 'dT/dt', 'scaling': 1., 'unit': 'K/s', 'vmax': 5e-7, 'vmin': -5e-7, 'var_index':0},
    'DQ1PHYS': {'var_title': 'dQv/dt', 'scaling': 1e3, 'unit': 'g/kg/s', 'vmax': 1e-6, 'vmin': -1e-6, 'var_index':60},
    'DQ2PHYS': {'var_title': 'dQl/dt', 'scaling': 1e6, 'unit': 'mg/kg/s', 'vmax': 1e-3, 'vmin': -1e-3, 'var_index':120},
    'DQ3PHYS': {'var_title': 'dQi/dt', 'scaling': 1e6, 'unit': 'mg/kg/s', 'vmax': 1e-3, 'vmin': -1e-3, 'var_index':180},
    'DUPHYS': {'var_title': 'dU/dt', 'scaling': 1., 'unit': 'm/s/s', 'vmax': 5e-7, 'vmin': -5e-7, 'var_index':240},
    'DVPHYS': {'var_title': 'dV/dt', 'scaling': 1., 'unit': 'm/s/s', 'vmax': 5e-7, 'vmin': -5e-7, 'var_index':300}
}

online_var_settings = {
    'T': {'var_title': 'Temperature', 'scaling': 1.0, 'unit': 'K', 'vmax': 5, 'vmin': -5},
    'Q': {'var_title': 'Specific Humidity', 'scaling': 1000.0, 'unit': 'g/kg', 'vmax': 1, 'vmin': -1},
    'U': {'var_title': 'Zonal Wind', 'scaling': 1.0, 'unit': 'm/s', 'vmax': 4, 'vmin': -4},
    'V': {'var_title': 'Meridional Wind', 'scaling': 1.0, 'unit': 'm/s', 'vmax': 4, 'vmin': -4},
    'CLDLIQ': {'var_title': 'Liquid Cloud', 'scaling': 1e6, 'unit': 'mg/kg', 'vmax': 40, 'vmin': -40},
    'CLDICE': {'var_title': 'Ice Cloud', 'scaling': 1e6, 'unit': 'mg/kg', 'vmax': 5, 'vmin': -5},
    'TOTCLD': {'var_title': 'Total Cloud', 'scaling': 1e6, 'unit': 'mg/kg', 'vmax': 40, 'vmin': -40},
    'DTPHYS': {'var_title': 'Heating Tendency', 'scaling': 1., 'unit': 'K/s', 'vmax': 1.5e-5, 'vmin': -1.5e-5},
    'DQ1PHYS': {'var_title': 'Moistening Tendency', 'scaling': 1e3, 'unit': 'g/kg/s', 'vmax': 1.2e-5, 'vmin': -1.2e-5},
    'DQ2PHYS': {'var_title': 'Liquid Tendency', 'scaling': 1e6, 'unit': 'mg/kg/s', 'vmax': 0.0015, 'vmin': -0.0015},
    'DQ3PHYS': {'var_title': 'Ice Tendency', 'scaling': 1e6, 'unit': 'mg/kg/s', 'vmax': 0.0015, 'vmin': -0.0015},
    'DQnPHYS': {'var_title': 'Liquid + Ice Tendency', 'scaling': 1e6, 'unit': 'mg/kg/s', 'vmax': .0015, 'vmin': -.0015},
    'DUPHYS': {'var_title': 'Zonal Wind Tendency', 'scaling': 1., 'unit': 'm/s²', 'vmax': 2.2e-6, 'vmin': -2.2e-6}
}

online_paths = {
    'standard': '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/standard/five_year_runs/',
    'conf_loss': '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/conf/five_year_runs/',
    'diff_loss': '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/diff/five_year_runs/',
    'multirep': '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/multirep/five_year_runs/',
    'v6': '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/v6/five_year_runs/'
}

seeds = ['seed_7', 'seed_43', 'seed_1024']
seed_numbers = [7, 43, 1024]

climsim3_figures_save_path_offline = '/global/cfs/cdirs/m4334/jerry/climsim3_figures/offline'
climsim3_figures_save_path_online = '/global/cfs/cdirs/m4334/jerry/climsim3_figures/online'

sypd_standard = [
    17.97901035,
    17.97901035,
    18.05589744,
    18.90966668,
    19.03178614,
    18.65975565,
    17.80983045,
    18.08638973,
    17.93128756,
    21.78946853,
    21.9847479,
    22.20038272,
    25.36398855,
    25.19805864,
    25.19227595,
    23.54557187,
    23.82794497,
    23.447502
]

sypd_conf_loss = [
    17.50453936,
    17.29971515,
    17.44196107,
    18.20260208,
    18.19053688,
    18.04921777,
    16.88591639,
    17.06639861,
    17.10229289,
    20.55525559,
    20.63385549,
    20.42521955,
    23.65465549,
    23.3813296,
    23.25625602,
    22.31773163,
    22.25777422,
    21.88176458
]

sypd_diff_loss =[
    16.97469344,
    16.92105472,
    17.16379509,
    18.99145235,
    19.24026684,
    18.69869139,
    17.75509967,
    16.87358759,
    16.66731672,
    20.77602886,
    19.70212994,
    19.22005471,
    21.00260056,
    25.47730606,
    25.47730606,
    24.0893883,
    23.99985625,
    23.78792837
]

sypd_multirep = [
    18.33255552,
    18.35861568,
    18.3448099,
    19.00295884,
    16.93296157,
    18.19204416,
    17.23723679,
    17.27385405,
    17.27657263,
    18.10913991,
    19.86397164,
    21.35810934,
    23.51060144,
    23.60014809,
    24.52549157,
    23.19360711,
    23.43123639,
    23.23164752
]

sypd_v6 = [
    18.21284028,
    18.31251632,
    18.41758393,
    6.894533848,
    8.95130655,
    7.890410959,
    17.84446786,
    17.79133537,
    17.73373234,
    21.39020294,
    17.90628317,
    21.40600448,
    24.47883654,
    24.13656852,
    24.3122588,
    23.34100196,
    23.5582043,
    23.5169076
]

pc_standard = [
    12975373,
    12975373,
    12975373,
    44785225,
    44785225,
    44785225,
    15395341,
    15395341,
    15395341,
    18876133,
    18876133,
    18876133,
    26805429,
    26805429,
    26805429,
    18582976,
    18582976,
    18582976
]

pc_conf_loss = [
    12980634,
    12980634,
    12980634,
    44811862,
    44811862,
    44811862,
    15402010,
    15402010,
    15402010,
    25407346,
    25407346,
    25407346,
    26839242,
    26839242,
    26839242,
    20723124,
    20723124,
    20723124
]

pc_diff_loss = [
    12975373,
    12975373,
    12975373,
    44785225,
    44785225,
    44785225,
    15395341,
    15395341,
    15395341,
    18876133,
    18876133,
    18876133,
    26805429,
    26805429,
    26805429,
    18582976,
    18582976,
    18582976
]

pc_multirep = [
    12981517,
    12981517,
    12981517,
    44791369,
    44791369,
    44791369,
    15428109,
    15428109,
    15428109,
    18880613,
    18880613,
    18880613,
    26811573,
    26811573,
    26811573,
    21004960,
    21004960,
    21004960
]

pc_v6 = [
    12981517,
    12981517,
    12981517,
    44791369,
    44791369,
    44791369,
    15428109,
    15428109,
    15428109,
    18880373,
    18880373,
    18880373,
    26811573,
    26811573,
    26811573,
    21004960,
    21004960,
    21004960
]

standard_sypd_dict = {}
conf_loss_sypd_dict = {}
diff_loss_sypd_dict = {}
multirep_sypd_dict = {}
v6_sypd_dict = {}

standard_pc_dict = {}
conf_loss_pc_dict = {}
diff_loss_pc_dict = {}
multirep_pc_dict = {}
v6_pc_dict = {}

i = 0
for model_name in model_names.keys():
    for seed_number in seed_numbers:
        standard_sypd_dict[f"{model_name}_{seed_number}"] = sypd_standard[i]
        conf_loss_sypd_dict[f"{model_name}_{seed_number}"] = sypd_conf_loss[i]
        diff_loss_sypd_dict[f"{model_name}_{seed_number}"] = sypd_diff_loss[i]
        multirep_sypd_dict[f"{model_name}_{seed_number}"] = sypd_multirep[i]
        v6_sypd_dict[f"{model_name}_{seed_number}"] = sypd_v6[i]
        standard_pc_dict[f"{model_name}_{seed_number}"] = pc_standard[i]
        conf_loss_pc_dict[f"{model_name}_{seed_number}"] = pc_conf_loss[i]
        diff_loss_pc_dict[f"{model_name}_{seed_number}"] = pc_diff_loss[i]
        multirep_pc_dict[f"{model_name}_{seed_number}"] = pc_multirep[i]
        v6_pc_dict[f"{model_name}_{seed_number}"] = pc_v6[i]
        i += 1

sypd_dict = {
    'standard': standard_sypd_dict,
    'conf_loss': conf_loss_sypd_dict,
    'diff_loss': diff_loss_sypd_dict,
    'multirep': multirep_sypd_dict,
    'v6': v6_sypd_dict
}
pc_dict = {
    'standard': standard_pc_dict,
    'conf_loss': conf_loss_pc_dict,
    'diff_loss': diff_loss_pc_dict,
    'multirep': multirep_pc_dict,
    'v6': v6_pc_dict
}

# Define helper functions

In [6]:
def ls(data_path = ""):
    return os.popen(" ".join(["ls", data_path])).read().splitlines()

def offline_area_time_mean_3d(arr):
    arr_zonal_mean = data_v2_rh_mc.zonal_bin_weight_3d(arr)
    arr_zonal_time_mean = arr_zonal_mean.mean(axis = 0)
    arr_zonal_time_mean = xr.DataArray(arr_zonal_time_mean.T, dims = ['hybrid pressure (hPa)', 'latitude'], coords = {'hybrid pressure (hPa)':level, 'latitude': lat_bin_mids})
    return arr_zonal_time_mean

def online_area_time_mean_3d(ds, var):
    arr = ds[var].values[1:,:,:]
    arr_reshaped = np.transpose(arr, (0,2,1))
    arr_zonal_mean = data_v2_rh_mc.zonal_bin_weight_3d(arr_reshaped)
    arr_zonal_time_mean = arr_zonal_mean.mean(axis = 0)
    arr_zonal_time_mean = xr.DataArray(arr_zonal_time_mean.T, dims = ['hybrid pressure (hPa)', 'latitude'], coords = {'hybrid pressure (hPa)':level, 'latitude': lat_bin_mids})
    return arr_zonal_time_mean

def area_mean(ds, var):
    arr = ds[var].values
    arr_reshaped = np.transpose(arr, (0,2,1))
    arr_zonal_mean = data_v2_rh_mc.zonal_bin_weight_3d(arr_reshaped)
    return arr_zonal_mean

def zonal_diff(ds_sp, ds_nn, var):
    diff_zonal_mean = (area_mean(ds_nn, var) - area_mean(ds_sp, var)).mean(axis = 0)
    diff_zonal = xr.DataArray(diff_zonal_mean.T, dims = ['level', 'lat'], coords = {'level':level, 'lat': lat_bin_mids})
    return diff_zonal

def get_dp(ds):
    ps = ds['PS']
    p_interface = (ds['hyai'] * ds['P0'] + ds['hybi'] * ds['PS']).values
    if p_interface.shape[0] == 61:
        p_interface = np.swapaxes(p_interface, 0, 1)
    dp = p_interface[:,1:61,:] - p_interface[:,0:60,:]
    return dp

def get_tcp_mean(ds, area_weight):
    cld = ds['TOTCLD'].values
    dp = get_dp(ds)
    tcp = np.sum(cld*dp, axis = 1)/9.81
    tcp_mean = np.average(tcp, weights = area_weight, axis = 1)
    return tcp_mean

def get_tcp_std(ds, area_weight):
    cld = ds['TOTCLD'].values
    dp = get_dp(ds)
    tcp = np.sum(cld*dp, axis = 1)/9.81
    tcp_mean = np.average(tcp, weights = area_weight, axis = 1)
    squared_diff = (tcp - tcp_mean[:, None])**2
    tcp_std = np.sqrt(np.average(squared_diff, weights = area_weight, axis = 1))
    return tcp_std

def read_mmf_online_data(num_years):
    assert num_years <= 5 and num_years >= 1
    years_regexp = '34567'[:num_years]
    ds_mmf_1 = xr.open_mfdataset(f'/pscratch/sd/z/zeyuanhu/hu_etal2024_data_v2/data/h0/5year/mmf_ref/control_fullysp_jan_wmlio_r3.eam.h0.000[{years_regexp}]*.nc')
    ds_mmf_2 = xr.open_mfdataset(f'/pscratch/sd/z/zeyuanhu/hu_etal2024_data_v2/data/h0/5year/mmf_b/control_fullysp_jan_wmlio_r3_b.eam.h0.000[{years_regexp}]*.nc')
    ds_mmf_1['DQnPHYS'] = ds_mmf_1['DQ2PHYS'] + ds_mmf_1['DQ3PHYS']
    ds_mmf_2['DQnPHYS'] = ds_mmf_2['DQ2PHYS'] + ds_mmf_2['DQ3PHYS']
    ds_mmf_1['TOTCLD'] = ds_mmf_1['CLDICE'] + ds_mmf_1['CLDLIQ']
    ds_mmf_2['TOTCLD'] = ds_mmf_2['CLDICE'] + ds_mmf_2['CLDLIQ']
    ds_mmf_1['PRECT'] = ds_mmf_1['PRECC'] + ds_mmf_1['PRECL']
    ds_mmf_2['PRECT'] = ds_mmf_2['PRECC'] + ds_mmf_2['PRECL']
    return ds_mmf_1, ds_mmf_2

def read_nn_online_data(config_name, model_name, seed, num_years):
    assert num_years <= 5 and num_years >= 1
    years_regexp = '34567'[:num_years]
    if config_name == 'standard':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_seed_{seed}', 'run', f'{model_name}_seed_{seed}.eam.h0.000[{years_regexp}]*.nc')
    elif config_name == 'conf_loss':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_conf_seed_{seed}', 'run', f'{model_name}_conf_seed_{seed}.eam.h0.000[{years_regexp}]*.nc')
    elif config_name == 'diff_loss':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_diff_seed_{seed}', 'run', f'{model_name}_diff_seed_{seed}.eam.h0.000[{years_regexp}]*.nc')
    elif config_name == 'multirep':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_multirep_seed_{seed}', 'run', f'{model_name}_multirep_seed_{seed}.eam.h0.000[{years_regexp}]*.nc')
    elif config_name == 'v6':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_v6_seed_{seed}', 'run', f'{model_name}_v6_seed_{seed}.eam.h0.000[{years_regexp}]*.nc')
    if len(ls(extract_path)) == 0:
        return None
    ds_nn = xr.open_mfdataset(extract_path)
    if len(ds_nn['time']) < 12 * num_years:
        return None
    ds_nn['DQnPHYS'] = ds_nn['DQ2PHYS'] + ds_nn['DQ3PHYS']
    ds_nn['TOTCLD'] = ds_nn['CLDICE'] + ds_nn['CLDLIQ']
    ds_nn['PRECT'] = ds_nn['PRECC'] + ds_nn['PRECL']
    return ds_nn

def read_nn_online_precip_data(config_name, model_name, seed, num_years):
    assert num_years <= 5 and num_years >= 1
    years_regexp = '34567'[:num_years]
    if config_name == 'standard':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_seed_{seed}', 'run', 'precip_dir', 'combined_precip.nc')
    elif config_name == 'conf_loss':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_conf_seed_{seed}', 'run', 'precip_dir', 'combined_precip.nc')
    elif config_name == 'diff_loss':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_diff_seed_{seed}', 'run', 'precip_dir', 'combined_precip.nc')
    elif config_name == 'multirep':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_multirep_seed_{seed}', 'run', 'precip_dir', 'combined_precip.nc')
    elif config_name == 'v6':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_v6_seed_{seed}', 'run', 'precip_dir', 'combined_precip.nc')
    if len(ls(extract_path)) == 0:
        return None
    ds_nn = xr.open_dataset(extract_path)
    if len(ds_nn['time']) < 365 * 24 * num_years:
        return None
    return ds_nn['PRECT']

def get_pressure_area_weights(ds, surface_type = None):
    ds_dp = get_dp(ds)
    ds_total_weight = ds_dp * area_weight[None, None, :]
    ds_total_weight = ds_total_weight.mean(axis = 0)
    ds_total_weight = ds_total_weight/ds_total_weight.sum()
    if surface_type is None:
        return ds_total_weight
    elif surface_type == 'land':
        land_area = ds['LANDFRAC'].values * grid_area[None, :]
        land_area_sums = np.array([[np.sum(land_area[t,:][data_v2_rh_mc.lat_bin_dict[lat_bin]]) for lat_bin in data_v2_rh_mc.lat_bin_dict.keys()] for t in range(land_area.shape[0])])
        land_area_divs = np.stack([np.divide(1, land_area_sums[:, bin_index], where=~(land_area_sums[:, bin_index] == 0), out=np.zeros_like(land_area_sums[:, bin_index])) for bin_index in data_v2_rh_mc.lat_bin_indices], axis=1)
        land_area_weighting = land_area * land_area_divs
        return land_area_weighting
    elif surface_type == 'ocean':
        ocean_area = ds['OCNFRAC'].values * grid_area[None, :]
        ocean_area_sums = np.array([[np.sum(ocean_area[t,:][data_v2_rh_mc.lat_bin_dict[lat_bin]]) for lat_bin in data_v2_rh_mc.lat_bin_dict.keys()] for t in range(ocean_area.shape[0])])
        ocean_area_divs = np.stack([np.divide(1, ocean_area_sums[:, bin_index], where=~(ocean_area_sums[:, bin_index] == 0), out=np.zeros_like(ocean_area_sums[:, bin_index])) for bin_index in data_v2_rh_mc.lat_bin_indices], axis=1)
        ocean_area_weighting = ocean_area * ocean_area_divs
        return ocean_area_weighting
    elif surface_type == 'ice':
        ice_area = ds['ICEFRAC'].values * grid_area[None, :]
        ice_area_sums = np.array([[np.sum(ice_area[t,:][data_v2_rh_mc.lat_bin_dict[lat_bin]]) for lat_bin in data_v2_rh_mc.lat_bin_dict.keys()] for t in range(ice_area.shape[0])])
        ice_area_divs = np.stack([np.divide(1, ice_area_sums[:, bin_index], where=~(ice_area_sums[:, bin_index] == 0), out=np.zeros_like(ice_area_sums[:, bin_index])) for bin_index in data_v2_rh_mc.lat_bin_indices], axis=1)
        ice_area_weighting = ice_area * ice_area_divs
        return ice_area_weighting
    else:
        raise ValueError("Invalid surface type. Choose from 'land', 'ocean', or 'ice'.")

def get_offline_precip_area_weights(nn_preds, precc_index = 363):
    nn_precc = nn_preds[:,:,precc_index] * 86400 * 1000
    no_precip_mask = nn_precc[:,:] == 0
    nn_active_precip = nn_precc[:,:][~no_precip_mask]
    nn_active_precip_quantiles = np.quantile(nn_active_precip, [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
    active_precip_1_10_percentile_mask = (nn_precc[:,:] > 0) & (nn_precc[:,:] <= nn_active_precip_quantiles[0])
    active_precip_10_20_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[0]) & (nn_precc[:,:] <= nn_active_precip_quantiles[1])
    active_precip_20_30_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[1]) & (nn_precc[:,:] <= nn_active_precip_quantiles[2])
    active_precip_30_40_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[2]) & (nn_precc[:,:] <= nn_active_precip_quantiles[3])
    active_precip_40_50_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[3]) & (nn_precc[:,:] <= nn_active_precip_quantiles[4])
    active_precip_50_60_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[4]) & (nn_precc[:,:] <= nn_active_precip_quantiles[5])
    active_precip_60_70_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[5]) & (nn_precc[:,:] <= nn_active_precip_quantiles[6])
    active_precip_70_80_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[6]) & (nn_precc[:,:] <= nn_active_precip_quantiles[7])
    active_precip_80_90_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[7]) & (nn_precc[:,:] <= nn_active_precip_quantiles[8])
    active_precip_90_100_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[8]) & (nn_precc[:,:] <= nn_active_precip_quantiles[9])
    precip_area_dict = {
        'no_precip': no_precip_mask * area_weight[None,:],
        'active_precip_1_10_percentile': active_precip_1_10_percentile_mask * area_weight[None,:],
        'active_precip_10_20_percentile': active_precip_10_20_percentile_mask * area_weight[None,:],
        'active_precip_20_30_percentile': active_precip_20_30_percentile_mask * area_weight[None,:],
        'active_precip_30_40_percentile': active_precip_30_40_percentile_mask * area_weight[None,:],
        'active_precip_40_50_percentile': active_precip_40_50_percentile_mask * area_weight[None,:],
        'active_precip_50_60_percentile': active_precip_50_60_percentile_mask * area_weight[None,:],
        'active_precip_60_70_percentile': active_precip_60_70_percentile_mask * area_weight[None,:],
        'active_precip_70_80_percentile': active_precip_70_80_percentile_mask * area_weight[None,:],
        'active_precip_80_90_percentile': active_precip_80_90_percentile_mask * area_weight[None,:],
        'active_precip_90_100_percentile': active_precip_90_100_percentile_mask * area_weight[None,:]
    }
    precip_area_divs = {key: np.divide(np.ones_like(precip_area_dict[key]), np.sum(precip_area_dict[key]), out = np.zeros_like(precip_area_dict[key]), where = (precip_area_dict[key] != 0)) for key in precip_area_dict.keys()}
    for key in precip_area_dict.keys():
        precip_area_dict[key] = (precip_area_dict[key] * precip_area_divs[key])[:,:,None]
    return precip_area_dict

precip_percentile_labels = {
    'no_precip': 'No Precipitation',
    'active_precip_1_10_percentile': '1-10th percentile',
    'active_precip_10_20_percentile': '10-20th percentile',
    'active_precip_20_30_percentile': '20-30th percentile',
    'active_precip_30_40_percentile': '30-40th percentile',
    'active_precip_40_50_percentile': '40-50th percentile',
    'active_precip_50_60_percentile': '50-60th percentile',
    'active_precip_60_70_percentile': '60-70th percentile',
    'active_precip_70_80_percentile': '70-80th percentile',
    'active_precip_80_90_percentile': '80-90th percentile',
    'active_precip_90_100_percentile': '90-100th percentile'
}

# Load offline preds

In [7]:
standard_save_path = '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v2_rh_mc/test_set/test_preds/standard/'
conf_loss_save_path = '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v2_rh_mc/test_set/test_preds/conf_loss/'
diff_loss_save_path = '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v2_rh_mc/test_set/test_preds/diff_loss/'
multirep_save_path = '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v2_rh_mc/test_set/test_preds/multirep/'
v6_save_path = '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v6/test_set/test_preds/'

def load_seed_data(save_path, npz_file, seed_key):
    with np.load(os.path.join(save_path, npz_file)) as data:
        return data[seed_key]

print('loading standard preds')
standard_preds = {
    'unet': lambda seed_key: load_seed_data(standard_save_path, 'standard_unet_preds.npz', seed_key),
    'squeezeformer': lambda seed_key: load_seed_data(standard_save_path, 'standard_squeezeformer_preds.npz', seed_key),
    'pure_resLSTM': lambda seed_key: load_seed_data(standard_save_path, 'standard_pure_resLSTM_preds.npz', seed_key),
    'pao_model': lambda seed_key: load_seed_data(standard_save_path, 'standard_pao_model_preds.npz', seed_key),
    'convnext': lambda seed_key: load_seed_data(standard_save_path, 'standard_convnext_preds.npz', seed_key),
    'encdec_lstm': lambda seed_key: load_seed_data(standard_save_path, 'standard_encdec_lstm_preds.npz', seed_key)
}

print('loading conf loss preds')
conf_loss_preds = {
    'unet': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_unet_preds.npz', seed_key),
    'squeezeformer': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_squeezeformer_preds.npz', seed_key),
    'pure_resLSTM': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_pure_resLSTM_preds.npz', seed_key),
    'pao_model': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_pao_model_preds.npz', seed_key),
    'convnext': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_convnext_preds.npz', seed_key),
    'encdec_lstm': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_encdec_lstm_preds.npz', seed_key)
}

print('loading conf loss conf')
conf_loss_conf = {
    'unet': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_unet_conf.npz', seed_key),
    'squeezeformer': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_squeezeformer_conf.npz', seed_key),
    'pure_resLSTM': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_pure_resLSTM_conf.npz', seed_key),
    'pao_model': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_pao_model_conf.npz', seed_key),
    'convnext': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_convnext_conf.npz', seed_key),
    'encdec_lstm': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_encdec_lstm_conf.npz', seed_key)
}

print('loading diff loss preds')
diff_loss_preds = {
    'unet': lambda seed_key: load_seed_data(diff_loss_save_path, 'diff_loss_unet_preds.npz', seed_key),
    'squeezeformer': lambda seed_key: load_seed_data(diff_loss_save_path, 'diff_loss_squeezeformer_preds.npz', seed_key),
    'pure_resLSTM': lambda seed_key: load_seed_data(diff_loss_save_path, 'diff_loss_pure_resLSTM_preds.npz', seed_key),
    'pao_model': lambda seed_key: load_seed_data(diff_loss_save_path, 'diff_loss_pao_model_preds.npz', seed_key),
    'convnext': lambda seed_key: load_seed_data(diff_loss_save_path, 'diff_loss_convnext_preds.npz', seed_key),
    'encdec_lstm': lambda seed_key: load_seed_data(diff_loss_save_path, 'diff_loss_encdec_lstm_preds.npz', seed_key)
}

print('loading multirep preds')
multirep_preds = {
    'unet': lambda seed_key: load_seed_data(multirep_save_path, 'multirep_unet_preds.npz', seed_key),
    'squeezeformer': lambda seed_key: load_seed_data(multirep_save_path, 'multirep_squeezeformer_preds.npz', seed_key),
    'pure_resLSTM': lambda seed_key: load_seed_data(multirep_save_path, 'multirep_pure_resLSTM_preds.npz', seed_key),
    'pao_model': lambda seed_key: load_seed_data(multirep_save_path, 'multirep_pao_model_preds.npz', seed_key),
    'convnext': lambda seed_key: load_seed_data(multirep_save_path, 'multirep_convnext_preds.npz', seed_key),
    'encdec_lstm': lambda seed_key: load_seed_data(multirep_save_path, 'multirep_encdec_lstm_preds.npz', seed_key)
}

print('loading v6 preds')
v6_preds = {
    'unet': lambda seed_key: load_seed_data(v6_save_path, 'v6_unet_preds.npz', seed_key),
    'squeezeformer': lambda seed_key: load_seed_data(v6_save_path, 'v6_squeezeformer_preds.npz', seed_key),
    'pure_resLSTM': lambda seed_key: load_seed_data(v6_save_path, 'v6_pure_resLSTM_preds.npz', seed_key),
    'pao_model': lambda seed_key: load_seed_data(v6_save_path, 'v6_pao_model_preds.npz', seed_key),
    'convnext': lambda seed_key: load_seed_data(v6_save_path, 'v6_convnext_preds.npz', seed_key),
    'encdec_lstm': lambda seed_key: load_seed_data(v6_save_path, 'v6_encdec_lstm_preds.npz', seed_key)
}

config_preds = {
    'standard': standard_preds,
    'conf_loss': conf_loss_preds,
    'diff_loss': diff_loss_preds,
    'multirep': multirep_preds,
    'v6': v6_preds
}

print('loading Kaggle preds')
kaggle_check_path = '/pscratch/sd/j/jerrylin/kaggle_check'

greysnow = np.load(os.path.join(kaggle_check_path, 'prds.npy'))
greysnow = greysnow.reshape(-1, data_v2_rh_mc.num_latlon, 368)

adam = pd.read_parquet(os.path.join(kaggle_check_path, 'final_blend_v10.parquet'))
adam = adam.iloc[:,1:].values
adam = adam.reshape(-1, data_v2_rh_mc.num_latlon, 368)

print(greysnow.shape, adam.shape)

loading standard preds
loading conf loss preds
loading conf loss conf
loading diff loss preds
loading multirep preds
loading v6 preds
loading Kaggle preds
(4380, 384, 368) (4380, 384, 368)


# Load Offline R2

In [8]:
with open(os.path.join(standard_save_path, "standard_unet_r2.pkl"), "rb") as f:
    standard_unet_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "standard_squeezeformer_r2.pkl"), "rb") as f:
    standard_squeezeformer_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "standard_pure_resLSTM_r2.pkl"), "rb") as f:
    standard_pure_resLSTM_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "standard_pao_model_r2.pkl"), "rb") as f:
    standard_pao_model_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "standard_convnext_r2.pkl"), "rb") as f:
    standard_convnext_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "standard_encdec_lstm_r2.pkl"), "rb") as f:
    standard_encdec_lstm_r2 = pickle.load(f)

with open(os.path.join(conf_loss_save_path, "conf_loss_unet_r2.pkl"), "rb") as f:
    conf_loss_unet_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "conf_loss_squeezeformer_r2.pkl"), "rb") as f:
    conf_loss_squeezeformer_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "conf_loss_pure_resLSTM_r2.pkl"), "rb") as f:
    conf_loss_pure_resLSTM_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "conf_loss_pao_model_r2.pkl"), "rb") as f:
    conf_loss_pao_model_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "conf_loss_convnext_r2.pkl"), "rb") as f:
    conf_loss_convnext_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "conf_loss_encdec_lstm_r2.pkl"), "rb") as f:
    conf_loss_encdec_lstm_r2 = pickle.load(f)

with open(os.path.join(diff_loss_save_path, "diff_loss_unet_r2.pkl"), "rb") as f:
    diff_loss_unet_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "diff_loss_squeezeformer_r2.pkl"), "rb") as f:
    diff_loss_squeezeformer_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "diff_loss_pure_resLSTM_r2.pkl"), "rb") as f:
    diff_loss_pure_resLSTM_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "diff_loss_pao_model_r2.pkl"), "rb") as f:
    diff_loss_pao_model_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "diff_loss_convnext_r2.pkl"), "rb") as f:
    diff_loss_convnext_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "diff_loss_encdec_lstm_r2.pkl"), "rb") as f:
    diff_loss_encdec_lstm_r2 = pickle.load(f)

with open(os.path.join(multirep_save_path, "multirep_unet_r2.pkl"), "rb") as f:
    multirep_unet_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "multirep_squeezeformer_r2.pkl"), "rb") as f:
    multirep_squeezeformer_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "multirep_pure_resLSTM_r2.pkl"), "rb") as f:
    multirep_pure_resLSTM_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "multirep_pao_model_r2.pkl"), "rb") as f:
    multirep_pao_model_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "multirep_convnext_r2.pkl"), "rb") as f:
    multirep_convnext_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "multirep_encdec_lstm_r2.pkl"), "rb") as f:
    multirep_encdec_lstm_r2 = pickle.load(f)

with open(os.path.join(v6_save_path, "v6_unet_r2.pkl"), "rb") as f:
    v6_unet_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "v6_squeezeformer_r2.pkl"), "rb") as f:
    v6_squeezeformer_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "v6_pure_resLSTM_r2.pkl"), "rb") as f:
    v6_pure_resLSTM_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "v6_pao_model_r2.pkl"), "rb") as f:
    v6_pao_model_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "v6_convnext_r2.pkl"), "rb") as f:
    v6_convnext_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "v6_encdec_lstm_r2.pkl"), "rb") as f:
    v6_encdec_lstm_r2 = pickle.load(f)

# Load online runs

In [None]:
def calculate_rmse(ds1, ds2, total_weight, var, num_years):
    months = np.arange(1, num_years * 12 + 1)
    # Initialize the RMSE array with NaN values
    rmse_per_month = np.full(len(months), np.nan)
    if not ds1:
        return rmse_per_month

    total_weight = get_dp(ds2) * ds2['area'].values[:,None,:]
    # Determine the number of months in ds1
    num_months = ds1[var].shape[0]
    
    # Slice total_weight to match the number of months in ds1
    total_weight_sliced = total_weight[:num_months, :, :]
    
    # Compute RMSE for existing months
    squared_diff = (ds1[var] - ds2[var]) ** 2
    weighted_squared_diff = squared_diff * total_weight_sliced
    weighted_sum = weighted_squared_diff.sum(axis=(1, 2))
    total_weight_sum = total_weight_sliced.sum(axis=(1, 2))
    weighted_mean_squared_diff = weighted_sum / total_weight_sum
    rmse_existing_months = np.sqrt(weighted_mean_squared_diff)
    
    # Fill in the RMSE array with the computed values
    rmse_per_month[:num_months] = rmse_existing_months.values
    rmse_per_month = rmse_per_month * online_var_settings[var]['scaling']
    return rmse_per_month

ds_mmf_1_5_year, ds_mmf_2_5_year = read_mmf_online_data(num_years = 5)
ds_mmf_1_4_year, ds_mmf_2_4_year = read_mmf_online_data(num_years = 4)
mmf_1_total_weight_5_year = get_pressure_area_weights(ds_mmf_1_5_year)
mmf_1_total_weight_4_year = get_pressure_area_weights(ds_mmf_1_4_year)
mmf_1_5_year_rmse = {
    'T': calculate_rmse(ds_mmf_2_5_year, ds_mmf_1_5_year, mmf_1_total_weight_5_year, var='T', num_years=5),
    'Q': calculate_rmse(ds_mmf_2_5_year, ds_mmf_1_5_year, mmf_1_total_weight_5_year, var='Q', num_years=5),
    'U': calculate_rmse(ds_mmf_2_5_year, ds_mmf_1_5_year, mmf_1_total_weight_5_year, var='U', num_years=5),
    'V': calculate_rmse(ds_mmf_2_5_year, ds_mmf_1_5_year, mmf_1_total_weight_5_year, var='V', num_years=5),
    'CLDLIQ': calculate_rmse(ds_mmf_2_5_year, ds_mmf_1_5_year, mmf_1_total_weight_5_year, var='CLDLIQ', num_years=5),
    'CLDICE': calculate_rmse(ds_mmf_2_5_year, ds_mmf_1_5_year, mmf_1_total_weight_5_year, var='CLDICE', num_years=5),
    'DTPHYS': calculate_rmse(ds_mmf_2_5_year, ds_mmf_1_5_year, mmf_1_total_weight_5_year, var='DTPHYS', num_years=5),
    'DQ1PHYS': calculate_rmse(ds_mmf_2_5_year, ds_mmf_1_5_year, mmf_1_total_weight_5_year, var='DQ1PHYS', num_years=5),
    'DQ2PHYS': calculate_rmse(ds_mmf_2_5_year, ds_mmf_1_5_year, mmf_1_total_weight_5_year, var='DQ2PHYS', num_years=5),
    'DQ3PHYS': calculate_rmse(ds_mmf_2_5_year, ds_mmf_1_5_year, mmf_1_total_weight_5_year, var='DQ3PHYS', num_years=5),
}
mmf_1_4_year_rmse = {
    'T': calculate_rmse(ds_mmf_2_4_year, ds_mmf_1_4_year, mmf_1_total_weight_4_year, var='T', num_years=4),
    'Q': calculate_rmse(ds_mmf_2_4_year, ds_mmf_1_4_year, mmf_1_total_weight_4_year, var='Q', num_years=4),
    'U': calculate_rmse(ds_mmf_2_4_year, ds_mmf_1_4_year, mmf_1_total_weight_4_year, var='U', num_years=4),
    'V': calculate_rmse(ds_mmf_2_4_year, ds_mmf_1_4_year, mmf_1_total_weight_4_year, var='V', num_years=4),
    'CLDLIQ': calculate_rmse(ds_mmf_2_4_year, ds_mmf_1_4_year, mmf_1_total_weight_4_year, var='CLDLIQ', num_years=4),
    'CLDICE': calculate_rmse(ds_mmf_2_4_year, ds_mmf_1_4_year, mmf_1_total_weight_4_year, var='CLDICE', num_years=4),
    'DTPHYS': calculate_rmse(ds_mmf_2_4_year, ds_mmf_1_4_year, mmf_1_total_weight_4_year, var='DTPHYS', num_years=4),
    'DQ1PHYS': calculate_rmse(ds_mmf_2_4_year, ds_mmf_1_4_year, mmf_1_total_weight_4_year, var='DQ1PHYS', num_years=4),
    'DQ2PHYS': calculate_rmse(ds_mmf_2_4_year, ds_mmf_1_4_year, mmf_1_total_weight_4_year, var='DQ2PHYS', num_years=4),
    'DQ3PHYS': calculate_rmse(ds_mmf_2_4_year, ds_mmf_1_4_year, mmf_1_total_weight_4_year, var='DQ3PHYS', num_years=4),
}


In [None]:
ds_nn_standard_5_year = {
    'unet': {seed_number: read_nn_online_data('standard', 'unet', seed_number, num_years = 5) for seed_number in seed_numbers},
    'squeezeformer': {seed_number: read_nn_online_data('standard', 'squeezeformer', seed_number, num_years = 5) for seed_number in seed_numbers},
    'pure_resLSTM': {seed_number: read_nn_online_data('standard', 'pure_resLSTM', seed_number, num_years = 5) for seed_number in seed_numbers},
    'pao_model': {seed_number: read_nn_online_data('standard', 'pao_model', seed_number, num_years = 5) for seed_number in seed_numbers},
    'convnext': {seed_number: read_nn_online_data('standard', 'convnext', seed_number, num_years = 5) for seed_number in seed_numbers},
    'encdec_lstm': {seed_number: read_nn_online_data('standard', 'encdec_lstm', seed_number, num_years = 5) for seed_number in seed_numbers}
}
ds_nn_conf_loss_5_year = {
    'unet': {seed_number: read_nn_online_data('conf_loss', 'unet', seed_number, num_years = 5) for seed_number in seed_numbers},
    'squeezeformer': {seed_number: read_nn_online_data('conf_loss', 'squeezeformer', seed_number, num_years = 5) for seed_number in seed_numbers},
    'pure_resLSTM': {seed_number: read_nn_online_data('conf_loss', 'pure_resLSTM', seed_number, num_years = 5) for seed_number in seed_numbers},
    'pao_model': {seed_number: read_nn_online_data('conf_loss', 'pao_model', seed_number, num_years = 5) for seed_number in seed_numbers},
    'convnext': {seed_number: read_nn_online_data('conf_loss', 'convnext', seed_number, num_years = 5) for seed_number in seed_numbers},
    'encdec_lstm': {seed_number: read_nn_online_data('conf_loss', 'encdec_lstm', seed_number, num_years = 5) for seed_number in seed_numbers}
}
ds_nn_diff_loss_5_year = {
    'unet': {seed_number: read_nn_online_data('diff_loss', 'unet', seed_number, num_years = 5) for seed_number in seed_numbers},
    'squeezeformer': {seed_number: read_nn_online_data('diff_loss', 'squeezeformer', seed_number, num_years = 5) for seed_number in seed_numbers},
    'pure_resLSTM': {seed_number: read_nn_online_data('diff_loss', 'pure_resLSTM', seed_number, num_years = 5) for seed_number in seed_numbers},
    'pao_model': {seed_number: read_nn_online_data('diff_loss', 'pao_model', seed_number, num_years = 5) for seed_number in seed_numbers},
    'convnext': {seed_number: read_nn_online_data('diff_loss', 'convnext', seed_number, num_years = 5) for seed_number in seed_numbers},
    'encdec_lstm': {seed_number: read_nn_online_data('diff_loss', 'encdec_lstm', seed_number, num_years = 5) for seed_number in seed_numbers}
}
ds_nn_multirep_5_year = {
    'unet': {seed_number: read_nn_online_data('multirep', 'unet', seed_number, num_years = 5) for seed_number in seed_numbers},
    'squeezeformer': {seed_number: read_nn_online_data('multirep', 'squeezeformer', seed_number, num_years = 5) for seed_number in seed_numbers},
    'pure_resLSTM': {seed_number: read_nn_online_data('multirep', 'pure_resLSTM', seed_number, num_years = 5) for seed_number in seed_numbers},
    'pao_model': {seed_number: read_nn_online_data('multirep', 'pao_model', seed_number, num_years = 5) for seed_number in seed_numbers},
    'convnext': {seed_number: read_nn_online_data('multirep', 'convnext', seed_number, num_years = 5) for seed_number in seed_numbers},
    'encdec_lstm': {seed_number: read_nn_online_data('multirep', 'encdec_lstm', seed_number, num_years = 5) for seed_number in seed_numbers}
}
ds_nn_v6_5_year = {
    'unet': {seed_number: read_nn_online_data('v6', 'unet', seed_number, num_years = 5) for seed_number in seed_numbers},
    'squeezeformer': {seed_number: read_nn_online_data('v6', 'squeezeformer', seed_number, num_years = 5) for seed_number in seed_numbers},
    'pure_resLSTM': {seed_number: read_nn_online_data('v6', 'pure_resLSTM', seed_number, num_years = 5) for seed_number in seed_numbers},
    'pao_model': {seed_number: read_nn_online_data('v6', 'pao_model', seed_number, num_years = 5) for seed_number in seed_numbers},
    'convnext': {seed_number: read_nn_online_data('v6', 'convnext', seed_number, num_years = 5) for seed_number in seed_numbers},
    'encdec_lstm': {seed_number: read_nn_online_data('v6', 'encdec_lstm', seed_number, num_years = 5) for seed_number in seed_numbers}
}

In [None]:
ds_nn_standard_4_year = {
    'unet': {seed_number: read_nn_online_data('standard', 'unet', seed_number, num_years = 4) for seed_number in seed_numbers},
    'squeezeformer': {seed_number: read_nn_online_data('standard', 'squeezeformer', seed_number, num_years = 4) for seed_number in seed_numbers},
    'pure_resLSTM': {seed_number: read_nn_online_data('standard', 'pure_resLSTM', seed_number, num_years = 4) for seed_number in seed_numbers},
    'pao_model': {seed_number: read_nn_online_data('standard', 'pao_model', seed_number, num_years = 4) for seed_number in seed_numbers},
    'convnext': {seed_number: read_nn_online_data('standard', 'convnext', seed_number, num_years = 4) for seed_number in seed_numbers},
    'encdec_lstm': {seed_number: read_nn_online_data('standard', 'encdec_lstm', seed_number, num_years = 4) for seed_number in seed_numbers}
}
ds_nn_conf_loss_4_year = {
    'unet': {seed_number: read_nn_online_data('conf_loss', 'unet', seed_number, num_years = 4) for seed_number in seed_numbers},
    'squeezeformer': {seed_number: read_nn_online_data('conf_loss', 'squeezeformer', seed_number, num_years = 4) for seed_number in seed_numbers},
    'pure_resLSTM': {seed_number: read_nn_online_data('conf_loss', 'pure_resLSTM', seed_number, num_years = 4) for seed_number in seed_numbers},
    'pao_model': {seed_number: read_nn_online_data('conf_loss', 'pao_model', seed_number, num_years = 4) for seed_number in seed_numbers},
    'convnext': {seed_number: read_nn_online_data('conf_loss', 'convnext', seed_number, num_years = 4) for seed_number in seed_numbers},
    'encdec_lstm': {seed_number: read_nn_online_data('conf_loss', 'encdec_lstm', seed_number, num_years = 4) for seed_number in seed_numbers}
}
ds_nn_diff_loss_4_year = {
    'unet': {seed_number: read_nn_online_data('diff_loss', 'unet', seed_number, num_years = 4) for seed_number in seed_numbers},
    'squeezeformer': {seed_number: read_nn_online_data('diff_loss', 'squeezeformer', seed_number, num_years = 4) for seed_number in seed_numbers},
    'pure_resLSTM': {seed_number: read_nn_online_data('diff_loss', 'pure_resLSTM', seed_number, num_years = 4) for seed_number in seed_numbers},
    'pao_model': {seed_number: read_nn_online_data('diff_loss', 'pao_model', seed_number, num_years = 4) for seed_number in seed_numbers},
    'convnext': {seed_number: read_nn_online_data('diff_loss', 'convnext', seed_number, num_years = 4) for seed_number in seed_numbers},
    'encdec_lstm': {seed_number: read_nn_online_data('diff_loss', 'encdec_lstm', seed_number, num_years = 4) for seed_number in seed_numbers}
}
ds_nn_multirep_4_year = {
    'unet': {seed_number: read_nn_online_data('multirep', 'unet', seed_number, num_years = 4) for seed_number in seed_numbers},
    'squeezeformer': {seed_number: read_nn_online_data('multirep', 'squeezeformer', seed_number, num_years = 4) for seed_number in seed_numbers},
    'pure_resLSTM': {seed_number: read_nn_online_data('multirep', 'pure_resLSTM', seed_number, num_years = 4) for seed_number in seed_numbers},
    'pao_model': {seed_number: read_nn_online_data('multirep', 'pao_model', seed_number, num_years = 4) for seed_number in seed_numbers},
    'convnext': {seed_number: read_nn_online_data('multirep', 'convnext', seed_number, num_years = 4) for seed_number in seed_numbers},
    'encdec_lstm': {seed_number: read_nn_online_data('multirep', 'encdec_lstm', seed_number, num_years = 4) for seed_number in seed_numbers}
}
ds_nn_v6_4_year = {
    'unet': {seed_number: read_nn_online_data('v6', 'unet', seed_number, num_years = 4) for seed_number in seed_numbers},
    'squeezeformer': {seed_number: read_nn_online_data('v6', 'squeezeformer', seed_number, num_years = 4) for seed_number in seed_numbers},
    'pure_resLSTM': {seed_number: read_nn_online_data('v6', 'pure_resLSTM', seed_number, num_years = 4) for seed_number in seed_numbers},
    'pao_model': {seed_number: read_nn_online_data('v6', 'pao_model', seed_number, num_years = 4) for seed_number in seed_numbers},
    'convnext': {seed_number: read_nn_online_data('v6', 'convnext', seed_number, num_years = 4) for seed_number in seed_numbers},
    'encdec_lstm': {seed_number: read_nn_online_data('v6', 'encdec_lstm', seed_number, num_years = 4) for seed_number in seed_numbers}
}

In [None]:
online_nn_data_save_path = '/global/cfs/cdirs/m4334/jerry/climsim3_figures/online/online_nn_data'

file_name = "ds_nn_standard_5_year.pkl"
with open(os.path.join(online_nn_data_save_path, file_name), "wb") as file:
    pickle.dump(ds_nn_standard_5_year, file)

file_name = "ds_nn_conf_loss_5_year.pkl"
with open(os.path.join(online_nn_data_save_path, file_name), "wb") as file:
    pickle.dump(ds_nn_conf_loss_5_year, file)

file_name = "ds_nn_diff_loss_5_year.pkl"
with open(os.path.join(online_nn_data_save_path, file_name), "wb") as file:
    pickle.dump(ds_nn_diff_loss_5_year, file)

file_name = "ds_nn_multirep_5_year.pkl"
with open(os.path.join(online_nn_data_save_path, file_name), "wb") as file:
    pickle.dump(ds_nn_multirep_5_year, file)

file_name = "ds_nn_v6_5_year.pkl"
with open(os.path.join(online_nn_data_save_path, file_name), "wb") as file:
    pickle.dump(ds_nn_v6_5_year, file)

file_name = "ds_nn_standard_4_year.pkl"
with open(os.path.join(online_nn_data_save_path, file_name), "wb") as file:
    pickle.dump(ds_nn_standard_4_year, file)

file_name = "ds_nn_conf_loss_4_year.pkl"
with open(os.path.join(online_nn_data_save_path, file_name), "wb") as file:
    pickle.dump(ds_nn_conf_loss_4_year, file)

file_name = "ds_nn_diff_loss_4_year.pkl"
with open(os.path.join(online_nn_data_save_path, file_name), "wb") as file:
    pickle.dump(ds_nn_diff_loss_4_year, file)

file_name = "ds_nn_multirep_4_year.pkl"
with open(os.path.join(online_nn_data_save_path, file_name), "wb") as file:
    pickle.dump(ds_nn_multirep_4_year, file)

file_name = "ds_nn_v6_4_year.pkl"
with open(os.path.join(online_nn_data_save_path, file_name), "wb") as file:
    pickle.dump(ds_nn_v6_4_year, file)

In [None]:
ds_nn_rmse_standard_5_year = {
    'T':{model_name: np.array([calculate_rmse(ds_nn_standard_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'T', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'Q':{model_name: np.array([calculate_rmse(ds_nn_standard_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'Q', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'U':{model_name: np.array([calculate_rmse(ds_nn_standard_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'U', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'V':{model_name: np.array([calculate_rmse(ds_nn_standard_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'V', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'CLDLIQ':{model_name: np.array([calculate_rmse(ds_nn_standard_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'CLDLIQ', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'CLDICE':{model_name: np.array([calculate_rmse(ds_nn_standard_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'CLDICE', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DTPHYS':{model_name: np.array([calculate_rmse(ds_nn_standard_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DTPHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DQ1PHYS':{model_name: np.array([calculate_rmse(ds_nn_standard_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DQ1PHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DQ2PHYS':{model_name: np.array([calculate_rmse(ds_nn_standard_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DQ2PHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DQ3PHYS':{model_name: np.array([calculate_rmse(ds_nn_standard_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DQ3PHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
}
ds_nn_rmse_conf_loss_5_year = {
    'T':{model_name: np.array([calculate_rmse(ds_nn_conf_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'T', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'Q':{model_name: np.array([calculate_rmse(ds_nn_conf_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'Q', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'U':{model_name: np.array([calculate_rmse(ds_nn_conf_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'U', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'V':{model_name: np.array([calculate_rmse(ds_nn_conf_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'V', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'CLDLIQ':{model_name: np.array([calculate_rmse(ds_nn_conf_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'CLDLIQ', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'CLDICE':{model_name: np.array([calculate_rmse(ds_nn_conf_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'CLDICE', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DTPHYS':{model_name: np.array([calculate_rmse(ds_nn_conf_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DTPHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DQ1PHYS':{model_name: np.array([calculate_rmse(ds_nn_conf_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DQ1PHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DQ2PHYS':{model_name: np.array([calculate_rmse(ds_nn_conf_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DQ2PHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DQ3PHYS':{model_name: np.array([calculate_rmse(ds_nn_conf_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DQ3PHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
}
ds_nn_rmse_diff_loss_5_year = {
    'T':{model_name: np.array([calculate_rmse(ds_nn_diff_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'T', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'Q':{model_name: np.array([calculate_rmse(ds_nn_diff_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'Q', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'U':{model_name: np.array([calculate_rmse(ds_nn_diff_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'U', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'V':{model_name: np.array([calculate_rmse(ds_nn_diff_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'V', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'CLDLIQ':{model_name: np.array([calculate_rmse(ds_nn_diff_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'CLDLIQ', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'CLDICE':{model_name: np.array([calculate_rmse(ds_nn_diff_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'CLDICE', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DTPHYS':{model_name: np.array([calculate_rmse(ds_nn_diff_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DTPHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DQ1PHYS':{model_name: np.array([calculate_rmse(ds_nn_diff_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DQ1PHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DQ2PHYS':{model_name: np.array([calculate_rmse(ds_nn_diff_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DQ2PHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DQ3PHYS':{model_name: np.array([calculate_rmse(ds_nn_diff_loss_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DQ3PHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
}
ds_nn_rmse_multirep_5_year = {
    'T':{model_name: np.array([calculate_rmse(ds_nn_multirep_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'T', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'Q':{model_name: np.array([calculate_rmse(ds_nn_multirep_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'Q', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'U':{model_name: np.array([calculate_rmse(ds_nn_multirep_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'U', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'V':{model_name: np.array([calculate_rmse(ds_nn_multirep_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'V', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'CLDLIQ':{model_name: np.array([calculate_rmse(ds_nn_multirep_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'CLDLIQ', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'CLDICE':{model_name: np.array([calculate_rmse(ds_nn_multirep_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'CLDICE', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DTPHYS':{model_name: np.array([calculate_rmse(ds_nn_multirep_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DTPHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DQ1PHYS':{model_name: np.array([calculate_rmse(ds_nn_multirep_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DQ1PHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DQ2PHYS':{model_name: np.array([calculate_rmse(ds_nn_multirep_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DQ2PHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DQ3PHYS':{model_name: np.array([calculate_rmse(ds_nn_multirep_5_year[model_name][seed_number], ds_mmf_1_5_year, mmf_1_total_weight_5_year, var = 'DQ3PHYS', num_years = 5) for seed_number in seed_numbers]) for model_name in model_names.keys()},
}
ds_nn_rmse_v6_4_year = {
    'T':{model_name: np.array([calculate_rmse(ds_nn_v6_4_year[model_name][seed_number], ds_mmf_1_4_year, mmf_1_total_weight_4_year, var = 'T', num_years = 4) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'Q':{model_name: np.array([calculate_rmse(ds_nn_v6_4_year[model_name][seed_number], ds_mmf_1_4_year, mmf_1_total_weight_4_year, var = 'Q', num_years = 4) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'U':{model_name: np.array([calculate_rmse(ds_nn_v6_4_year[model_name][seed_number], ds_mmf_1_4_year, mmf_1_total_weight_4_year, var = 'U', num_years = 4) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'V':{model_name: np.array([calculate_rmse(ds_nn_v6_4_year[model_name][seed_number], ds_mmf_1_4_year, mmf_1_total_weight_4_year, var = 'V', num_years = 4) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'CLDLIQ':{model_name: np.array([calculate_rmse(ds_nn_v6_4_year[model_name][seed_number], ds_mmf_1_4_year, mmf_1_total_weight_4_year, var = 'CLDLIQ', num_years = 4) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'CLDICE':{model_name: np.array([calculate_rmse(ds_nn_v6_4_year[model_name][seed_number], ds_mmf_1_4_year, mmf_1_total_weight_4_year, var = 'CLDICE', num_years = 4) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DTPHYS':{model_name: np.array([calculate_rmse(ds_nn_v6_4_year[model_name][seed_number], ds_mmf_1_4_year, mmf_1_total_weight_4_year, var = 'DTPHYS', num_years = 4) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DQ1PHYS':{model_name: np.array([calculate_rmse(ds_nn_v6_4_year[model_name][seed_number], ds_mmf_1_4_year, mmf_1_total_weight_4_year, var = 'DQ1PHYS', num_years = 4) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DQ2PHYS':{model_name: np.array([calculate_rmse(ds_nn_v6_4_year[model_name][seed_number], ds_mmf_1_4_year, mmf_1_total_weight_4_year, var = 'DQ2PHYS', num_years = 4) for seed_number in seed_numbers]) for model_name in model_names.keys()},
    'DQ3PHYS':{model_name: np.array([calculate_rmse(ds_nn_v6_4_year[model_name][seed_number], ds_mmf_1_4_year, mmf_1_total_weight_4_year, var = 'DQ3PHYS', num_years = 4) for seed_number in seed_numbers]) for model_name in model_names.keys()},
}

In [None]:
online_rmse_growth_save_path = '/global/cfs/cdirs/m4334/jerry/climsim3_figures/online/online_rmse_growth'

file_name = "mmf_1_5_year_rmse.pkl"
with open(os.path.join(online_rmse_growth_save_path, file_name), "wb") as file:
    pickle.dump(mmf_1_5_year_rmse, file)

file_name = "mmf_1_4_year_rmse.pkl"
with open(os.path.join(online_rmse_growth_save_path, file_name), "wb") as file:
    pickle.dump(mmf_1_4_year_rmse, file)

file_name = "ds_nn_rmse_standard_5_year.pkl"
with open(os.path.join(online_rmse_growth_save_path, file_name), "wb") as file:
    pickle.dump(ds_nn_rmse_standard_5_year, file)

file_name = "ds_nn_rmse_conf_loss_5_year.pkl"
with open(os.path.join(online_rmse_growth_save_path, file_name), "wb") as file:
    pickle.dump(ds_nn_rmse_conf_loss_5_year, file)

file_name = "ds_nn_rmse_diff_loss_5_year.pkl"
with open(os.path.join(online_rmse_growth_save_path, file_name), "wb") as file:
    pickle.dump(ds_nn_rmse_diff_loss_5_year, file)

file_name = "ds_nn_rmse_multirep_5_year.pkl"
with open(os.path.join(online_rmse_growth_save_path, file_name), "wb") as file:
    pickle.dump(ds_nn_rmse_multirep_5_year, file)

file_name = "ds_nn_rmse_v6_4_year.pkl"
with open(os.path.join(online_rmse_growth_save_path, file_name), "wb") as file:
    pickle.dump(ds_nn_rmse_v6_4_year, file)

In [None]:
mmf_1_hourly_prect_4_year = xr.open_mfdataset('/pscratch/sd/z/zeyuanhu/hu_etal2024_data_v2/data_hourly/precip_hourly/mmf_ref/PRECT*nc')['PRECT'].sel(time = slice(None, str(4 + 2).zfill(4))).values * 86400 * 1000
mmf_1_hourly_prect_4_year_flat = mmf_1_hourly_prect_4_year.flatten()

nn_hourly_prect_standard_5_year = {model_name: {seed_number: read_nn_online_precip_data('standard', model_name, seed_number, num_years = 5) for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_standard_5_year = {model_name: {seed_number: nn_hourly_prect_standard_5_year[model_name][seed_number].sel(time = slice(None, str(5 + 2).zfill(4))).values * 86400 * 1000 if nn_hourly_prect_standard_5_year[model_name][seed_number] is not None else []
                                for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_standard_5_year = {model_name: {seed_number: nn_hourly_prect_standard_5_year[model_name][seed_number] if len(nn_hourly_prect_standard_5_year[model_name][seed_number]) == 365 * 24 * 5 else []
                                    for seed_number in seed_numbers} for model_name in model_names.keys()}

nn_hourly_prect_conf_loss_5_year = {model_name: {seed_number: read_nn_online_precip_data('conf_loss', model_name, seed_number, num_years = 5) for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_conf_loss_5_year = {model_name: {seed_number: nn_hourly_prect_conf_loss_5_year[model_name][seed_number].sel(time = slice(None, str(5 + 2).zfill(4))).values * 86400 * 1000 if nn_hourly_prect_conf_loss_5_year[model_name][seed_number] is not None else []
                                for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_conf_loss_5_year = {model_name: {seed_number: nn_hourly_prect_conf_loss_5_year[model_name][seed_number] if len(nn_hourly_prect_conf_loss_5_year[model_name][seed_number]) == 365 * 24 * 5 else []
                                    for seed_number in seed_numbers} for model_name in model_names.keys()}

nn_hourly_prect_diff_loss_5_year = {model_name: {seed_number: read_nn_online_precip_data('diff_loss', model_name, seed_number, num_years = 5) for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_diff_loss_5_year = {model_name: {seed_number: nn_hourly_prect_diff_loss_5_year[model_name][seed_number].sel(time = slice(None, str(5 + 2).zfill(4))).values * 86400 * 1000 if nn_hourly_prect_diff_loss_5_year[model_name][seed_number] is not None else []
                                for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_diff_loss_5_year = {model_name: {seed_number: nn_hourly_prect_diff_loss_5_year[model_name][seed_number] if len(nn_hourly_prect_diff_loss_5_year[model_name][seed_number]) == 365 * 24 * 5 else []
                                    for seed_number in seed_numbers} for model_name in model_names.keys()}

nn_hourly_prect_multirep_5_year = {model_name: {seed_number: read_nn_online_precip_data('multirep', model_name, seed_number, num_years = 5) for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_multirep_5_year = {model_name: {seed_number: nn_hourly_prect_multirep_5_year[model_name][seed_number].sel(time = slice(None, str(5 + 2).zfill(4))).values * 86400 * 1000 if nn_hourly_prect_multirep_5_year[model_name][seed_number] is not None else []
                                for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_multirep_5_year = {model_name: {seed_number: nn_hourly_prect_multirep_5_year[model_name][seed_number] if len(nn_hourly_prect_multirep_5_year[model_name][seed_number]) == 365 * 24 * 5 else []
                                    for seed_number in seed_numbers} for model_name in model_names.keys()}

nn_hourly_prect_v6_5_year = {model_name: {seed_number: read_nn_online_precip_data('v6', model_name, seed_number, num_years = 5) for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_v6_5_year = {model_name: {seed_number: nn_hourly_prect_v6_5_year[model_name][seed_number].sel(time = slice(None, str(5 + 2).zfill(4))).values * 86400 * 1000 if nn_hourly_prect_v6_5_year[model_name][seed_number] is not None else []
                                for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_v6_5_year = {model_name: {seed_number: nn_hourly_prect_v6_5_year[model_name][seed_number] if len(nn_hourly_prect_v6_5_year[model_name][seed_number]) == 365 * 24 * 5 else []
                                    for seed_number in seed_numbers} for model_name in model_names.keys()}

nn_hourly_prect_standard_4_year = {model_name: {seed_number: read_nn_online_precip_data('standard', model_name, seed_number, num_years = 4) for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_standard_4_year = {model_name: {seed_number: nn_hourly_prect_standard_4_year[model_name][seed_number].sel(time = slice(None, str(4 + 2).zfill(4))).values * 86400 * 1000 if nn_hourly_prect_standard_4_year[model_name][seed_number] is not None else []
                                for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_standard_4_year = {model_name: {seed_number: nn_hourly_prect_standard_4_year[model_name][seed_number] if len(nn_hourly_prect_standard_4_year[model_name][seed_number]) == 365 * 24 * 4 else []
                                    for seed_number in seed_numbers} for model_name in model_names.keys()}

nn_hourly_prect_conf_loss_4_year = {model_name: {seed_number: read_nn_online_precip_data('conf_loss', model_name, seed_number, num_years = 4) for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_conf_loss_4_year = {model_name: {seed_number: nn_hourly_prect_conf_loss_4_year[model_name][seed_number].sel(time = slice(None, str(4 + 2).zfill(4))).values * 86400 * 1000 if nn_hourly_prect_conf_loss_4_year[model_name][seed_number] is not None else []
                                for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_conf_loss_4_year = {model_name: {seed_number: nn_hourly_prect_conf_loss_4_year[model_name][seed_number] if len(nn_hourly_prect_conf_loss_4_year[model_name][seed_number]) == 365 * 24 * 4 else []
                                    for seed_number in seed_numbers} for model_name in model_names.keys()}

nn_hourly_prect_diff_loss_4_year = {model_name: {seed_number: read_nn_online_precip_data('diff_loss', model_name, seed_number, num_years = 4) for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_diff_loss_4_year = {model_name: {seed_number: nn_hourly_prect_diff_loss_4_year[model_name][seed_number].sel(time = slice(None, str(4 + 2).zfill(4))).values * 86400 * 1000 if nn_hourly_prect_diff_loss_4_year[model_name][seed_number] is not None else []
                                for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_diff_loss_4_year = {model_name: {seed_number: nn_hourly_prect_diff_loss_4_year[model_name][seed_number] if len(nn_hourly_prect_diff_loss_4_year[model_name][seed_number]) == 365 * 24 * 4 else []
                                    for seed_number in seed_numbers} for model_name in model_names.keys()}

nn_hourly_prect_multirep_4_year = {model_name: {seed_number: read_nn_online_precip_data('multirep', model_name, seed_number, num_years = 4) for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_multirep_4_year = {model_name: {seed_number: nn_hourly_prect_multirep_4_year[model_name][seed_number].sel(time = slice(None, str(4 + 2).zfill(4))).values * 86400 * 1000 if nn_hourly_prect_multirep_4_year[model_name][seed_number] is not None else []
                                for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_multirep_4_year = {model_name: {seed_number: nn_hourly_prect_multirep_4_year[model_name][seed_number] if len(nn_hourly_prect_multirep_4_year[model_name][seed_number]) == 365 * 24 * 4 else []
                                    for seed_number in seed_numbers} for model_name in model_names.keys()}

nn_hourly_prect_v6_4_year = {model_name: {seed_number: read_nn_online_precip_data('v6', model_name, seed_number, num_years = 4) for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_v6_4_year = {model_name: {seed_number: nn_hourly_prect_v6_4_year[model_name][seed_number].sel(time = slice(None, str(4 + 2).zfill(4))).values * 86400 * 1000 if nn_hourly_prect_v6_4_year[model_name][seed_number] is not None else []
                                for seed_number in seed_numbers} for model_name in model_names.keys()}
nn_hourly_prect_v6_4_year = {model_name: {seed_number: nn_hourly_prect_v6_4_year[model_name][seed_number] if len(nn_hourly_prect_v6_4_year[model_name][seed_number]) == 365 * 24 * 4 else []
                                    for seed_number in seed_numbers} for model_name in model_names.keys()}

In [None]:
online_nn_hourly_prect_save_path = '/global/cfs/cdirs/m4334/jerry/climsim3_figures/online/hourly_prect'

file_name = "hourly_prect_standard_5_year.pkl"
with open(os.path.join(online_nn_hourly_prect_save_path, file_name), "wb") as file:
    pickle.dump(nn_hourly_prect_standard_5_year, file)

file_name = "hourly_prect_conf_loss_5_year.pkl"
with open(os.path.join(online_nn_hourly_prect_save_path, file_name), "wb") as file:
    pickle.dump(nn_hourly_prect_conf_loss_5_year, file)

file_name = "hourly_prect_diff_loss_5_year.pkl"
with open(os.path.join(online_nn_hourly_prect_save_path, file_name), "wb") as file:
    pickle.dump(nn_hourly_prect_diff_loss_5_year, file)

file_name = "hourly_prect_multirep_5_year.pkl"
with open(os.path.join(online_nn_hourly_prect_save_path, file_name), "wb") as file:
    pickle.dump(nn_hourly_prect_multirep_5_year, file)

file_name = "hourly_prect_v6_5_year.pkl"
with open(os.path.join(online_nn_hourly_prect_save_path, file_name), "wb") as file:
    pickle.dump(nn_hourly_prect_v6_5_year, file)

file_name = "hourly_prect_standard_4_year.pkl"
with open(os.path.join(online_nn_hourly_prect_save_path, file_name), "wb") as file:
    pickle.dump(nn_hourly_prect_standard_4_year, file)

file_name = "hourly_prect_conf_loss_4_year.pkl"
with open(os.path.join(online_nn_hourly_prect_save_path, file_name), "wb") as file:
    pickle.dump(nn_hourly_prect_conf_loss_4_year, file)

file_name = "hourly_prect_diff_loss_4_year.pkl"
with open(os.path.join(online_nn_hourly_prect_save_path, file_name), "wb") as file:
    pickle.dump(nn_hourly_prect_diff_loss_4_year, file)

file_name = "hourly_prect_multirep_4_year.pkl"
with open(os.path.join(online_nn_hourly_prect_save_path, file_name), "wb") as file:
    pickle.dump(nn_hourly_prect_multirep_4_year, file)

file_name = "hourly_prect_v6_4_year.pkl"
with open(os.path.join(online_nn_hourly_prect_save_path, file_name), "wb") as file:
    pickle.dump(nn_hourly_prect_v6_4_year, file)

In [13]:
def get_global_rmse(config_name, num_years):
    months = np.arange(1, num_years * 12 + 1)
    ds_mmf_1, ds_mmf_2 = read_mmf_online_data(num_years)
    mmf_1_total_weight = get_pressure_area_weights(ds_mmf_1)
    ds_nn = {
        'unet': {seed_number: read_nn_online_data(config_name, 'unet', seed_number, num_years) for seed_number in seed_numbers},
        'squeezeformer': {seed_number: read_nn_online_data(config_name, 'squeezeformer', seed_number, num_years) for seed_number in seed_numbers},
        'pure_resLSTM': {seed_number: read_nn_online_data(config_name, 'pure_resLSTM', seed_number, num_years) for seed_number in seed_numbers},
        'pao_model': {seed_number: read_nn_online_data(config_name, 'pao_model', seed_number, num_years) for seed_number in seed_numbers},
        'convnext': {seed_number: read_nn_online_data(config_name, 'convnext', seed_number, num_years) for seed_number in seed_numbers},
        'encdec_lstm': {seed_number: read_nn_online_data(config_name, 'encdec_lstm', seed_number, num_years) for seed_number in seed_numbers}
    }
    variables = ['T', 'Q', 'U', 'V', 'CLDLIQ', 'CLDICE', 'DTPHYS', 'DQ1PHYS', 'DQnPHYS', 'DUPHYS']
    def load_nn_var_time_mean(ds_nn_xr, var, num_years):
        return_vals = np.full((data_v2_rh_mc.num_levels, data_v2_rh_mc.num_latlon), np.nan)
        if not ds_nn_xr or len(ds_nn_xr['time']) < num_years * 12:
            return return_vals
        else:
            return ds_nn_xr[var].mean(dim = 'time').values
    ds_nn_rmse_global_dict = {}
    for var in variables:
        ds_mmf_1_mean = ds_mmf_1[var].mean(dim = 'time').values * online_var_settings[var]['scaling']
        ds_nn_rmse_global_dict[var] = {
            'unet': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['unet'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers]),
            'squeezeformer': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['squeezeformer'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers]),
            'pure_resLSTM': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['pure_resLSTM'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers]),
            'pao_model': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['pao_model'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers]),
            'convnext': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['convnext'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers]),
            'encdec_lstm': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['encdec_lstm'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers])
        }
    return ds_nn_rmse_global_dict

print('calculating 5 year standard global rmse')
standard_global_rmse_5_year = get_global_rmse('standard', 5)
print('calculating 5 year conf_loss global rmse')
conf_loss_global_rmse_5_year = get_global_rmse('conf_loss', 5)
print('calculating 5 year diff_loss global rmse')
diff_loss_global_rmse_5_year = get_global_rmse('diff_loss', 5)
print('calculating 5 year multirep global rmse')
multirep_global_rmse_5_year = get_global_rmse('multirep', 5)
print('calculating 5 year v6 global rmse')
v6_global_rmse_5_year = get_global_rmse('v6', 5)

print('calculating 4 year standard global rmse')
standard_global_rmse_4_year = get_global_rmse('standard', 4)
print('calculating 4 year conf_loss global rmse')
conf_loss_global_rmse_4_year = get_global_rmse('conf_loss', 4)
print('calculating 4 year diff_loss global rmse')
diff_loss_global_rmse_4_year = get_global_rmse('diff_loss', 4)
print('calculating 4 year multirep global rmse')
multirep_global_rmse_4_year = get_global_rmse('multirep', 4)
print('calculating 4 year v6 global rmse')
v6_global_rmse_4_year = get_global_rmse('v6', 4)

dict_save_path = '/global/cfs/cdirs/m4334/jerry/climsim3_figures/online/online_global_rmse_config_dicts'

dict_file_path_standard_5_year = os.path.join(dict_save_path, 'standard_global_rmse_5_year.pkl')
dict_file_path_conf_loss_5_year = os.path.join(dict_save_path, 'conf_loss_global_rmse_5_year.pkl')
dict_file_path_diff_loss_5_year = os.path.join(dict_save_path, 'diff_loss_global_rmse_5_year.pkl')
dict_file_path_multirep_5_year = os.path.join(dict_save_path, 'multirep_global_rmse_5_year.pkl')
dict_file_path_v6_5_year = os.path.join(dict_save_path, 'v6_global_rmse_5_year.pkl')

with open(dict_file_path_standard_5_year, "wb") as file:
    pickle.dump(standard_global_rmse_5_year, file)

with open(dict_file_path_conf_loss_5_year, "wb") as file:
    pickle.dump(conf_loss_global_rmse_5_year, file)

with open(dict_file_path_diff_loss_5_year, "wb") as file:
    pickle.dump(diff_loss_global_rmse_5_year, file)

with open(dict_file_path_multirep_5_year, "wb") as file:
    pickle.dump(multirep_global_rmse_5_year, file)

with open(dict_file_path_v6_5_year, "wb") as file:
    pickle.dump(v6_global_rmse_5_year, file)

dict_file_path_standard_4_year = os.path.join(dict_save_path, 'standard_global_rmse_4_year.pkl')
dict_file_path_conf_loss_4_year = os.path.join(dict_save_path, 'conf_loss_global_rmse_4_year.pkl')
dict_file_path_diff_loss_4_year = os.path.join(dict_save_path, 'diff_loss_global_rmse_4_year.pkl')
dict_file_path_multirep_4_year = os.path.join(dict_save_path, 'multirep_global_rmse_4_year.pkl')
dict_file_path_v6_4_year = os.path.join(dict_save_path, 'v6_global_rmse_4_year.pkl')

with open(dict_file_path_standard_4_year, "wb") as file:
    pickle.dump(standard_global_rmse_4_year, file)

with open(dict_file_path_conf_loss_4_year, "wb") as file:
    pickle.dump(conf_loss_global_rmse_4_year, file)

with open(dict_file_path_diff_loss_4_year, "wb") as file:
    pickle.dump(diff_loss_global_rmse_4_year, file)

with open(dict_file_path_multirep_4_year, "wb") as file:
    pickle.dump(multirep_global_rmse_4_year, file)

with open(dict_file_path_v6_4_year, "wb") as file:
    pickle.dump(v6_global_rmse_4_year, file)

calculating 5 year standard global rmse
calculating 5 year conf_loss global rmse
calculating 5 year diff_loss global rmse
calculating 5 year multirep global rmse
calculating 5 year v6 global rmse


ls: cannot access '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/v6/five_year_runs/squeezeformer_v6_seed_7/run/squeezeformer_v6_seed_7.eam.h0.000[34567]*.nc': No such file or directory
ls: cannot access '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/v6/five_year_runs/squeezeformer_v6_seed_43/run/squeezeformer_v6_seed_43.eam.h0.000[34567]*.nc': No such file or directory
ls: cannot access '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/v6/five_year_runs/squeezeformer_v6_seed_1024/run/squeezeformer_v6_seed_1024.eam.h0.000[34567]*.nc': No such file or directory


calculating 4 year standard global rmse
calculating 4 year conf_loss global rmse
calculating 4 year diff_loss global rmse
calculating 4 year multirep global rmse
calculating 4 year v6 global rmse


ls: cannot access '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/v6/five_year_runs/squeezeformer_v6_seed_7/run/squeezeformer_v6_seed_7.eam.h0.000[3456]*.nc': No such file or directory
ls: cannot access '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/v6/five_year_runs/squeezeformer_v6_seed_43/run/squeezeformer_v6_seed_43.eam.h0.000[3456]*.nc': No such file or directory
ls: cannot access '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/v6/five_year_runs/squeezeformer_v6_seed_1024/run/squeezeformer_v6_seed_1024.eam.h0.000[3456]*.nc': No such file or directory


In [24]:
nn_global_rmse_dict_5_year = {
    'standard': standard_global_rmse_5_year,
    'conf_loss': conf_loss_global_rmse_5_year,
    'diff_loss': diff_loss_global_rmse_5_year,
    'multirep': multirep_global_rmse_5_year,
    'v6': v6_global_rmse_5_year
}

global_rmse_dict_full_5_year = {'T': [], 'Q': [], 'U': [], 'V': [], 'CLDLIQ': [], 'CLDICE': [], 'DTPHYS': [], 'DQ1PHYS': [], 'DQnPHYS': [], 'DUPHYS': []}
for var in global_rmse_dict_full_5_year.keys():
    for config_name in config_names.keys():
        for model_name in model_names.keys():
            for idx, seed_number in enumerate(seed_numbers):
                new_result = nn_global_rmse_dict_5_year[config_name][var][model_name][idx]
                if not np.isnan(new_result):
                    global_rmse_dict_full_5_year[var].append({'config_name': config_name, 
                                                    'model_name': model_name,
                                                    'seed_idx': idx, 
                                                    'seed_number': seed_number,
                                                    'rmse': new_result})

nn_global_rmse_dict_4_year = {
    'standard': standard_global_rmse_4_year,
    'conf_loss': conf_loss_global_rmse_4_year,
    'diff_loss': diff_loss_global_rmse_4_year,
    'multirep': multirep_global_rmse_4_year,
    'v6': v6_global_rmse_4_year
}

global_rmse_dict_full_4_year = {'T': [], 'Q': [], 'U': [], 'V': [], 'CLDLIQ': [], 'CLDICE': [], 'DTPHYS': [], 'DQ1PHYS': [], 'DQnPHYS': [], 'DUPHYS': []}
for var in global_rmse_dict_full_4_year.keys():
    for config_name in config_names.keys():
        for model_name in model_names.keys():
            for idx, seed_number in enumerate(seed_numbers):
                new_result = nn_global_rmse_dict_4_year[config_name][var][model_name][idx]
                if not np.isnan(new_result):
                    global_rmse_dict_full_4_year[var].append({'config_name': config_name, 
                                                    'model_name': model_name,
                                                    'seed_idx': idx, 
                                                    'seed_number': seed_number,
                                                    'rmse': new_result})

top_model_dict_5_year = {}
for var in global_rmse_dict_full_5_year.keys():
    global_rmse_dict_full_5_year[var] = sorted(global_rmse_dict_full_5_year[var], key = lambda sota_dict: sota_dict['rmse'])
    top_model_dict_5_year[var] =  global_rmse_dict_full_5_year[var][0]

top_model_dict_4_year = {}
for var in global_rmse_dict_full_4_year.keys():
    global_rmse_dict_full_4_year[var] = sorted(global_rmse_dict_full_4_year[var], key = lambda sota_dict: sota_dict['rmse'])
    top_model_dict_4_year[var] =  global_rmse_dict_full_4_year[var][0]


In [10]:
dict_save_path = '/global/cfs/cdirs/m4334/jerry/climsim3_figures/online/online_global_rmse_config_dicts'
dict_file_path_standard = os.path.join(dict_save_path, 'standard_global_rmse.pkl')
dict_file_path_conf_loss = os.path.join(dict_save_path, 'conf_loss_global_rmse.pkl')
dict_file_path_diff_loss = os.path.join(dict_save_path, 'diff_loss_global_rmse.pkl')
dict_file_path_multirep = os.path.join(dict_save_path, 'multirep_global_rmse.pkl')
dict_file_path_v6 = os.path.join(dict_save_path, 'v6_global_rmse.pkl')

with open(dict_file_path_standard, "rb") as file:
    standard_global_rmse = pickle.load(file)

with open(dict_file_path_conf_loss, "rb") as file:
    conf_loss_global_rmse = pickle.load(file)

with open(dict_file_path_diff_loss, "rb") as file:
    diff_loss_global_rmse = pickle.load(file)

with open(dict_file_path_multirep, "rb") as file:
    multirep_global_rmse = pickle.load(file)

with open(dict_file_path_v6, "rb") as file:
    v6_global_rmse = pickle.load(file)

nn_global_rmse_dict = {
    'standard': standard_global_rmse,
    'conf_loss': conf_loss_global_rmse,
    'diff_loss': diff_loss_global_rmse,
    'multirep': multirep_global_rmse,
    'v6': v6_global_rmse
}

prev_sota = {'T': .98, 'Q': .25, 'CLDLIQ': 5.39, 'CLDICE': 2.09, 'U': 1.68, 'V': .77}
sota_breakers = {'T': [], 'Q': [], 'CLDLIQ': [], 'CLDICE': [], 'U': [], 'V': []}
nonnan_rmse_dict = {'T': [], 'Q': [], 'CLDLIQ': [], 'CLDICE': [], 'U': [], 'V': []}
for var in prev_sota.keys():
    for config_name in config_names.keys():
        for model_name in model_names.keys():
            for idx, seed_number in enumerate(seed_numbers):
                new_result = np.round(nn_global_rmse_dict[config_name][var][model_name][idx], 2)
                if not np.isnan(new_result):
                    nonnan_rmse_dict[var].append({'config_name': config_name, 
                                                  'model_name': model_name,
                                                  'seed_idx': idx, 
                                                  'seed_number': seed_number,
                                                  'rmse': new_result,
                                                  'sypd': sypd_dict[config_name][f'{model_name}_{seed_number}'],
                                                  'pc': pc_dict[config_name][f'{model_name}_{seed_number}']})
                    if new_result < prev_sota[var]:
                        sota_breakers[var].append({'config_name': config_name, 
                                                   'model_name': model_name,
                                                   'seed_idx': idx, 
                                                   'seed_number': seed_number,
                                                   'rmse': new_result,
                                                   'sypd': sypd_dict[config_name][f'{model_name}_{seed_number}'],
                                                   'pc': pc_dict[config_name][f'{model_name}_{seed_number}']})

for var in prev_sota.keys():
    sota_breakers[var] = sorted(sota_breakers[var], key = lambda sota_dict: sota_dict['rmse'])
    nonnan_rmse_dict[var] = sorted(nonnan_rmse_dict[var], key = lambda nonnan_dict: nonnan_dict['rmse'])

zeyuan_path = '/pscratch/sd/z/zeyuanhu/hu_etal2024_data_v2/data/h0/5year/unet_v5/'
def read_nn_online_data_zeyuan(config_name, num_years):
    assert num_years <= 5 and num_years >= 1
    years_regexp = '34567'[:num_years]
    assert config_name in ['huber_rop', 'huber_step']
    if config_name == 'huber_rop':
        extract_path = os.path.join(zeyuan_path, config_name, f'v5_noclassifier_huber_1y_noaggressive_rop2_5year_3node.eam.h0.000[{years_regexp}]*.nc')
    elif config_name == 'huber_step':
        extract_path = os.path.join(zeyuan_path, config_name, f'v5_noclassifier_huber_1y_noaggressive_5year_3node.eam.h0.000[{years_regexp}]*.nc')
    ds_nn = xr.open_mfdataset(extract_path)
    if len(ds_nn['time']) < 12 * num_years:
        return None
    ds_nn['DQnPHYS'] = ds_nn['DQ2PHYS'] + ds_nn['DQ3PHYS']
    ds_nn['TOTCLD'] = ds_nn['CLDICE'] + ds_nn['CLDLIQ']
    ds_nn['PRECT'] = ds_nn['PRECC'] + ds_nn['PRECL']
    return ds_nn

huber_rop_run = read_nn_online_data_zeyuan('huber_rop', 5)
huber_step_run = read_nn_online_data_zeyuan('huber_step', 5)

huetal_sota_dict = {
    'T': (.99, 1.21),
    'Q': (.33, .25),
    'CLDLIQ': (13.05, 5.40),
    'CLDICE': (2.10, 2.29),
    'U': (1.70, 1.98),
    'V': (.79, .89)
}

mmf_ref_dict = {
    'T': .18,
    'Q': .06,
    'CLDLIQ': .80,
    'CLDICE': .65,
    'U': .44,
    'V': .35
}