In [1]:
!hostname

nid004146


In [2]:
!pwd

/global/cfs/cdirs/m4334/jerry/climsim3_dev


# Import packages

In [3]:
import xarray as xr
import numpy as np
import pandas as pd
from sklearn.metrics import r2_score
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.cm import ScalarMappable
from matplotlib import gridspec
import os, gc, sys, glob, string, argparse
from tqdm import tqdm
import time
import itertools
import sys
import pickle
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from climsim_utils.data_utils import *

# Load utilities

In [4]:
grid_path = '/global/cfs/cdirs/m4334/jerry/climsim3_dev/grid_info/ClimSim_low-res_grid-info.nc'

input_mean_v2_rh_mc_file = 'input_mean_v2_rh_mc_pervar.nc'

input_max_v2_rh_mc_file = 'input_max_v2_rh_mc_pervar.nc'
input_min_v2_rh_mc_file = 'input_min_v2_rh_mc_pervar.nc'
output_scale_v2_rh_mc_file = 'output_scale_std_lowerthred_v2_rh_mc.nc'

input_mean_v6_file = 'input_mean_v6_pervar.nc'
input_max_v6_file = 'input_max_v6_pervar.nc'
input_min_v6_file = 'input_min_v6_pervar.nc'
output_scale_v6_file = 'output_scale_std_lowerthred_v6.nc'

lbd_qn_file = 'qn_exp_lambda_large.txt'

grid_info = xr.open_dataset(grid_path)
grid_area = grid_info['area'].values
area_weight = grid_area/np.sum(grid_area)
level = grid_info.lev.values

input_mean_v2_rh_mc = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/' + input_mean_v2_rh_mc_file)
input_max_v2_rh_mc = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/' + input_max_v2_rh_mc_file)
input_min_v2_rh_mc = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/' + input_min_v2_rh_mc_file)
output_scale_v2_rh_mc = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/outputs/' + output_scale_v2_rh_mc_file)

input_mean_v6 = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/' + input_mean_v6_file)
input_max_v6 = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/' + input_max_v6_file)
input_min_v6 = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/' + input_min_v6_file)
output_scale_v6 = xr.open_dataset('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/outputs/' + output_scale_v6_file)

lbd_qn = np.loadtxt('/global/cfs/cdirs/m4334/jerry/climsim3_dev/preprocessing/normalizations/inputs/' + lbd_qn_file, delimiter = ',')

data_v2_rh_mc = data_utils(grid_info = grid_info, 
                           input_mean = input_mean_v2_rh_mc, 
                           input_max = input_max_v2_rh_mc, 
                           input_min = input_min_v2_rh_mc, 
                           output_scale = output_scale_v2_rh_mc,
                           qinput_log = False,
                           normalize = False)
data_v2_rh_mc.set_to_v2_rh_mc_vars()

data_v6 = data_utils(grid_info = grid_info,
                     input_mean = input_mean_v6,
                     input_max = input_max_v6,
                     input_min = input_min_v6,
                     output_scale = output_scale_v6,
                     qinput_log = False,
                     normalize = False)                     
data_v6.set_to_v6_vars()

actual_input_v2_rh_mc = np.load('/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v2_rh_mc/test_set/actual_input.npy')
actual_target_v2_rh_mc = np.load('/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v2_rh_mc/test_set/actual_target.npy')

actual_input_v6 = np.load('/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v6/test_set/actual_input.npy')
actual_target_v6 = np.load('/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v6/test_set/actual_target.npy')

assert np.array_equal(actual_target_v2_rh_mc, actual_target_v6)
actual_target = np.load('/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v2_rh_mc/test_set/original_test_vars/actual_target_v2.npy')
del actual_target_v2_rh_mc
del actual_target_v6

assert np.array_equal(actual_input_v2_rh_mc[:,:,data_v2_rh_mc.ps_index], 
                      actual_input_v6[:,:,data_v6.ps_index])

surface_pressure = actual_input_v2_rh_mc[:, :, data_v2_rh_mc.ps_index]
hyam_component = (data_v2_rh_mc.hyam * data_v2_rh_mc.p0)[None,None,:]
hybm_component = data_v2_rh_mc.hybm[None,None,:] * surface_pressure[:,:,None]
pressures = hyam_component + hybm_component
pressures_binned = data_v2_rh_mc.zonal_bin_weight_3d(pressures)
lat_bin_mids = data_v2_rh_mc.lat_bin_mids
lats = data_v2_rh_mc.lats
lons = data_v2_rh_mc.lons

idx_p400_t10 = np.load('/pscratch/sd/z/zeyuanhu/hu_etal2024_data/microphysics_hourly/first_true_indices_p400_t10.npy')
for i in range(idx_p400_t10.shape[0]):
    for j in range(idx_p400_t10.shape[1]):
        idx_p400_t10[i,j] = level[int(idx_p400_t10[i,j])]

idx_p400_t10 = idx_p400_t10.mean(axis=0)
idx_p400_t10 = idx_p400_t10[np.newaxis,:]

idx_tropopause_zm = data_v2_rh_mc.zonal_bin_weight_2d(idx_p400_t10).flatten()

area_weight_dict = {
    'global': area_weight,
    'nh': np.where(lats > 30, area_weight, 0),
    'sh': np.where(lats < -30, area_weight, 0),
    'tropics': np.where((lats > -30) & (lats < 30), area_weight, 0)
}

lat_idx_dict = {
    '30S_30N': ((data_v2_rh_mc.lats < 30) & (data_v2_rh_mc.lats > -30))[None,:,None],
    '30N_60N': ((data_v2_rh_mc.lats < 60) & (data_v2_rh_mc.lats > 30))[None,:,None],
    '30S_60S': ((data_v2_rh_mc.lats < -30) & (data_v2_rh_mc.lats > -60))[None,:,None],
    '60N_90N': (data_v2_rh_mc.lats > 60)[None,:,None],
    '60S_90S': (data_v2_rh_mc.lats < -60)[None,:,None]
}

pressure_idx_dict = {
    'below_400hPa': pressures >= 400,
    'above_400hPa': pressures < 400
}

config_names = {
    'standard': 'Standard',
    'conf_loss': 'Confidence Loss',
    'diff_loss': 'Difference Loss',
    'multirep': 'Multirepresentation',
    'v6': 'Expanded Variable List'
}

model_names = {
    'unet': 'U-Net',
    'squeezeformer': 'Squeezeformer',
    'pure_resLSTM': 'Pure ResLSTM',
    'pao_model': 'Pao Model',
    'convnext': 'ConvNeXt',
    'encdec_lstm': 'Encoder-Decoder LSTM'
}

color_dict = {
    'unet': 'green',
    'squeezeformer': 'purple',
    'pure_resLSTM': 'blue',
    'pao_model': 'red',
    'convnext': 'gold',
    'encdec_lstm': 'orange',
}

color_dict_config = {
    'standard': 'blue',
    'conf_loss': 'cyan',
    'diff_loss': 'red',
    'multirep': 'orange',
    'v6': 'green'
}

offline_var_settings = {
    'DTPHYS': {'var_title': 'dT/dt', 'scaling': 1., 'unit': 'K/s', 'vmax': 5e-7, 'vmin': -5e-7, 'var_index':0},
    'DQ1PHYS': {'var_title': 'dQv/dt', 'scaling': 1e3, 'unit': 'g/kg/s', 'vmax': 1e-6, 'vmin': -1e-6, 'var_index':60},
    'DQ2PHYS': {'var_title': 'dQl/dt', 'scaling': 1e6, 'unit': 'mg/kg/s', 'vmax': 1e-3, 'vmin': -1e-3, 'var_index':120},
    'DQ3PHYS': {'var_title': 'dQi/dt', 'scaling': 1e6, 'unit': 'mg/kg/s', 'vmax': 1e-3, 'vmin': -1e-3, 'var_index':180},
    'DUPHYS': {'var_title': 'dU/dt', 'scaling': 1., 'unit': 'm/s/s', 'vmax': 5e-7, 'vmin': -5e-7, 'var_index':240},
    'DVPHYS': {'var_title': 'dV/dt', 'scaling': 1., 'unit': 'm/s/s', 'vmax': 5e-7, 'vmin': -5e-7, 'var_index':300}
}

online_var_settings = {
    'T': {'var_title': 'Temperature', 'scaling': 1.0, 'unit': 'K', 'vmax': 5, 'vmin': -5},
    'Q': {'var_title': 'Specific Humidity', 'scaling': 1000.0, 'unit': 'g/kg', 'vmax': 1, 'vmin': -1},
    'U': {'var_title': 'Zonal Wind', 'scaling': 1.0, 'unit': 'm/s', 'vmax': 4, 'vmin': -4},
    'V': {'var_title': 'Meridional Wind', 'scaling': 1.0, 'unit': 'm/s', 'vmax': 4, 'vmin': -4},
    'CLDLIQ': {'var_title': 'Liquid Cloud', 'scaling': 1e6, 'unit': 'mg/kg', 'vmax': 40, 'vmin': -40},
    'CLDICE': {'var_title': 'Ice Cloud', 'scaling': 1e6, 'unit': 'mg/kg', 'vmax': 5, 'vmin': -5},
    'TOTCLD': {'var_title': 'Total Cloud', 'scaling': 1e6, 'unit': 'mg/kg', 'vmax': 40, 'vmin': -40},
    'DTPHYS': {'var_title': 'Heating Tendency', 'scaling': 1., 'unit': 'K/s', 'vmax': 1.5e-5, 'vmin': -1.5e-5},
    'DQ1PHYS': {'var_title': 'Moistening Tendency', 'scaling': 1e3, 'unit': 'g/kg/s', 'vmax': 1.2e-5, 'vmin': -1.2e-5},
    'DQ2PHYS': {'var_title': 'Liquid Tendency', 'scaling': 1e6, 'unit': 'mg/kg/s', 'vmax': 0.0015, 'vmin': -0.0015},
    'DQ3PHYS': {'var_title': 'Ice Tendency', 'scaling': 1e6, 'unit': 'mg/kg/s', 'vmax': 0.0015, 'vmin': -0.0015},
    'DQnPHYS': {'var_title': 'Liquid + Ice Tendency', 'scaling': 1e6, 'unit': 'mg/kg/s', 'vmax': .0015, 'vmin': -.0015},
    'DUPHYS': {'var_title': 'Zonal Wind Tendency', 'scaling': 1., 'unit': 'm/s²', 'vmax': 2.2e-6, 'vmin': -2.2e-6}
}

online_paths = {
    'standard': '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/standard/five_year_runs/',
    'conf_loss': '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/conf/five_year_runs/',
    'diff_loss': '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/diff/five_year_runs/',
    'multirep': '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/multirep/five_year_runs/',
    'v6': '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/online_runs/climsim3_ensembles_good/v6/five_year_runs/'
}

seeds = ['seed_7', 'seed_43', 'seed_1024']
seed_numbers = [7, 43, 1024]

climsim3_figures_save_path_offline = '/global/cfs/cdirs/m4334/jerry/climsim3_figures/offline'
climsim3_figures_save_path_online = '/global/cfs/cdirs/m4334/jerry/climsim3_figures/online'

sypd_standard = [
    17.97901035,
    17.97901035,
    18.05589744,
    18.90966668,
    19.03178614,
    18.65975565,
    17.80983045,
    18.08638973,
    17.93128756,
    21.78946853,
    21.9847479,
    22.20038272,
    25.36398855,
    25.19805864,
    25.19227595,
    23.54557187,
    23.82794497,
    23.447502
]

sypd_conf_loss = [
    17.50453936,
    17.29971515,
    17.44196107,
    18.20260208,
    18.19053688,
    18.04921777,
    16.88591639,
    17.06639861,
    17.10229289,
    20.55525559,
    20.63385549,
    20.42521955,
    23.65465549,
    23.3813296,
    23.25625602,
    22.31773163,
    22.25777422,
    21.88176458
]

sypd_diff_loss =[
    16.97469344,
    16.92105472,
    17.16379509,
    18.99145235,
    19.24026684,
    18.69869139,
    17.75509967,
    16.87358759,
    16.66731672,
    20.77602886,
    19.70212994,
    19.22005471,
    21.00260056,
    25.47730606,
    25.47730606,
    24.0893883,
    23.99985625,
    23.78792837
]

sypd_multirep = [
    18.33255552,
    18.35861568,
    18.3448099,
    19.00295884,
    16.93296157,
    18.19204416,
    17.23723679,
    17.27385405,
    17.27657263,
    18.10913991,
    19.86397164,
    21.35810934,
    23.51060144,
    23.60014809,
    24.52549157,
    23.19360711,
    23.43123639,
    23.23164752
]

sypd_v6 = [
    18.21284028,
    18.31251632,
    18.41758393,
    6.894533848,
    8.95130655,
    7.890410959,
    17.84446786,
    17.79133537,
    17.73373234,
    21.39020294,
    17.90628317,
    21.40600448,
    24.47883654,
    24.13656852,
    24.3122588,
    23.34100196,
    23.5582043,
    23.5169076
]

pc_standard = [
    12975373,
    12975373,
    12975373,
    44785225,
    44785225,
    44785225,
    15395341,
    15395341,
    15395341,
    18876133,
    18876133,
    18876133,
    26805429,
    26805429,
    26805429,
    18582976,
    18582976,
    18582976
]

pc_conf_loss = [
    12980634,
    12980634,
    12980634,
    44811862,
    44811862,
    44811862,
    15402010,
    15402010,
    15402010,
    25407346,
    25407346,
    25407346,
    26839242,
    26839242,
    26839242,
    20723124,
    20723124,
    20723124
]

pc_diff_loss = [
    12975373,
    12975373,
    12975373,
    44785225,
    44785225,
    44785225,
    15395341,
    15395341,
    15395341,
    18876133,
    18876133,
    18876133,
    26805429,
    26805429,
    26805429,
    18582976,
    18582976,
    18582976
]

pc_multirep = [
    12981517,
    12981517,
    12981517,
    44791369,
    44791369,
    44791369,
    15428109,
    15428109,
    15428109,
    18880613,
    18880613,
    18880613,
    26811573,
    26811573,
    26811573,
    21004960,
    21004960,
    21004960
]

pc_v6 = [
    12981517,
    12981517,
    12981517,
    44791369,
    44791369,
    44791369,
    15428109,
    15428109,
    15428109,
    18880373,
    18880373,
    18880373,
    26811573,
    26811573,
    26811573,
    21004960,
    21004960,
    21004960
]

standard_sypd_dict = {}
conf_loss_sypd_dict = {}
diff_loss_sypd_dict = {}
multirep_sypd_dict = {}
v6_sypd_dict = {}

standard_pc_dict = {}
conf_loss_pc_dict = {}
diff_loss_pc_dict = {}
multirep_pc_dict = {}
v6_pc_dict = {}

i = 0
for model_name in model_names.keys():
    for seed_number in seed_numbers:
        standard_sypd_dict[f"{model_name}_{seed_number}"] = sypd_standard[i]
        conf_loss_sypd_dict[f"{model_name}_{seed_number}"] = sypd_conf_loss[i]
        diff_loss_sypd_dict[f"{model_name}_{seed_number}"] = sypd_diff_loss[i]
        multirep_sypd_dict[f"{model_name}_{seed_number}"] = sypd_multirep[i]
        v6_sypd_dict[f"{model_name}_{seed_number}"] = sypd_v6[i]
        standard_pc_dict[f"{model_name}_{seed_number}"] = pc_standard[i]
        conf_loss_pc_dict[f"{model_name}_{seed_number}"] = pc_conf_loss[i]
        diff_loss_pc_dict[f"{model_name}_{seed_number}"] = pc_diff_loss[i]
        multirep_pc_dict[f"{model_name}_{seed_number}"] = pc_multirep[i]
        v6_pc_dict[f"{model_name}_{seed_number}"] = pc_v6[i]
        i += 1

# Define helper functions

In [5]:
def ls(data_path = ""):
    return os.popen(" ".join(["ls", data_path])).read().splitlines()

def offline_area_time_mean_3d(arr):
    arr_zonal_mean = data_v2_rh_mc.zonal_bin_weight_3d(arr)
    arr_zonal_time_mean = arr_zonal_mean.mean(axis = 0)
    arr_zonal_time_mean = xr.DataArray(arr_zonal_time_mean.T, dims = ['hybrid pressure (hPa)', 'latitude'], coords = {'hybrid pressure (hPa)':level, 'latitude': lat_bin_mids})
    return arr_zonal_time_mean

def online_area_time_mean_3d(ds, var):
    arr = ds[var].values[1:,:,:]
    arr_reshaped = np.transpose(arr, (0,2,1))
    arr_zonal_mean = data_v2_rh_mc.zonal_bin_weight_3d(arr_reshaped)
    arr_zonal_time_mean = arr_zonal_mean.mean(axis = 0)
    arr_zonal_time_mean = xr.DataArray(arr_zonal_time_mean.T, dims = ['hybrid pressure (hPa)', 'latitude'], coords = {'hybrid pressure (hPa)':level, 'latitude': lat_bin_mids})
    return arr_zonal_time_mean

def area_mean(ds, var):
    arr = ds[var].values
    arr_reshaped = np.transpose(arr, (0,2,1))
    arr_zonal_mean = data_v2_rh_mc.zonal_bin_weight_3d(arr_reshaped)
    return arr_zonal_mean

def zonal_diff(ds_sp, ds_nn, var):
    diff_zonal_mean = (area_mean(ds_nn, var) - area_mean(ds_sp, var)).mean(axis = 0)
    diff_zonal = xr.DataArray(diff_zonal_mean.T, dims = ['level', 'lat'], coords = {'level':level, 'lat': lat_bin_mids})
    return diff_zonal

def get_dp(ds):
    ps = ds['PS']
    p_interface = (ds['hyai'] * ds['P0'] + ds['hybi'] * ds['PS']).values
    if p_interface.shape[0] == 61:
        p_interface = np.swapaxes(p_interface, 0, 1)
    dp = p_interface[:,1:61,:] - p_interface[:,0:60,:]
    return dp

def get_tcp_mean(ds, area_weight):
    cld = ds['TOTCLD'].values
    dp = get_dp(ds)
    tcp = np.sum(cld*dp, axis = 1)/9.81
    tcp_mean = np.average(tcp, weights = area_weight, axis = 1)
    return tcp_mean

def get_tcp_std(ds, area_weight):
    cld = ds['TOTCLD'].values
    dp = get_dp(ds)
    tcp = np.sum(cld*dp, axis = 1)/9.81
    tcp_mean = np.average(tcp, weights = area_weight, axis = 1)
    squared_diff = (tcp - tcp_mean[:, None])**2
    tcp_std = np.sqrt(np.average(squared_diff, weights = area_weight, axis = 1))
    return tcp_std

def read_mmf_online_data(num_years):
    assert num_years <= 5 and num_years >= 1
    years_regexp = '34567'[:num_years]
    ds_mmf_1 = xr.open_mfdataset(f'/pscratch/sd/z/zeyuanhu/hu_etal2024_data_v2/data/h0/5year/mmf_ref/control_fullysp_jan_wmlio_r3.eam.h0.000[{years_regexp}]*.nc')
    ds_mmf_2 = xr.open_mfdataset(f'/pscratch/sd/z/zeyuanhu/hu_etal2024_data_v2/data/h0/5year/mmf_b/control_fullysp_jan_wmlio_r3_b.eam.h0.000[{years_regexp}]*.nc')
    ds_mmf_1['DQnPHYS'] = ds_mmf_1['DQ2PHYS'] + ds_mmf_1['DQ3PHYS']
    ds_mmf_2['DQnPHYS'] = ds_mmf_2['DQ2PHYS'] + ds_mmf_2['DQ3PHYS']
    ds_mmf_1['TOTCLD'] = ds_mmf_1['CLDICE'] + ds_mmf_1['CLDLIQ']
    ds_mmf_2['TOTCLD'] = ds_mmf_2['CLDICE'] + ds_mmf_2['CLDLIQ']
    ds_mmf_1['PRECT'] = ds_mmf_1['PRECC'] + ds_mmf_1['PRECL']
    ds_mmf_2['PRECT'] = ds_mmf_2['PRECC'] + ds_mmf_2['PRECL']
    return ds_mmf_1, ds_mmf_2

def read_nn_online_data(config_name, model_name, seed, num_years):
    assert num_years <= 5 and num_years >= 1
    years_regexp = '34567'[:num_years]
    if config_name == 'standard':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_seed_{seed}', 'run', f'{model_name}_seed_{seed}.eam.h0.000[{years_regexp}]*.nc')
    elif config_name == 'conf_loss':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_conf_seed_{seed}', 'run', f'{model_name}_conf_seed_{seed}.eam.h0.000[{years_regexp}]*.nc')
    elif config_name == 'diff_loss':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_diff_seed_{seed}', 'run', f'{model_name}_diff_seed_{seed}.eam.h0.000[{years_regexp}]*.nc')
    elif config_name == 'multirep':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_multirep_seed_{seed}', 'run', f'{model_name}_multirep_seed_{seed}.eam.h0.000[{years_regexp}]*.nc')
    elif config_name == 'v6':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_v6_seed_{seed}', 'run', f'{model_name}_v6_seed_{seed}.eam.h0.000[{years_regexp}]*.nc')
    if len(ls(extract_path)) == 0:
        return None
    ds_nn = xr.open_mfdataset(extract_path)
    if len(ds_nn['time']) < 12 * num_years:
        return None
    ds_nn['DQnPHYS'] = ds_nn['DQ2PHYS'] + ds_nn['DQ3PHYS']
    ds_nn['TOTCLD'] = ds_nn['CLDICE'] + ds_nn['CLDLIQ']
    ds_nn['PRECT'] = ds_nn['PRECC'] + ds_nn['PRECL']
    return ds_nn

def read_nn_online_precip_data(config_name, model_name, seed, num_years):
    assert num_years <= 5 and num_years >= 1
    years_regexp = '34567'[:num_years]
    if config_name == 'standard':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_seed_{seed}', 'run', 'precip_dir', 'combined_precip.nc')
    elif config_name == 'conf_loss':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_conf_seed_{seed}', 'run', 'precip_dir', 'combined_precip.nc')
    elif config_name == 'diff_loss':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_diff_seed_{seed}', 'run', 'precip_dir', 'combined_precip.nc')
    elif config_name == 'multirep':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_multirep_seed_{seed}', 'run', 'precip_dir', 'combined_precip.nc')
    elif config_name == 'v6':
        extract_path = os.path.join(online_paths[config_name], f'{model_name}_v6_seed_{seed}', 'run', 'precip_dir', 'combined_precip.nc')
    if len(ls(extract_path)) == 0:
        return None
    ds_nn = xr.open_dataset(extract_path)
    if len(ds_nn['time']) < 365 * 24 * num_years:
        return None
    return ds_nn['PRECT']

def get_pressure_area_weights(ds, surface_type = None):
    ds_dp = get_dp(ds)
    ds_total_weight = ds_dp * area_weight[None, None, :]
    ds_total_weight = ds_total_weight.mean(axis = 0)
    ds_total_weight = ds_total_weight/ds_total_weight.sum()
    if surface_type is None:
        return ds_total_weight
    elif surface_type == 'land':
        land_area = ds['LANDFRAC'].values * grid_area[None, :]
        land_area_sums = np.array([[np.sum(land_area[t,:][data_v2_rh_mc.lat_bin_dict[lat_bin]]) for lat_bin in data_v2_rh_mc.lat_bin_dict.keys()] for t in range(land_area.shape[0])])
        land_area_divs = np.stack([np.divide(1, land_area_sums[:, bin_index], where=~(land_area_sums[:, bin_index] == 0), out=np.zeros_like(land_area_sums[:, bin_index])) for bin_index in data_v2_rh_mc.lat_bin_indices], axis=1)
        land_area_weighting = land_area * land_area_divs
        return land_area_weighting
    elif surface_type == 'ocean':
        ocean_area = ds['OCNFRAC'].values * grid_area[None, :]
        ocean_area_sums = np.array([[np.sum(ocean_area[t,:][data_v2_rh_mc.lat_bin_dict[lat_bin]]) for lat_bin in data_v2_rh_mc.lat_bin_dict.keys()] for t in range(ocean_area.shape[0])])
        ocean_area_divs = np.stack([np.divide(1, ocean_area_sums[:, bin_index], where=~(ocean_area_sums[:, bin_index] == 0), out=np.zeros_like(ocean_area_sums[:, bin_index])) for bin_index in data_v2_rh_mc.lat_bin_indices], axis=1)
        ocean_area_weighting = ocean_area * ocean_area_divs
        return ocean_area_weighting
    elif surface_type == 'ice':
        ice_area = ds['ICEFRAC'].values * grid_area[None, :]
        ice_area_sums = np.array([[np.sum(ice_area[t,:][data_v2_rh_mc.lat_bin_dict[lat_bin]]) for lat_bin in data_v2_rh_mc.lat_bin_dict.keys()] for t in range(ice_area.shape[0])])
        ice_area_divs = np.stack([np.divide(1, ice_area_sums[:, bin_index], where=~(ice_area_sums[:, bin_index] == 0), out=np.zeros_like(ice_area_sums[:, bin_index])) for bin_index in data_v2_rh_mc.lat_bin_indices], axis=1)
        ice_area_weighting = ice_area * ice_area_divs
        return ice_area_weighting
    else:
        raise ValueError("Invalid surface type. Choose from 'land', 'ocean', or 'ice'.")

def get_offline_precip_area_weights(nn_preds, precc_index = 363):
    nn_precc = nn_preds[:,:,precc_index] * 86400 * 1000
    no_precip_mask = nn_precc[:,:] == 0
    nn_active_precip = nn_precc[:,:][~no_precip_mask]
    nn_active_precip_quantiles = np.quantile(nn_active_precip, [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
    active_precip_1_10_percentile_mask = (nn_precc[:,:] > 0) & (nn_precc[:,:] <= nn_active_precip_quantiles[0])
    active_precip_10_20_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[0]) & (nn_precc[:,:] <= nn_active_precip_quantiles[1])
    active_precip_20_30_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[1]) & (nn_precc[:,:] <= nn_active_precip_quantiles[2])
    active_precip_30_40_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[2]) & (nn_precc[:,:] <= nn_active_precip_quantiles[3])
    active_precip_40_50_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[3]) & (nn_precc[:,:] <= nn_active_precip_quantiles[4])
    active_precip_50_60_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[4]) & (nn_precc[:,:] <= nn_active_precip_quantiles[5])
    active_precip_60_70_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[5]) & (nn_precc[:,:] <= nn_active_precip_quantiles[6])
    active_precip_70_80_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[6]) & (nn_precc[:,:] <= nn_active_precip_quantiles[7])
    active_precip_80_90_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[7]) & (nn_precc[:,:] <= nn_active_precip_quantiles[8])
    active_precip_90_100_percentile_mask = (nn_precc[:,:] > nn_active_precip_quantiles[8]) & (nn_precc[:,:] <= nn_active_precip_quantiles[9])
    precip_area_dict = {
        'no_precip': no_precip_mask * area_weight[None,:],
        'active_precip_1_10_percentile': active_precip_1_10_percentile_mask * area_weight[None,:],
        'active_precip_10_20_percentile': active_precip_10_20_percentile_mask * area_weight[None,:],
        'active_precip_20_30_percentile': active_precip_20_30_percentile_mask * area_weight[None,:],
        'active_precip_30_40_percentile': active_precip_30_40_percentile_mask * area_weight[None,:],
        'active_precip_40_50_percentile': active_precip_40_50_percentile_mask * area_weight[None,:],
        'active_precip_50_60_percentile': active_precip_50_60_percentile_mask * area_weight[None,:],
        'active_precip_60_70_percentile': active_precip_60_70_percentile_mask * area_weight[None,:],
        'active_precip_70_80_percentile': active_precip_70_80_percentile_mask * area_weight[None,:],
        'active_precip_80_90_percentile': active_precip_80_90_percentile_mask * area_weight[None,:],
        'active_precip_90_100_percentile': active_precip_90_100_percentile_mask * area_weight[None,:]
    }
    precip_area_divs = {key: np.divide(np.ones_like(precip_area_dict[key]), np.sum(precip_area_dict[key]), out = np.zeros_like(precip_area_dict[key]), where = (precip_area_dict[key] != 0)) for key in precip_area_dict.keys()}
    for key in precip_area_dict.keys():
        precip_area_dict[key] = (precip_area_dict[key] * precip_area_divs[key])[:,:,None]
    return precip_area_dict

precip_percentile_labels = {
    'no_precip': 'No Precipitation',
    'active_precip_1_10_percentile': '1-10th percentile',
    'active_precip_10_20_percentile': '10-20th percentile',
    'active_precip_20_30_percentile': '20-30th percentile',
    'active_precip_30_40_percentile': '30-40th percentile',
    'active_precip_40_50_percentile': '40-50th percentile',
    'active_precip_50_60_percentile': '50-60th percentile',
    'active_precip_60_70_percentile': '60-70th percentile',
    'active_precip_70_80_percentile': '70-80th percentile',
    'active_precip_80_90_percentile': '80-90th percentile',
    'active_precip_90_100_percentile': '90-100th percentile'
}

# Load offline preds

In [6]:
standard_save_path = '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v2_rh_mc/test_set/test_preds/standard/'
conf_loss_save_path = '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v2_rh_mc/test_set/test_preds/conf_loss/'
diff_loss_save_path = '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v2_rh_mc/test_set/test_preds/diff_loss/'
multirep_save_path = '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v2_rh_mc/test_set/test_preds/multirep/'
v6_save_path = '/pscratch/sd/j/jerrylin/hugging/E3SM-MMF_ne4/preprocessing/v6/test_set/test_preds/'

def load_seed_data(save_path, npz_file, seed_key):
    with np.load(os.path.join(save_path, npz_file)) as data:
        return data[seed_key]

print('loading standard preds')
standard_preds = {
    'unet': lambda seed_key: load_seed_data(standard_save_path, 'standard_unet_preds.npz', seed_key),
    'squeezeformer': lambda seed_key: load_seed_data(standard_save_path, 'standard_squeezeformer_preds.npz', seed_key),
    'pure_resLSTM': lambda seed_key: load_seed_data(standard_save_path, 'standard_pure_resLSTM_preds.npz', seed_key),
    'pao_model': lambda seed_key: load_seed_data(standard_save_path, 'standard_pao_model_preds.npz', seed_key),
    'convnext': lambda seed_key: load_seed_data(standard_save_path, 'standard_convnext_preds.npz', seed_key),
    'encdec_lstm': lambda seed_key: load_seed_data(standard_save_path, 'standard_encdec_lstm_preds.npz', seed_key)
}

print('loading conf loss preds')
conf_loss_preds = {
    'unet': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_unet_preds.npz', seed_key),
    'squeezeformer': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_squeezeformer_preds.npz', seed_key),
    'pure_resLSTM': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_pure_resLSTM_preds.npz', seed_key),
    'pao_model': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_pao_model_preds.npz', seed_key),
    'convnext': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_convnext_preds.npz', seed_key),
    'encdec_lstm': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_encdec_lstm_preds.npz', seed_key)
}

print('loading conf loss conf')
conf_loss_conf = {
    'unet': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_unet_conf.npz', seed_key),
    'squeezeformer': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_squeezeformer_conf.npz', seed_key),
    'pure_resLSTM': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_pure_resLSTM_conf.npz', seed_key),
    'pao_model': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_pao_model_conf.npz', seed_key),
    'convnext': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_convnext_conf.npz', seed_key),
    'encdec_lstm': lambda seed_key: load_seed_data(conf_loss_save_path, 'conf_loss_encdec_lstm_conf.npz', seed_key)
}

print('loading diff loss preds')
diff_loss_preds = {
    'unet': lambda seed_key: load_seed_data(diff_loss_save_path, 'diff_loss_unet_preds.npz', seed_key),
    'squeezeformer': lambda seed_key: load_seed_data(diff_loss_save_path, 'diff_loss_squeezeformer_preds.npz', seed_key),
    'pure_resLSTM': lambda seed_key: load_seed_data(diff_loss_save_path, 'diff_loss_pure_resLSTM_preds.npz', seed_key),
    'pao_model': lambda seed_key: load_seed_data(diff_loss_save_path, 'diff_loss_pao_model_preds.npz', seed_key),
    'convnext': lambda seed_key: load_seed_data(diff_loss_save_path, 'diff_loss_convnext_preds.npz', seed_key),
    'encdec_lstm': lambda seed_key: load_seed_data(diff_loss_save_path, 'diff_loss_encdec_lstm_preds.npz', seed_key)
}

print('loading multirep preds')
multirep_preds = {
    'unet': lambda seed_key: load_seed_data(multirep_save_path, 'multirep_unet_preds.npz', seed_key),
    'squeezeformer': lambda seed_key: load_seed_data(multirep_save_path, 'multirep_squeezeformer_preds.npz', seed_key),
    'pure_resLSTM': lambda seed_key: load_seed_data(multirep_save_path, 'multirep_pure_resLSTM_preds.npz', seed_key),
    'pao_model': lambda seed_key: load_seed_data(multirep_save_path, 'multirep_pao_model_preds.npz', seed_key),
    'convnext': lambda seed_key: load_seed_data(multirep_save_path, 'multirep_convnext_preds.npz', seed_key),
    'encdec_lstm': lambda seed_key: load_seed_data(multirep_save_path, 'multirep_encdec_lstm_preds.npz', seed_key)
}

print('loading v6 preds')
v6_preds = {
    'unet': lambda seed_key: load_seed_data(v6_save_path, 'v6_unet_preds.npz', seed_key),
    'squeezeformer': lambda seed_key: load_seed_data(v6_save_path, 'v6_squeezeformer_preds.npz', seed_key),
    'pure_resLSTM': lambda seed_key: load_seed_data(v6_save_path, 'v6_pure_resLSTM_preds.npz', seed_key),
    'pao_model': lambda seed_key: load_seed_data(v6_save_path, 'v6_pao_model_preds.npz', seed_key),
    'convnext': lambda seed_key: load_seed_data(v6_save_path, 'v6_convnext_preds.npz', seed_key),
    'encdec_lstm': lambda seed_key: load_seed_data(v6_save_path, 'v6_encdec_lstm_preds.npz', seed_key)
}

config_preds = {
    'standard': standard_preds,
    'conf_loss': conf_loss_preds,
    'diff_loss': diff_loss_preds,
    'multirep': multirep_preds,
    'v6': v6_preds
}

print('loading Kaggle preds')
kaggle_check_path = '/pscratch/sd/j/jerrylin/kaggle_check'

greysnow = np.load(os.path.join(kaggle_check_path, 'prds.npy'))
greysnow = greysnow.reshape(-1, data_v2_rh_mc.num_latlon, 368)

adam = pd.read_parquet(os.path.join(kaggle_check_path, 'final_blend_v10.parquet'))
adam = adam.iloc[:,1:].values
adam = adam.reshape(-1, data_v2_rh_mc.num_latlon, 368)

print(greysnow.shape, adam.shape)

loading standard preds
loading conf loss preds
loading conf loss conf
loading diff loss preds
loading multirep preds
loading v6 preds
loading Kaggle preds
(4380, 384, 368) (4380, 384, 368)


# Load Offline R2

In [7]:
with open(os.path.join(standard_save_path, "standard_unet_r2.pkl"), "rb") as f:
    standard_unet_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "standard_squeezeformer_r2.pkl"), "rb") as f:
    standard_squeezeformer_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "standard_pure_resLSTM_r2.pkl"), "rb") as f:
    standard_pure_resLSTM_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "standard_pao_model_r2.pkl"), "rb") as f:
    standard_pao_model_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "standard_convnext_r2.pkl"), "rb") as f:
    standard_convnext_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "standard_encdec_lstm_r2.pkl"), "rb") as f:
    standard_encdec_lstm_r2 = pickle.load(f)

with open(os.path.join(conf_loss_save_path, "conf_loss_unet_r2.pkl"), "rb") as f:
    conf_loss_unet_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "conf_loss_squeezeformer_r2.pkl"), "rb") as f:
    conf_loss_squeezeformer_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "conf_loss_pure_resLSTM_r2.pkl"), "rb") as f:
    conf_loss_pure_resLSTM_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "conf_loss_pao_model_r2.pkl"), "rb") as f:
    conf_loss_pao_model_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "conf_loss_convnext_r2.pkl"), "rb") as f:
    conf_loss_convnext_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "conf_loss_encdec_lstm_r2.pkl"), "rb") as f:
    conf_loss_encdec_lstm_r2 = pickle.load(f)

with open(os.path.join(diff_loss_save_path, "diff_loss_unet_r2.pkl"), "rb") as f:
    diff_loss_unet_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "diff_loss_squeezeformer_r2.pkl"), "rb") as f:
    diff_loss_squeezeformer_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "diff_loss_pure_resLSTM_r2.pkl"), "rb") as f:
    diff_loss_pure_resLSTM_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "diff_loss_pao_model_r2.pkl"), "rb") as f:
    diff_loss_pao_model_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "diff_loss_convnext_r2.pkl"), "rb") as f:
    diff_loss_convnext_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "diff_loss_encdec_lstm_r2.pkl"), "rb") as f:
    diff_loss_encdec_lstm_r2 = pickle.load(f)

with open(os.path.join(multirep_save_path, "multirep_unet_r2.pkl"), "rb") as f:
    multirep_unet_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "multirep_squeezeformer_r2.pkl"), "rb") as f:
    multirep_squeezeformer_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "multirep_pure_resLSTM_r2.pkl"), "rb") as f:
    multirep_pure_resLSTM_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "multirep_pao_model_r2.pkl"), "rb") as f:
    multirep_pao_model_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "multirep_convnext_r2.pkl"), "rb") as f:
    multirep_convnext_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "multirep_encdec_lstm_r2.pkl"), "rb") as f:
    multirep_encdec_lstm_r2 = pickle.load(f)

with open(os.path.join(v6_save_path, "v6_unet_r2.pkl"), "rb") as f:
    v6_unet_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "v6_squeezeformer_r2.pkl"), "rb") as f:
    v6_squeezeformer_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "v6_pure_resLSTM_r2.pkl"), "rb") as f:
    v6_pure_resLSTM_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "v6_pao_model_r2.pkl"), "rb") as f:
    v6_pao_model_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "v6_convnext_r2.pkl"), "rb") as f:
    v6_convnext_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "v6_encdec_lstm_r2.pkl"), "rb") as f:
    v6_encdec_lstm_r2 = pickle.load(f)

# Load Offline R2 (zonal)

In [8]:
with open(os.path.join(standard_save_path, "zonal", "standard_unet_zonal_dTdt_r2.pkl"), "rb") as f:
    standard_unet_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_unet_zonal_dQvdt_r2.pkl"), "rb") as f:
    standard_unet_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_unet_zonal_dQldt_r2.pkl"), "rb") as f:
    standard_unet_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_unet_zonal_dQidt_r2.pkl"), "rb") as f:
    standard_unet_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_unet_zonal_dUdt_r2.pkl"), "rb") as f:
    standard_unet_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_unet_zonal_dVdt_r2.pkl"), "rb") as f:
    standard_unet_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(standard_save_path, "zonal", "standard_squeezeformer_zonal_dTdt_r2.pkl"), "rb") as f:
    standard_squeezeformer_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_squeezeformer_zonal_dQvdt_r2.pkl"), "rb") as f:
    standard_squeezeformer_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_squeezeformer_zonal_dQldt_r2.pkl"), "rb") as f:
    standard_squeezeformer_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_squeezeformer_zonal_dQidt_r2.pkl"), "rb") as f:
    standard_squeezeformer_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_squeezeformer_zonal_dUdt_r2.pkl"), "rb") as f:
    standard_squeezeformer_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_squeezeformer_zonal_dVdt_r2.pkl"), "rb") as f:
    standard_squeezeformer_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(standard_save_path, "zonal", "standard_pure_resLSTM_zonal_dTdt_r2.pkl"), "rb") as f:
    standard_pure_resLSTM_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_pure_resLSTM_zonal_dQvdt_r2.pkl"), "rb") as f:
    standard_pure_resLSTM_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_pure_resLSTM_zonal_dQldt_r2.pkl"), "rb") as f:
    standard_pure_resLSTM_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_pure_resLSTM_zonal_dQidt_r2.pkl"), "rb") as f:
    standard_pure_resLSTM_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_pure_resLSTM_zonal_dUdt_r2.pkl"), "rb") as f:
    standard_pure_resLSTM_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_pure_resLSTM_zonal_dVdt_r2.pkl"), "rb") as f:
    standard_pure_resLSTM_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(standard_save_path, "zonal", "standard_pao_model_zonal_dTdt_r2.pkl"), "rb") as f:
    standard_pao_model_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_pao_model_zonal_dQvdt_r2.pkl"), "rb") as f:
    standard_pao_model_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_pao_model_zonal_dQldt_r2.pkl"), "rb") as f:
    standard_pao_model_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_pao_model_zonal_dQidt_r2.pkl"), "rb") as f:
    standard_pao_model_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_pao_model_zonal_dUdt_r2.pkl"), "rb") as f:
    standard_pao_model_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_pao_model_zonal_dVdt_r2.pkl"), "rb") as f:
    standard_pao_model_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(standard_save_path, "zonal", "standard_convnext_zonal_dTdt_r2.pkl"), "rb") as f:
    standard_convnext_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_convnext_zonal_dQvdt_r2.pkl"), "rb") as f:
    standard_convnext_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_convnext_zonal_dQldt_r2.pkl"), "rb") as f:
    standard_convnext_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_convnext_zonal_dQidt_r2.pkl"), "rb") as f:
    standard_convnext_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_convnext_zonal_dUdt_r2.pkl"), "rb") as f:
    standard_convnext_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_convnext_zonal_dVdt_r2.pkl"), "rb") as f:
    standard_convnext_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(standard_save_path, "zonal", "standard_encdec_lstm_zonal_dTdt_r2.pkl"), "rb") as f:
    standard_encdec_lstm_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_encdec_lstm_zonal_dQvdt_r2.pkl"), "rb") as f:
    standard_encdec_lstm_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_encdec_lstm_zonal_dQldt_r2.pkl"), "rb") as f:
    standard_encdec_lstm_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_encdec_lstm_zonal_dQidt_r2.pkl"), "rb") as f:
    standard_encdec_lstm_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_encdec_lstm_zonal_dUdt_r2.pkl"), "rb") as f:
    standard_encdec_lstm_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(standard_save_path, "zonal", "standard_encdec_lstm_zonal_dVdt_r2.pkl"), "rb") as f:
    standard_encdec_lstm_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_unet_zonal_dTdt_r2.pkl"), "rb") as f:
    conf_loss_unet_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_unet_zonal_dQvdt_r2.pkl"), "rb") as f:
    conf_loss_unet_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_unet_zonal_dQldt_r2.pkl"), "rb") as f:
    conf_loss_unet_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_unet_zonal_dQidt_r2.pkl"), "rb") as f:
    conf_loss_unet_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_unet_zonal_dUdt_r2.pkl"), "rb") as f:
    conf_loss_unet_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_unet_zonal_dVdt_r2.pkl"), "rb") as f:
    conf_loss_unet_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_squeezeformer_zonal_dTdt_r2.pkl"), "rb") as f:
    conf_loss_squeezeformer_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_squeezeformer_zonal_dQvdt_r2.pkl"), "rb") as f:
    conf_loss_squeezeformer_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_squeezeformer_zonal_dQldt_r2.pkl"), "rb") as f:
    conf_loss_squeezeformer_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_squeezeformer_zonal_dQidt_r2.pkl"), "rb") as f:
    conf_loss_squeezeformer_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_squeezeformer_zonal_dUdt_r2.pkl"), "rb") as f:
    conf_loss_squeezeformer_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_squeezeformer_zonal_dVdt_r2.pkl"), "rb") as f:
    conf_loss_squeezeformer_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_pure_resLSTM_zonal_dTdt_r2.pkl"), "rb") as f:
    conf_loss_pure_resLSTM_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_pure_resLSTM_zonal_dQvdt_r2.pkl"), "rb") as f:
    conf_loss_pure_resLSTM_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_pure_resLSTM_zonal_dQldt_r2.pkl"), "rb") as f:
    conf_loss_pure_resLSTM_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_pure_resLSTM_zonal_dQidt_r2.pkl"), "rb") as f:
    conf_loss_pure_resLSTM_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_pure_resLSTM_zonal_dUdt_r2.pkl"), "rb") as f:
    conf_loss_pure_resLSTM_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_pure_resLSTM_zonal_dVdt_r2.pkl"), "rb") as f:
    conf_loss_pure_resLSTM_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_pao_model_zonal_dTdt_r2.pkl"), "rb") as f:
    conf_loss_pao_model_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_pao_model_zonal_dQvdt_r2.pkl"), "rb") as f:
    conf_loss_pao_model_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_pao_model_zonal_dQldt_r2.pkl"), "rb") as f:
    conf_loss_pao_model_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_pao_model_zonal_dQidt_r2.pkl"), "rb") as f:
    conf_loss_pao_model_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_pao_model_zonal_dUdt_r2.pkl"), "rb") as f:
    conf_loss_pao_model_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_pao_model_zonal_dVdt_r2.pkl"), "rb") as f:
    conf_loss_pao_model_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_convnext_zonal_dTdt_r2.pkl"), "rb") as f:
    conf_loss_convnext_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_convnext_zonal_dQvdt_r2.pkl"), "rb") as f:
    conf_loss_convnext_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_convnext_zonal_dQldt_r2.pkl"), "rb") as f:
    conf_loss_convnext_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_convnext_zonal_dQidt_r2.pkl"), "rb") as f:
    conf_loss_convnext_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_convnext_zonal_dUdt_r2.pkl"), "rb") as f:
    conf_loss_convnext_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_convnext_zonal_dVdt_r2.pkl"), "rb") as f:
    conf_loss_convnext_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_encdec_lstm_zonal_dTdt_r2.pkl"), "rb") as f:
    conf_loss_encdec_lstm_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_encdec_lstm_zonal_dQvdt_r2.pkl"), "rb") as f:
    conf_loss_encdec_lstm_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_encdec_lstm_zonal_dQldt_r2.pkl"), "rb") as f:
    conf_loss_encdec_lstm_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_encdec_lstm_zonal_dQidt_r2.pkl"), "rb") as f:
    conf_loss_encdec_lstm_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_encdec_lstm_zonal_dUdt_r2.pkl"), "rb") as f:
    conf_loss_encdec_lstm_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(conf_loss_save_path, "zonal", "conf_loss_encdec_lstm_zonal_dVdt_r2.pkl"), "rb") as f:
    conf_loss_encdec_lstm_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_unet_zonal_dTdt_r2.pkl"), "rb") as f:
    diff_loss_unet_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_unet_zonal_dQvdt_r2.pkl"), "rb") as f:
    diff_loss_unet_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_unet_zonal_dQldt_r2.pkl"), "rb") as f:
    diff_loss_unet_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_unet_zonal_dQidt_r2.pkl"), "rb") as f:
    diff_loss_unet_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_unet_zonal_dUdt_r2.pkl"), "rb") as f:
    diff_loss_unet_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_unet_zonal_dVdt_r2.pkl"), "rb") as f:
    diff_loss_unet_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_squeezeformer_zonal_dTdt_r2.pkl"), "rb") as f:
    diff_loss_squeezeformer_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_squeezeformer_zonal_dQvdt_r2.pkl"), "rb") as f:
    diff_loss_squeezeformer_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_squeezeformer_zonal_dQldt_r2.pkl"), "rb") as f:
    diff_loss_squeezeformer_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_squeezeformer_zonal_dQidt_r2.pkl"), "rb") as f:
    diff_loss_squeezeformer_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_squeezeformer_zonal_dUdt_r2.pkl"), "rb") as f:
    diff_loss_squeezeformer_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_squeezeformer_zonal_dVdt_r2.pkl"), "rb") as f:
    diff_loss_squeezeformer_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_pure_resLSTM_zonal_dTdt_r2.pkl"), "rb") as f:
    diff_loss_pure_resLSTM_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_pure_resLSTM_zonal_dQvdt_r2.pkl"), "rb") as f:
    diff_loss_pure_resLSTM_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_pure_resLSTM_zonal_dQldt_r2.pkl"), "rb") as f:
    diff_loss_pure_resLSTM_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_pure_resLSTM_zonal_dQidt_r2.pkl"), "rb") as f:
    diff_loss_pure_resLSTM_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_pure_resLSTM_zonal_dUdt_r2.pkl"), "rb") as f:
    diff_loss_pure_resLSTM_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_pure_resLSTM_zonal_dVdt_r2.pkl"), "rb") as f:
    diff_loss_pure_resLSTM_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_pao_model_zonal_dTdt_r2.pkl"), "rb") as f:
    diff_loss_pao_model_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_pao_model_zonal_dQvdt_r2.pkl"), "rb") as f:
    diff_loss_pao_model_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_pao_model_zonal_dQldt_r2.pkl"), "rb") as f:
    diff_loss_pao_model_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_pao_model_zonal_dQidt_r2.pkl"), "rb") as f:
    diff_loss_pao_model_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_pao_model_zonal_dUdt_r2.pkl"), "rb") as f:
    diff_loss_pao_model_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_pao_model_zonal_dVdt_r2.pkl"), "rb") as f:
    diff_loss_pao_model_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_convnext_zonal_dTdt_r2.pkl"), "rb") as f:
    diff_loss_convnext_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_convnext_zonal_dQvdt_r2.pkl"), "rb") as f:
    diff_loss_convnext_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_convnext_zonal_dQldt_r2.pkl"), "rb") as f:
    diff_loss_convnext_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_convnext_zonal_dQidt_r2.pkl"), "rb") as f:
    diff_loss_convnext_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_convnext_zonal_dUdt_r2.pkl"), "rb") as f:
    diff_loss_convnext_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_convnext_zonal_dVdt_r2.pkl"), "rb") as f:
    diff_loss_convnext_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_encdec_lstm_zonal_dTdt_r2.pkl"), "rb") as f:
    diff_loss_encdec_lstm_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_encdec_lstm_zonal_dQvdt_r2.pkl"), "rb") as f:
    diff_loss_encdec_lstm_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_encdec_lstm_zonal_dQldt_r2.pkl"), "rb") as f:
    diff_loss_encdec_lstm_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_encdec_lstm_zonal_dQidt_r2.pkl"), "rb") as f:
    diff_loss_encdec_lstm_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_encdec_lstm_zonal_dUdt_r2.pkl"), "rb") as f:
    diff_loss_encdec_lstm_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(diff_loss_save_path, "zonal", "diff_loss_encdec_lstm_zonal_dVdt_r2.pkl"), "rb") as f:
    diff_loss_encdec_lstm_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(multirep_save_path, "zonal", "multirep_unet_zonal_dTdt_r2.pkl"), "rb") as f:
    multirep_unet_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_unet_zonal_dQvdt_r2.pkl"), "rb") as f:
    multirep_unet_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_unet_zonal_dQldt_r2.pkl"), "rb") as f:
    multirep_unet_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_unet_zonal_dQidt_r2.pkl"), "rb") as f:
    multirep_unet_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_unet_zonal_dUdt_r2.pkl"), "rb") as f:
    multirep_unet_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_unet_zonal_dVdt_r2.pkl"), "rb") as f:
    multirep_unet_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(multirep_save_path, "zonal", "multirep_squeezeformer_zonal_dTdt_r2.pkl"), "rb") as f:
    multirep_squeezeformer_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_squeezeformer_zonal_dQvdt_r2.pkl"), "rb") as f:
    multirep_squeezeformer_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_squeezeformer_zonal_dQldt_r2.pkl"), "rb") as f:
    multirep_squeezeformer_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_squeezeformer_zonal_dQidt_r2.pkl"), "rb") as f:
    multirep_squeezeformer_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_squeezeformer_zonal_dUdt_r2.pkl"), "rb") as f:
    multirep_squeezeformer_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_squeezeformer_zonal_dVdt_r2.pkl"), "rb") as f:
    multirep_squeezeformer_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(multirep_save_path, "zonal", "multirep_pure_resLSTM_zonal_dTdt_r2.pkl"), "rb") as f:
    multirep_pure_resLSTM_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_pure_resLSTM_zonal_dQvdt_r2.pkl"), "rb") as f:
    multirep_pure_resLSTM_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_pure_resLSTM_zonal_dQldt_r2.pkl"), "rb") as f:
    multirep_pure_resLSTM_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_pure_resLSTM_zonal_dQidt_r2.pkl"), "rb") as f:
    multirep_pure_resLSTM_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_pure_resLSTM_zonal_dUdt_r2.pkl"), "rb") as f:
    multirep_pure_resLSTM_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_pure_resLSTM_zonal_dVdt_r2.pkl"), "rb") as f:
    multirep_pure_resLSTM_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(multirep_save_path, "zonal", "multirep_pao_model_zonal_dTdt_r2.pkl"), "rb") as f:
    multirep_pao_model_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_pao_model_zonal_dQvdt_r2.pkl"), "rb") as f:
    multirep_pao_model_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_pao_model_zonal_dQldt_r2.pkl"), "rb") as f:
    multirep_pao_model_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_pao_model_zonal_dQidt_r2.pkl"), "rb") as f:
    multirep_pao_model_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_pao_model_zonal_dUdt_r2.pkl"), "rb") as f:
    multirep_pao_model_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_pao_model_zonal_dVdt_r2.pkl"), "rb") as f:
    multirep_pao_model_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(multirep_save_path, "zonal", "multirep_convnext_zonal_dTdt_r2.pkl"), "rb") as f:
    multirep_convnext_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_convnext_zonal_dQvdt_r2.pkl"), "rb") as f:
    multirep_convnext_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_convnext_zonal_dQldt_r2.pkl"), "rb") as f:
    multirep_convnext_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_convnext_zonal_dQidt_r2.pkl"), "rb") as f:
    multirep_convnext_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_convnext_zonal_dUdt_r2.pkl"), "rb") as f:
    multirep_convnext_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_convnext_zonal_dVdt_r2.pkl"), "rb") as f:
    multirep_convnext_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(multirep_save_path, "zonal", "multirep_encdec_lstm_zonal_dTdt_r2.pkl"), "rb") as f:
    multirep_encdec_lstm_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_encdec_lstm_zonal_dQvdt_r2.pkl"), "rb") as f:
    multirep_encdec_lstm_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_encdec_lstm_zonal_dQldt_r2.pkl"), "rb") as f:
    multirep_encdec_lstm_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_encdec_lstm_zonal_dQidt_r2.pkl"), "rb") as f:
    multirep_encdec_lstm_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_encdec_lstm_zonal_dUdt_r2.pkl"), "rb") as f:
    multirep_encdec_lstm_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "multirep_encdec_lstm_zonal_dVdt_r2.pkl"), "rb") as f:
    multirep_encdec_lstm_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(multirep_save_path, "zonal", "v6_unet_zonal_dTdt_r2.pkl"), "rb") as f:
    v6_unet_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "v6_unet_zonal_dQvdt_r2.pkl"), "rb") as f:
    v6_unet_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "v6_unet_zonal_dQldt_r2.pkl"), "rb") as f:
    v6_unet_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "v6_unet_zonal_dQidt_r2.pkl"), "rb") as f:
    v6_unet_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "v6_unet_zonal_dUdt_r2.pkl"), "rb") as f:
    v6_unet_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(multirep_save_path, "zonal", "v6_unet_zonal_dVdt_r2.pkl"), "rb") as f:
    v6_unet_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(v6_save_path, "zonal", "v6_squeezeformer_zonal_dTdt_r2.pkl"), "rb") as f:
    v6_squeezeformer_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_squeezeformer_zonal_dQvdt_r2.pkl"), "rb") as f:
    v6_squeezeformer_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_squeezeformer_zonal_dQldt_r2.pkl"), "rb") as f:
    v6_squeezeformer_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_squeezeformer_zonal_dQidt_r2.pkl"), "rb") as f:
    v6_squeezeformer_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_squeezeformer_zonal_dUdt_r2.pkl"), "rb") as f:
    v6_squeezeformer_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_squeezeformer_zonal_dVdt_r2.pkl"), "rb") as f:
    v6_squeezeformer_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(v6_save_path, "zonal", "v6_pure_resLSTM_zonal_dTdt_r2.pkl"), "rb") as f:
    v6_pure_resLSTM_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_pure_resLSTM_zonal_dQvdt_r2.pkl"), "rb") as f:
    v6_pure_resLSTM_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_pure_resLSTM_zonal_dQldt_r2.pkl"), "rb") as f:
    v6_pure_resLSTM_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_pure_resLSTM_zonal_dQidt_r2.pkl"), "rb") as f:
    v6_pure_resLSTM_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_pure_resLSTM_zonal_dUdt_r2.pkl"), "rb") as f:
    v6_pure_resLSTM_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_pure_resLSTM_zonal_dVdt_r2.pkl"), "rb") as f:
    v6_pure_resLSTM_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(v6_save_path, "zonal", "v6_pao_model_zonal_dTdt_r2.pkl"), "rb") as f:
    v6_pao_model_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_pao_model_zonal_dQvdt_r2.pkl"), "rb") as f:
    v6_pao_model_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_pao_model_zonal_dQldt_r2.pkl"), "rb") as f:
    v6_pao_model_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_pao_model_zonal_dQidt_r2.pkl"), "rb") as f:
    v6_pao_model_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_pao_model_zonal_dUdt_r2.pkl"), "rb") as f:
    v6_pao_model_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_pao_model_zonal_dVdt_r2.pkl"), "rb") as f:
    v6_pao_model_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(v6_save_path, "zonal", "v6_convnext_zonal_dTdt_r2.pkl"), "rb") as f:
    v6_convnext_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_convnext_zonal_dQvdt_r2.pkl"), "rb") as f:
    v6_convnext_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_convnext_zonal_dQldt_r2.pkl"), "rb") as f:
    v6_convnext_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_convnext_zonal_dQidt_r2.pkl"), "rb") as f:
    v6_convnext_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_convnext_zonal_dUdt_r2.pkl"), "rb") as f:
    v6_convnext_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_convnext_zonal_dVdt_r2.pkl"), "rb") as f:
    v6_convnext_zonal_dVdt_r2 = pickle.load(f)

with open(os.path.join(v6_save_path, "zonal", "v6_encdec_lstm_zonal_dTdt_r2.pkl"), "rb") as f:
    v6_encdec_lstm_zonal_dTdt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_encdec_lstm_zonal_dQvdt_r2.pkl"), "rb") as f:
    v6_encdec_lstm_zonal_dQvdt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_encdec_lstm_zonal_dQldt_r2.pkl"), "rb") as f:
    v6_encdec_lstm_zonal_dQldt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_encdec_lstm_zonal_dQidt_r2.pkl"), "rb") as f:
    v6_encdec_lstm_zonal_dQidt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_encdec_lstm_zonal_dUdt_r2.pkl"), "rb") as f:
    v6_encdec_lstm_zonal_dUdt_r2 = pickle.load(f)
with open(os.path.join(v6_save_path, "zonal", "v6_encdec_lstm_zonal_dVdt_r2.pkl"), "rb") as f:
    v6_encdec_lstm_zonal_dVdt_r2 = pickle.load(f)

# Offline plotting functions

### Offline R2 lines plot

In [9]:
def plot_offline_R2_lines(unet_r2,
                          squeezeformer_r2,
                          pure_resLSTM_r2,
                          pao_model_r2,
                          convnext_r2,
                          encdec_lstm_r2,
                          config_name = None,
                          show = True,
                          save_path = None):
    
    unet_dTdt_r2 = np.stack([unet_r2[seed][0][:60] for seed in seeds])
    squeezeformer_dTdt_r2 = np.stack([squeezeformer_r2[seed][0][:60] for seed in seeds])
    pure_resLSTM_dTdt_r2 = np.stack([pure_resLSTM_r2[seed][0][:60] for seed in seeds])
    pao_model_dTdt_r2 = np.stack([pao_model_r2[seed][0][:60] for seed in seeds])
    convnext_dTdt_r2 = np.stack([convnext_r2[seed][0][:60] for seed in seeds])
    encdec_lstm_dTdt_r2 = np.stack([encdec_lstm_r2[seed][0][:60] for seed in seeds])

    unet_dQvdt_r2 = np.stack([unet_r2[seed][0][60:120] for seed in seeds])
    squeezeformer_dQvdt_r2 = np.stack([squeezeformer_r2[seed][0][60:120] for seed in seeds])
    pure_resLSTM_dQvdt_r2 = np.stack([pure_resLSTM_r2[seed][0][60:120] for seed in seeds])
    pao_model_dQvdt_r2 = np.stack([pao_model_r2[seed][0][60:120] for seed in seeds])
    convnext_dQvdt_r2 = np.stack([convnext_r2[seed][0][60:120] for seed in seeds])
    encdec_lstm_dQvdt_r2 = np.stack([encdec_lstm_r2[seed][0][60:120] for seed in seeds])

    unet_dQldt_r2 = np.stack([unet_r2[seed][0][120:180] for seed in seeds])
    squeezeformer_dQldt_r2 = np.stack([squeezeformer_r2[seed][0][120:180] for seed in seeds])
    pure_resLSTM_dQldt_r2 = np.stack([pure_resLSTM_r2[seed][0][120:180] for seed in seeds])
    pao_model_dQldt_r2 = np.stack([pao_model_r2[seed][0][120:180] for seed in seeds])
    convnext_dQldt_r2 = np.stack([convnext_r2[seed][0][120:180] for seed in seeds])
    encdec_lstm_dQldt_r2 = np.stack([encdec_lstm_r2[seed][0][120:180] for seed in seeds])

    unet_dQidt_r2 = np.stack([unet_r2[seed][0][180:240] for seed in seeds])
    squeezeformer_dQidt_r2 = np.stack([squeezeformer_r2[seed][0][180:240] for seed in seeds])
    pure_resLSTM_dQidt_r2 = np.stack([pure_resLSTM_r2[seed][0][180:240] for seed in seeds])
    pao_model_dQidt_r2 = np.stack([pao_model_r2[seed][0][180:240] for seed in seeds])
    convnext_dQidt_r2 = np.stack([convnext_r2[seed][0][180:240] for seed in seeds])
    encdec_lstm_dQidt_r2 = np.stack([encdec_lstm_r2[seed][0][180:240] for seed in seeds])

    unet_dUdt_r2 = np.stack([unet_r2[seed][0][240:300] for seed in seeds])
    squeezeformer_dUdt_r2 = np.stack([squeezeformer_r2[seed][0][240:300] for seed in seeds])
    pure_resLSTM_dUdt_r2 = np.stack([pure_resLSTM_r2[seed][0][240:300] for seed in seeds])
    pao_model_dUdt_r2 = np.stack([pao_model_r2[seed][0][240:300] for seed in seeds])
    convnext_dUdt_r2 = np.stack([convnext_r2[seed][0][240:300] for seed in seeds])
    encdec_lstm_dUdt_r2 = np.stack([encdec_lstm_r2[seed][0][240:300] for seed in seeds])

    unet_dVdt_r2 = np.stack([unet_r2[seed][0][300:360] for seed in seeds])
    squeezeformer_dVdt_r2 = np.stack([squeezeformer_r2[seed][0][300:360] for seed in seeds])
    pure_resLSTM_dVdt_r2 = np.stack([pure_resLSTM_r2[seed][0][300:360] for seed in seeds])
    pao_model_dVdt_r2 = np.stack([pao_model_r2[seed][0][300:360] for seed in seeds])
    convnext_dVdt_r2 = np.stack([convnext_r2[seed][0][300:360] for seed in seeds])
    encdec_lstm_dVdt_r2 = np.stack([encdec_lstm_r2[seed][0][300:360] for seed in seeds])

    unet_NETSW_r2 = np.stack([unet_r2[seed][0][360] for seed in seeds])
    squeezeformer_NETSW_r2 = np.stack([squeezeformer_r2[seed][0][360] for seed in seeds])
    pure_resLSTM_NETSW_r2 = np.stack([pure_resLSTM_r2[seed][0][360] for seed in seeds])
    pao_model_NETSW_r2 = np.stack([pao_model_r2[seed][0][360] for seed in seeds])
    convnext_NETSW_r2 = np.stack([convnext_r2[seed][0][360] for seed in seeds])
    encdec_lstm_NETSW_r2 = np.stack([encdec_lstm_r2[seed][0][360] for seed in seeds])

    unet_FLWDS_r2 = np.stack([unet_r2[seed][0][361] for seed in seeds])
    squeezeformer_FLWDS_r2 = np.stack([squeezeformer_r2[seed][0][361] for seed in seeds])
    pure_resLSTM_FLWDS_r2 = np.stack([pure_resLSTM_r2[seed][0][361] for seed in seeds])
    pao_model_FLWDS_r2 = np.stack([pao_model_r2[seed][0][361] for seed in seeds])
    convnext_FLWDS_r2 = np.stack([convnext_r2[seed][0][361] for seed in seeds])
    encdec_lstm_FLWDS_r2 = np.stack([encdec_lstm_r2[seed][0][361] for seed in seeds])

    unet_PRECSC_r2 = np.stack([unet_r2[seed][0][362] for seed in seeds])
    squeezeformer_PRECSC_r2 = np.stack([squeezeformer_r2[seed][0][362] for seed in seeds])
    pure_resLSTM_PRECSC_r2 = np.stack([pure_resLSTM_r2[seed][0][362] for seed in seeds])
    pao_model_PRECSC_r2 = np.stack([pao_model_r2[seed][0][362] for seed in seeds])
    convnext_PRECSC_r2 = np.stack([convnext_r2[seed][0][362] for seed in seeds])
    encdec_lstm_PRECSC_r2 = np.stack([encdec_lstm_r2[seed][0][362] for seed in seeds])

    unet_PRECC_r2 = np.stack([unet_r2[seed][0][363] for seed in seeds])
    squeezeformer_PRECC_r2 = np.stack([squeezeformer_r2[seed][0][363] for seed in seeds])
    pure_resLSTM_PRECC_r2 = np.stack([pure_resLSTM_r2[seed][0][363] for seed in seeds])
    pao_model_PRECC_r2 = np.stack([pao_model_r2[seed][0][363] for seed in seeds])
    convnext_PRECC_r2 = np.stack([convnext_r2[seed][0][363] for seed in seeds])
    encdec_lstm_PRECC_r2 = np.stack([encdec_lstm_r2[seed][0][363] for seed in seeds])

    unet_SOLS_r2 = np.stack([unet_r2[seed][0][364] for seed in seeds])
    squeezeformer_SOLS_r2 = np.stack([squeezeformer_r2[seed][0][364] for seed in seeds])
    pure_resLSTM_SOLS_r2 = np.stack([pure_resLSTM_r2[seed][0][364] for seed in seeds])
    pao_model_SOLS_r2 = np.stack([pao_model_r2[seed][0][364] for seed in seeds])
    convnext_SOLS_r2 = np.stack([convnext_r2[seed][0][364] for seed in seeds])
    encdec_lstm_SOLS_r2 = np.stack([encdec_lstm_r2[seed][0][364] for seed in seeds])

    unet_SOLL_r2 = np.stack([unet_r2[seed][0][365] for seed in seeds])
    squeezeformer_SOLL_r2 = np.stack([squeezeformer_r2[seed][0][365] for seed in seeds])
    pure_resLSTM_SOLL_r2 = np.stack([pure_resLSTM_r2[seed][0][365] for seed in seeds])
    pao_model_SOLL_r2 = np.stack([pao_model_r2[seed][0][365] for seed in seeds])
    convnext_SOLL_r2 = np.stack([convnext_r2[seed][0][365] for seed in seeds])
    encdec_lstm_SOLL_r2 = np.stack([encdec_lstm_r2[seed][0][365] for seed in seeds])

    unet_SOLSD_r2 = np.stack([unet_r2[seed][0][366] for seed in seeds])
    squeezeformer_SOLSD_r2 = np.stack([squeezeformer_r2[seed][0][366] for seed in seeds])
    pure_resLSTM_SOLSD_r2 = np.stack([pure_resLSTM_r2[seed][0][366] for seed in seeds])
    pao_model_SOLSD_r2 = np.stack([pao_model_r2[seed][0][366] for seed in seeds])
    convnext_SOLSD_r2 = np.stack([convnext_r2[seed][0][366] for seed in seeds])
    encdec_lstm_SOLSD_r2 = np.stack([encdec_lstm_r2[seed][0][366] for seed in seeds])

    unet_SOLLD_r2 = np.stack([unet_r2[seed][0][367] for seed in seeds])
    squeezeformer_SOLLD_r2 = np.stack([squeezeformer_r2[seed][0][367] for seed in seeds])
    pure_resLSTM_SOLLD_r2 = np.stack([pure_resLSTM_r2[seed][0][367] for seed in seeds])
    pao_model_SOLLD_r2 = np.stack([pao_model_r2[seed][0][367] for seed in seeds])
    convnext_SOLLD_r2 = np.stack([convnext_r2[seed][0][367] for seed in seeds])
    encdec_lstm_SOLLD_r2 = np.stack([encdec_lstm_r2[seed][0][367] for seed in seeds])

    dTdt_r2 = {
        'min': np.stack([unet_dTdt_r2.min(axis=0), squeezeformer_dTdt_r2.min(axis=0), pure_resLSTM_dTdt_r2.min(axis=0), pao_model_dTdt_r2.min(axis=0), convnext_dTdt_r2.min(axis=0), encdec_lstm_dTdt_r2.min(axis=0)], axis = 0),
        'median': np.stack([np.median(unet_dTdt_r2, axis=0), np.median(squeezeformer_dTdt_r2, axis=0), np.median(pure_resLSTM_dTdt_r2, axis=0), np.median(pao_model_dTdt_r2, axis=0), np.median(convnext_dTdt_r2, axis=0), np.median(encdec_lstm_dTdt_r2, axis=0)], axis = 0),
        'max': np.stack([unet_dTdt_r2.max(axis=0), squeezeformer_dTdt_r2.max(axis=0), pure_resLSTM_dTdt_r2.max(axis=0), pao_model_dTdt_r2.max(axis=0), convnext_dTdt_r2.max(axis=0), encdec_lstm_dTdt_r2.max(axis=0)], axis = 0)
    }
    dQvdt_r2 = {
        'min': np.stack([unet_dQvdt_r2.min(axis=0), squeezeformer_dQvdt_r2.min(axis=0), pure_resLSTM_dQvdt_r2.min(axis=0), pao_model_dQvdt_r2.min(axis=0), convnext_dQvdt_r2.min(axis=0), encdec_lstm_dQvdt_r2.min(axis=0)], axis = 0),
        'median': np.stack([np.median(unet_dQvdt_r2, axis=0), np.median(squeezeformer_dQvdt_r2, axis=0), np.median(pure_resLSTM_dQvdt_r2, axis=0), np.median(pao_model_dQvdt_r2, axis=0), np.median(convnext_dQvdt_r2, axis=0), np.median(encdec_lstm_dQvdt_r2, axis=0)], axis = 0),
        'max': np.stack([unet_dQvdt_r2.max(axis=0), squeezeformer_dQvdt_r2.max(axis=0), pure_resLSTM_dQvdt_r2.max(axis=0), pao_model_dQvdt_r2.max(axis=0), convnext_dQvdt_r2.max(axis=0), encdec_lstm_dQvdt_r2.max(axis=0)], axis = 0)
    }
    dQldt_r2 = {
        'min': np.stack([unet_dQldt_r2.min(axis=0), squeezeformer_dQldt_r2.min(axis=0), pure_resLSTM_dQldt_r2.min(axis=0), pao_model_dQldt_r2.min(axis=0), convnext_dQldt_r2.min(axis=0), encdec_lstm_dQldt_r2.min(axis=0)], axis = 0),
        'median': np.stack([np.median(unet_dQldt_r2, axis=0), np.median(squeezeformer_dQldt_r2, axis=0), np.median(pure_resLSTM_dQldt_r2, axis=0), np.median(pao_model_dQldt_r2, axis=0), np.median(convnext_dQldt_r2, axis=0), np.median(encdec_lstm_dQldt_r2, axis=0)], axis = 0),
        'max': np.stack([unet_dQldt_r2.max(axis=0), squeezeformer_dQldt_r2.max(axis=0), pure_resLSTM_dQldt_r2.max(axis=0), pao_model_dQldt_r2.max(axis=0), convnext_dQldt_r2.max(axis=0), encdec_lstm_dQldt_r2.max(axis=0)], axis = 0)
    }
    dQidt_r2 = {
        'min': np.stack([unet_dQidt_r2.min(axis=0), squeezeformer_dQidt_r2.min(axis=0), pure_resLSTM_dQidt_r2.min(axis=0), pao_model_dQidt_r2.min(axis=0), convnext_dQidt_r2.min(axis=0), encdec_lstm_dQidt_r2.min(axis=0)], axis = 0),
        'median': np.stack([np.median(unet_dQidt_r2, axis=0), np.median(squeezeformer_dQidt_r2, axis=0), np.median(pure_resLSTM_dQidt_r2, axis=0), np.median(pao_model_dQidt_r2, axis=0), np.median(convnext_dQidt_r2, axis=0), np.median(encdec_lstm_dQidt_r2, axis=0)], axis = 0),
        'max': np.stack([unet_dQidt_r2.max(axis=0), squeezeformer_dQidt_r2.max(axis=0), pure_resLSTM_dQidt_r2.max(axis=0), pao_model_dQidt_r2.max(axis=0), convnext_dQidt_r2.max(axis=0), encdec_lstm_dQidt_r2.max(axis=0)], axis = 0)
    }
    dUdt_r2 = {
        'min': np.stack([unet_dUdt_r2.min(axis=0), squeezeformer_dUdt_r2.min(axis=0), pure_resLSTM_dUdt_r2.min(axis=0), pao_model_dUdt_r2.min(axis=0), convnext_dUdt_r2.min(axis=0), encdec_lstm_dUdt_r2.min(axis=0)], axis = 0),
        'median': np.stack([np.median(unet_dUdt_r2, axis=0), np.median(squeezeformer_dUdt_r2, axis=0), np.median(pure_resLSTM_dUdt_r2, axis=0), np.median(pao_model_dUdt_r2, axis=0), np.median(convnext_dUdt_r2, axis=0), np.median(encdec_lstm_dUdt_r2, axis=0)], axis = 0),
        'max': np.stack([unet_dUdt_r2.max(axis=0), squeezeformer_dUdt_r2.max(axis=0), pure_resLSTM_dUdt_r2.max(axis=0), pao_model_dUdt_r2.max(axis=0), convnext_dUdt_r2.max(axis=0), encdec_lstm_dUdt_r2.max(axis=0)], axis = 0)
    }
    dVdt_r2 = {
        'min': np.stack([unet_dVdt_r2.min(axis=0), squeezeformer_dVdt_r2.min(axis=0), pure_resLSTM_dVdt_r2.min(axis=0), pao_model_dVdt_r2.min(axis=0), convnext_dVdt_r2.min(axis=0), encdec_lstm_dVdt_r2.min(axis=0)], axis = 0),
        'median': np.stack([np.median(unet_dVdt_r2, axis=0), np.median(squeezeformer_dVdt_r2, axis=0), np.median(pure_resLSTM_dVdt_r2, axis=0), np.median(pao_model_dVdt_r2, axis=0), np.median(convnext_dVdt_r2, axis=0), np.median(encdec_lstm_dVdt_r2, axis=0)], axis = 0),
        'max': np.stack([unet_dVdt_r2.max(axis=0), squeezeformer_dVdt_r2.max(axis=0), pure_resLSTM_dVdt_r2.max(axis=0), pao_model_dVdt_r2.max(axis=0), convnext_dVdt_r2.max(axis=0), encdec_lstm_dVdt_r2.max(axis=0)], axis = 0)
    }
    NETSW_r2 = {
        'min': np.array([np.min(unet_NETSW_r2), np.min(squeezeformer_NETSW_r2), np.min(pure_resLSTM_NETSW_r2), np.min(pao_model_NETSW_r2), np.min(convnext_NETSW_r2), np.min(encdec_lstm_NETSW_r2)]),
        'median': np.array([np.median(unet_NETSW_r2), np.median(squeezeformer_NETSW_r2), np.median(pure_resLSTM_NETSW_r2), np.median(pao_model_NETSW_r2), np.median(convnext_NETSW_r2), np.median(encdec_lstm_NETSW_r2)]),
        'max': np.array([np.max(unet_NETSW_r2), np.max(squeezeformer_NETSW_r2), np.max(pure_resLSTM_NETSW_r2), np.max(pao_model_NETSW_r2), np.max(convnext_NETSW_r2), np.max(encdec_lstm_NETSW_r2)])
    }
    FLWDS_r2 = {
        'min': np.array([np.min(unet_FLWDS_r2), np.min(squeezeformer_FLWDS_r2), np.min(pure_resLSTM_FLWDS_r2), np.min(pao_model_FLWDS_r2), np.min(convnext_FLWDS_r2), np.min(encdec_lstm_FLWDS_r2)]),
        'median': np.array([np.median(unet_FLWDS_r2), np.median(squeezeformer_FLWDS_r2), np.median(pure_resLSTM_FLWDS_r2), np.median(pao_model_FLWDS_r2), np.median(convnext_FLWDS_r2), np.median(encdec_lstm_FLWDS_r2)]),
        'max': np.array([np.max(unet_FLWDS_r2), np.max(squeezeformer_FLWDS_r2), np.max(pure_resLSTM_FLWDS_r2), np.max(pao_model_FLWDS_r2), np.max(convnext_FLWDS_r2), np.max(encdec_lstm_FLWDS_r2)])
    }
    PRECSC_r2 = {
        'min': np.array([np.min(unet_PRECSC_r2), np.min(squeezeformer_PRECSC_r2), np.min(pure_resLSTM_PRECSC_r2), np.min(pao_model_PRECSC_r2), np.min(convnext_PRECSC_r2), np.min(encdec_lstm_PRECSC_r2)]),
        'median': np.array([np.median(unet_PRECSC_r2), np.median(squeezeformer_PRECSC_r2), np.median(pure_resLSTM_PRECSC_r2), np.median(pao_model_PRECSC_r2), np.median(convnext_PRECSC_r2), np.median(encdec_lstm_PRECSC_r2)]),
        'max': np.array([np.max(unet_PRECSC_r2), np.max(squeezeformer_PRECSC_r2), np.max(pure_resLSTM_PRECSC_r2), np.max(pao_model_PRECSC_r2), np.max(convnext_PRECSC_r2), np.max(encdec_lstm_PRECSC_r2)])
    }
    PRECC_r2 = {
        'min': np.array([np.min(unet_PRECC_r2), np.min(squeezeformer_PRECC_r2), np.min(pure_resLSTM_PRECC_r2), np.min(pao_model_PRECC_r2), np.min(convnext_PRECC_r2), np.min(encdec_lstm_PRECC_r2)]),
        'median': np.array([np.median(unet_PRECC_r2), np.median(squeezeformer_PRECC_r2), np.median(pure_resLSTM_PRECC_r2), np.median(pao_model_PRECC_r2), np.median(convnext_PRECC_r2), np.median(encdec_lstm_PRECC_r2)]),
        'max': np.array([np.max(unet_PRECC_r2), np.max(squeezeformer_PRECC_r2), np.max(pure_resLSTM_PRECC_r2), np.max(pao_model_PRECC_r2), np.max(convnext_PRECC_r2), np.max(encdec_lstm_PRECC_r2)])
    }
    SOLS_r2 = {
        'min': np.array([np.min(unet_SOLS_r2), np.min(squeezeformer_SOLS_r2), np.min(pure_resLSTM_SOLS_r2), np.min(pao_model_SOLS_r2), np.min(convnext_SOLS_r2), np.min(encdec_lstm_SOLS_r2)]),
        'median': np.array([np.median(unet_SOLS_r2), np.median(squeezeformer_SOLS_r2), np.median(pure_resLSTM_SOLS_r2), np.median(pao_model_SOLS_r2), np.median(convnext_SOLS_r2), np.median(encdec_lstm_SOLS_r2)]),
        'max': np.array([np.max(unet_SOLS_r2), np.max(squeezeformer_SOLS_r2), np.max(pure_resLSTM_SOLS_r2), np.max(pao_model_SOLS_r2), np.max(convnext_SOLS_r2), np.max(encdec_lstm_SOLS_r2)])
    }
    SOLL_r2 = {
        'min': np.array([np.min(unet_SOLL_r2), np.min(squeezeformer_SOLL_r2), np.min(pure_resLSTM_SOLL_r2), np.min(pao_model_SOLL_r2), np.min(convnext_SOLL_r2), np.min(encdec_lstm_SOLL_r2)]),
        'median': np.array([np.median(unet_SOLL_r2), np.median(squeezeformer_SOLL_r2), np.median(pure_resLSTM_SOLL_r2), np.median(pao_model_SOLL_r2), np.median(convnext_SOLL_r2), np.median(encdec_lstm_SOLL_r2)]),
        'max': np.array([np.max(unet_SOLL_r2), np.max(squeezeformer_SOLL_r2), np.max(pure_resLSTM_SOLL_r2), np.max(pao_model_SOLL_r2), np.max(convnext_SOLL_r2), np.max(encdec_lstm_SOLL_r2)])
    }
    SOLSD_r2 = {
        'min': np.array([np.min(unet_SOLSD_r2), np.min(squeezeformer_SOLSD_r2), np.min(pure_resLSTM_SOLSD_r2), np.min(pao_model_SOLSD_r2), np.min(convnext_SOLSD_r2), np.min(encdec_lstm_SOLSD_r2)]),
        'median': np.array([np.median(unet_SOLSD_r2), np.median(squeezeformer_SOLSD_r2), np.median(pure_resLSTM_SOLSD_r2), np.median(pao_model_SOLSD_r2), np.median(convnext_SOLSD_r2), np.median(encdec_lstm_SOLSD_r2)]),
        'max': np.array([np.max(unet_SOLSD_r2), np.max(squeezeformer_SOLSD_r2), np.max(pure_resLSTM_SOLSD_r2), np.max(pao_model_SOLSD_r2), np.max(convnext_SOLSD_r2), np.max(encdec_lstm_SOLSD_r2)])
    }
    SOLLD_r2 = {
        'min': np.array([np.min(unet_SOLLD_r2), np.min(squeezeformer_SOLLD_r2), np.min(pure_resLSTM_SOLLD_r2), np.min(pao_model_SOLLD_r2), np.min(convnext_SOLLD_r2), np.min(encdec_lstm_SOLLD_r2)]),
        'median': np.array([np.median(unet_SOLLD_r2), np.median(squeezeformer_SOLLD_r2), np.median(pure_resLSTM_SOLLD_r2), np.median(pao_model_SOLLD_r2), np.median(convnext_SOLLD_r2), np.median(encdec_lstm_SOLLD_r2)]),
        'max': np.array([np.max(unet_SOLLD_r2), np.max(squeezeformer_SOLLD_r2), np.max(pure_resLSTM_SOLLD_r2), np.max(pao_model_SOLLD_r2), np.max(convnext_SOLLD_r2), np.max(encdec_lstm_SOLLD_r2)])
    }

    sigma_pressure_levels = data_v2_rh_mc.grid_info['lev'].values
    models = ['U-Net', 'Squeezeformer', 'Pure ResLSTM', 'Pao Model', 'ConvNeXt', 'Encoder-Decoder LSTM']
    colors = ['green', 'purple', 'blue', 'red', 'gold', 'orange']

    # Six variables for the six profile panels:
    r2_profiles = {
        r'(a) $R^2$: dT/dt': dTdt_r2, 
        r'(b) $R^2$: dQ$_v$/dt': dQvdt_r2, 
        r'(c) $R^2$: dQ$_l$/dt': dQldt_r2,
        r'(d) $R^2$: dQ$_i$/dt': dQidt_r2, 
        r'(e) $R^2$: dU/dt': dUdt_r2, 
        r'(f) $R^2$: dV/dt': dVdt_r2
    }
    r2_scalars = {
        'NETSW': NETSW_r2,
        'FLWDS': FLWDS_r2,
        'PRECSC': PRECSC_r2,
        'PRECC': PRECC_r2,
        'SOLS': SOLS_r2,
        'SOLL': SOLL_r2,
        'SOLSD': SOLSD_r2,
        'SOLLD': SOLLD_r2
    }
    
    # --------------------------
    # 2) Set up the figure
    # --------------------------
    fig = plt.figure(figsize=(12, 10))
    gs = gridspec.GridSpec(3, 3, height_ratios=[1, 1, 0.7], hspace=0.3, wspace=0.2)
    
    # Titles for the six profile panels
    panel_labels = ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"]
    for idx, var in enumerate(r2_profiles.keys()):
        row = idx // 3
        col = idx % 3
        ax = fig.add_subplot(gs[row, col])
        
        # plot each model's profile
        for m, model in enumerate(models):
            ax.fill_betweenx(
                data_v2_rh_mc.grid_info['lev'],  # y-axis (pressure levels)
                r2_profiles[var]['min'][m],      # lower bound (min profile)
                r2_profiles[var]['max'][m],      # upper bound (max profile)
                color=colors[m],                 # color for the model
                alpha=0.4                        # transparency for the filled area
            )
            ax.plot(r2_profiles[var]['median'][m], data_v2_rh_mc.grid_info['lev'],
                    color=colors[m], label=model, alpha = 1, linewidth = .64)
        
        ax.set_xlim(0, 1)
        ax.set_ylim(1000, 0)         # invert y
        ax.set_yticks([0, 200, 400, 600, 800, 1000])
        if col == 0:
            ax.set_ylabel("Hybrid pressure (hPa)")
        else:
            ax.set_yticklabels([])   # no y‐labels on inner panels
        
        ax.set_title(var, fontsize=12)
        ax.grid(True, linestyle='--', alpha=0.7)
    
    # --------------------------
    # 3) The bottom bar‐chart
    # --------------------------
    axb = fig.add_subplot(gs[2, :])

    r2_scalars_labels = list(r2_scalars.keys())
    num_vars = len(r2_scalars_labels)
    num_models = len(models)

    # Prepare data for plotting
    # Each variable has an array of 6 values (one per model)
    r2_scalars_min = np.array([r2_scalars[var]['min'] for var in r2_scalars_labels])  # shape (num_vars, num_models)
    r2_scalars_median = np.array([r2_scalars[var]['median'] for var in r2_scalars_labels])  # shape (num_vars, num_models)
    r2_scalars_max = np.array([r2_scalars[var]['max'] for var in r2_scalars_labels])  # shape (num_vars, num_models)

    bar_width = 0.12
    indices = np.arange(num_vars)

    for i in range(num_models):
        y = r2_scalars_median[:, i]
        err_high = r2_scalars_max[:, i] - y
        err_low = y - r2_scalars_min[:, i]
        y_err = np.vstack([err_low, err_high])
        axb.bar(indices + i * bar_width, 
                y, 
                bar_width, 
                label=models[i], 
                color=colors[i], 
                alpha = .6, 
                yerr = y_err, 
                capsize = 1,
                error_kw = dict(elinewidth=4, ecolor = colors[i], alpha=1.0))

    # Labels and title
    axb.set_xticks(indices + bar_width * (num_models - 1) / 2)
    axb.set_xticklabels(r2_scalars_labels)
    axb.set_ylim(0, 1.05)
    axb.set_ylabel(r'$R^2$')
    axb.set_title('(g) Fluxes')

    # Add horizontal grid lines for better readability
    axb.yaxis.grid(True, linestyle='--', alpha=0.7)
    leg = axb.legend(loc='lower left', ncol=2, frameon=True)
    for lh in leg.legend_handles:
        lh.set_alpha(1.0)
        
    fig.suptitle(f"Offline $R^2$ for each model ({config_names[config_name]} Configuration)", fontsize=14, y=0.95)

    plt.tight_layout()
    if save_path:
        plt.savefig(os.path.join(save_path, f'offline_r2_lines_{config_name}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

In [10]:
plot_offline_R2_lines(standard_unet_r2,
                      standard_squeezeformer_r2,
                      standard_pure_resLSTM_r2,
                      standard_pao_model_r2,
                      standard_convnext_r2,
                      standard_encdec_lstm_r2,
                      config_name = "standard",
                      show = False,
                      save_path = os.path.join(climsim3_figures_save_path_offline, 'r2_lines'))
plot_offline_R2_lines(conf_loss_unet_r2,
                      conf_loss_squeezeformer_r2,
                      conf_loss_pure_resLSTM_r2,
                      conf_loss_pao_model_r2,
                      conf_loss_convnext_r2,
                      conf_loss_encdec_lstm_r2,
                      config_name = "conf_loss",
                      show = False,
                      save_path = os.path.join(climsim3_figures_save_path_offline, 'r2_lines'))
plot_offline_R2_lines(diff_loss_unet_r2,
                      diff_loss_squeezeformer_r2,
                      diff_loss_pure_resLSTM_r2,
                      diff_loss_pao_model_r2,
                      diff_loss_convnext_r2,
                      diff_loss_encdec_lstm_r2,
                      config_name = "diff_loss",
                      show = False,
                      save_path = os.path.join(climsim3_figures_save_path_offline, 'r2_lines'))
plot_offline_R2_lines(multirep_unet_r2,
                      multirep_squeezeformer_r2,
                      multirep_pure_resLSTM_r2,
                      multirep_pao_model_r2,
                      multirep_convnext_r2,
                      multirep_encdec_lstm_r2,
                      config_name = "multirep",
                      show = False,
                      save_path = os.path.join(climsim3_figures_save_path_offline, 'r2_lines'))
plot_offline_R2_lines(v6_unet_r2,
                      v6_squeezeformer_r2,
                      v6_pure_resLSTM_r2,
                      v6_pao_model_r2,
                      v6_convnext_r2,
                      v6_encdec_lstm_r2,
                      config_name = "v6",
                      show = False,
                      save_path = os.path.join(climsim3_figures_save_path_offline, 'r2_lines'))

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


### Zonal R2 plotting function

In [None]:
def plot_zonal_r2(
    zonal_dTdt_r2,
    zonal_dQvdt_r2,
    zonal_dQldt_r2,
    zonal_dQidt_r2,
    zonal_dUdt_r2,
    zonal_dVdt_r2,
    model_name,
    config_name,
    show = True,
    save_path = None):
    fig = plt.figure(figsize=(18, 10))
    gs = gridspec.GridSpec(2, 3, height_ratios=[1, 1], hspace=0.2, wspace=0.15)
    
    y = np.arange(60)
    X, Y = np.meshgrid(np.sin(np.pi*lat_bin_mids/180), y)
    Y = (1/100) * np.mean(pressures_binned, axis = 0).T
    
    zonal_r2_profiles = {
        r'(a) $R^2$: dT/dt': data_v2_rh_mc.zonal_bin_weight_3d(np.mean(np.stack([zonal_dTdt_r2[seed] for seed in seeds]), axis = 0)[None,:,:])[0], 
        r'(b) $R^2$: dQ$_v$/dt': data_v2_rh_mc.zonal_bin_weight_3d(np.mean(np.stack([zonal_dQvdt_r2[seed] for seed in seeds]), axis = 0)[None,:,:])[0], 
        r'(c) $R^2$: dQ$_l$/dt': data_v2_rh_mc.zonal_bin_weight_3d(np.mean(np.stack([zonal_dQldt_r2[seed] for seed in seeds]), axis = 0)[None,:,:])[0],
        r'(d) $R^2$: dQ$_i$/dt': data_v2_rh_mc.zonal_bin_weight_3d(np.mean(np.stack([zonal_dQidt_r2[seed] for seed in seeds]), axis = 0)[None,:,:])[0], 
        r'(e) $R^2$: dU/dt': data_v2_rh_mc.zonal_bin_weight_3d(np.mean(np.stack([zonal_dUdt_r2[seed] for seed in seeds]), axis = 0)[None,:,:])[0], 
        r'(f) $R^2$: dV/dt': data_v2_rh_mc.zonal_bin_weight_3d(np.mean(np.stack([zonal_dVdt_r2[seed] for seed in seeds]), axis = 0)[None,:,:])[0]
    }
    
    for idx, var in enumerate(zonal_r2_profiles.keys()):
        row = idx // 3
        col = idx % 3
        ax = fig.add_subplot(gs[row, col])
        
        ax.pcolor(X, Y, zonal_r2_profiles[var].T, cmap = 'Blues', vmin = 0, vmax = 1)
        ax.contour(X, Y, zonal_r2_profiles[var].T, [0.7], colors = 'orange', linewidths = [4])
        ax.contour(X, Y, zonal_r2_profiles[var].T, [0.9], colors = 'yellow', linewidths = [4])
        ax.set_ylim(ax.get_ylim()[::-1])
        ax.set_title(var, fontsize = 15, pad = 14)
        if col == 0:
            ax.yaxis.set_ticks([1000,800,600,400,200,0])
            ax.yaxis.set_tick_params(labelsize = 14)
            if row == 0:
                ax.set_ylabel("Pressure [hPa]", fontsize = 16)
                ax.yaxis.set_label_coords(-0.2,-0.09)
        else:
            ax.set_yticklabels([])   # no y‐labels on inner panels
    
        if row == 0:
            ax.set_xticks([])
        else:
            ax.xaxis.set_ticks([np.sin(-50/180*np.pi), 0, np.sin(50/180*np.pi)])
            ax.xaxis.set_ticklabels(['50$^\circ$S', '0$^\circ$', '50$^\circ$N'], fontsize = 16)
        fig.suptitle(f'{model_names[model_name]}' + r' Offline Zonal $R^2$ averaged across seeds ' +  f'({config_names[config_name]} Configuration)', y = .97, fontsize = 18)

    if save_path:
        plt.savefig(os.path.join(save_path, f'offline_zonal_r2_{model_name}_{config_name}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

In [None]:
plot_zonal_r2(
    standard_unet_zonal_dTdt_r2,
    standard_unet_zonal_dQvdt_r2,
    standard_unet_zonal_dQldt_r2,
    standard_unet_zonal_dQidt_r2,
    standard_unet_zonal_dUdt_r2,
    standard_unet_zonal_dVdt_r2,
    model_name = 'unet',
    config_name = 'standard',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    standard_squeezeformer_zonal_dTdt_r2,
    standard_squeezeformer_zonal_dQvdt_r2,
    standard_squeezeformer_zonal_dQldt_r2,
    standard_squeezeformer_zonal_dQidt_r2,
    standard_squeezeformer_zonal_dUdt_r2,
    standard_squeezeformer_zonal_dVdt_r2,
    model_name = 'squeezeformer',
    config_name = 'standard',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    standard_pure_resLSTM_zonal_dTdt_r2,
    standard_pure_resLSTM_zonal_dQvdt_r2,
    standard_pure_resLSTM_zonal_dQldt_r2,
    standard_pure_resLSTM_zonal_dQidt_r2,
    standard_pure_resLSTM_zonal_dUdt_r2,
    standard_pure_resLSTM_zonal_dVdt_r2,
    model_name = 'pure_resLSTM',
    config_name = 'standard',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    standard_pao_model_zonal_dTdt_r2,
    standard_pao_model_zonal_dQvdt_r2,
    standard_pao_model_zonal_dQldt_r2,
    standard_pao_model_zonal_dQidt_r2,
    standard_pao_model_zonal_dUdt_r2,
    standard_pao_model_zonal_dVdt_r2,
    model_name = 'pao_model',
    config_name = 'standard',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    standard_convnext_zonal_dTdt_r2,
    standard_convnext_zonal_dQvdt_r2,
    standard_convnext_zonal_dQldt_r2,
    standard_convnext_zonal_dQidt_r2,
    standard_convnext_zonal_dUdt_r2,
    standard_convnext_zonal_dVdt_r2,
    model_name = 'convnext',
    config_name = 'standard',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    standard_encdec_lstm_zonal_dTdt_r2,
    standard_encdec_lstm_zonal_dQvdt_r2,
    standard_encdec_lstm_zonal_dQldt_r2,
    standard_encdec_lstm_zonal_dQidt_r2,
    standard_encdec_lstm_zonal_dUdt_r2,
    standard_encdec_lstm_zonal_dVdt_r2,
    model_name = 'encdec_lstm',
    config_name = 'standard',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    conf_loss_unet_zonal_dTdt_r2,
    conf_loss_unet_zonal_dQvdt_r2,
    conf_loss_unet_zonal_dQldt_r2,
    conf_loss_unet_zonal_dQidt_r2,
    conf_loss_unet_zonal_dUdt_r2,
    conf_loss_unet_zonal_dVdt_r2,
    model_name = 'unet',
    config_name = 'conf_loss',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    conf_loss_squeezeformer_zonal_dTdt_r2,
    conf_loss_squeezeformer_zonal_dQvdt_r2,
    conf_loss_squeezeformer_zonal_dQldt_r2,
    conf_loss_squeezeformer_zonal_dQidt_r2,
    conf_loss_squeezeformer_zonal_dUdt_r2,
    conf_loss_squeezeformer_zonal_dVdt_r2,
    model_name = 'squeezeformer',
    config_name = 'conf_loss',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    conf_loss_pure_resLSTM_zonal_dTdt_r2,
    conf_loss_pure_resLSTM_zonal_dQvdt_r2,
    conf_loss_pure_resLSTM_zonal_dQldt_r2,
    conf_loss_pure_resLSTM_zonal_dQidt_r2,
    conf_loss_pure_resLSTM_zonal_dUdt_r2,
    conf_loss_pure_resLSTM_zonal_dVdt_r2,
    model_name = 'pure_resLSTM',
    config_name = 'conf_loss',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    conf_loss_pao_model_zonal_dTdt_r2,
    conf_loss_pao_model_zonal_dQvdt_r2,
    conf_loss_pao_model_zonal_dQldt_r2,
    conf_loss_pao_model_zonal_dQidt_r2,
    conf_loss_pao_model_zonal_dUdt_r2,
    conf_loss_pao_model_zonal_dVdt_r2,
    model_name = 'pao_model',
    config_name = 'conf_loss',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    conf_loss_convnext_zonal_dTdt_r2,
    conf_loss_convnext_zonal_dQvdt_r2,
    conf_loss_convnext_zonal_dQldt_r2,
    conf_loss_convnext_zonal_dQidt_r2,
    conf_loss_convnext_zonal_dUdt_r2,
    conf_loss_convnext_zonal_dVdt_r2,
    model_name = 'convnext',
    config_name = 'conf_loss',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    conf_loss_encdec_lstm_zonal_dTdt_r2,
    conf_loss_encdec_lstm_zonal_dQvdt_r2,
    conf_loss_encdec_lstm_zonal_dQldt_r2,
    conf_loss_encdec_lstm_zonal_dQidt_r2,
    conf_loss_encdec_lstm_zonal_dUdt_r2,
    conf_loss_encdec_lstm_zonal_dVdt_r2,
    model_name = 'encdec_lstm',
    config_name = 'conf_loss',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    diff_loss_unet_zonal_dTdt_r2,
    diff_loss_unet_zonal_dQvdt_r2,
    diff_loss_unet_zonal_dQldt_r2,
    diff_loss_unet_zonal_dQidt_r2,
    diff_loss_unet_zonal_dUdt_r2,
    diff_loss_unet_zonal_dVdt_r2,
    model_name = 'unet',
    config_name = 'diff_loss',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    diff_loss_squeezeformer_zonal_dTdt_r2,
    diff_loss_squeezeformer_zonal_dQvdt_r2,
    diff_loss_squeezeformer_zonal_dQldt_r2,
    diff_loss_squeezeformer_zonal_dQidt_r2,
    diff_loss_squeezeformer_zonal_dUdt_r2,
    diff_loss_squeezeformer_zonal_dVdt_r2,
    model_name = 'squeezeformer',
    config_name = 'diff_loss',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    diff_loss_pure_resLSTM_zonal_dTdt_r2,
    diff_loss_pure_resLSTM_zonal_dQvdt_r2,
    diff_loss_pure_resLSTM_zonal_dQldt_r2,
    diff_loss_pure_resLSTM_zonal_dQidt_r2,
    diff_loss_pure_resLSTM_zonal_dUdt_r2,
    diff_loss_pure_resLSTM_zonal_dVdt_r2,
    model_name = 'pure_resLSTM',
    config_name = 'diff_loss',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    diff_loss_pao_model_zonal_dTdt_r2,
    diff_loss_pao_model_zonal_dQvdt_r2,
    diff_loss_pao_model_zonal_dQldt_r2,
    diff_loss_pao_model_zonal_dQidt_r2,
    diff_loss_pao_model_zonal_dUdt_r2,
    diff_loss_pao_model_zonal_dVdt_r2,
    model_name = 'pao_model',
    config_name = 'diff_loss',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    diff_loss_convnext_zonal_dTdt_r2,
    diff_loss_convnext_zonal_dQvdt_r2,
    diff_loss_convnext_zonal_dQldt_r2,
    diff_loss_convnext_zonal_dQidt_r2,
    diff_loss_convnext_zonal_dUdt_r2,
    diff_loss_convnext_zonal_dVdt_r2,
    model_name = 'convnext',
    config_name = 'diff_loss',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    diff_loss_encdec_lstm_zonal_dTdt_r2,
    diff_loss_encdec_lstm_zonal_dQvdt_r2,
    diff_loss_encdec_lstm_zonal_dQldt_r2,
    diff_loss_encdec_lstm_zonal_dQidt_r2,
    diff_loss_encdec_lstm_zonal_dUdt_r2,
    diff_loss_encdec_lstm_zonal_dVdt_r2,
    model_name = 'encdec_lstm',
    config_name = 'diff_loss',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    multirep_unet_zonal_dTdt_r2,
    multirep_unet_zonal_dQvdt_r2,
    multirep_unet_zonal_dQldt_r2,
    multirep_unet_zonal_dQidt_r2,
    multirep_unet_zonal_dUdt_r2,
    multirep_unet_zonal_dVdt_r2,
    model_name = 'unet',
    config_name = 'multirep',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    multirep_squeezeformer_zonal_dTdt_r2,
    multirep_squeezeformer_zonal_dQvdt_r2,
    multirep_squeezeformer_zonal_dQldt_r2,
    multirep_squeezeformer_zonal_dQidt_r2,
    multirep_squeezeformer_zonal_dUdt_r2,
    multirep_squeezeformer_zonal_dVdt_r2,
    model_name = 'squeezeformer',
    config_name = 'multirep',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    multirep_pure_resLSTM_zonal_dTdt_r2,
    multirep_pure_resLSTM_zonal_dQvdt_r2,
    multirep_pure_resLSTM_zonal_dQldt_r2,
    multirep_pure_resLSTM_zonal_dQidt_r2,
    multirep_pure_resLSTM_zonal_dUdt_r2,
    multirep_pure_resLSTM_zonal_dVdt_r2,
    model_name = 'pure_resLSTM',
    config_name = 'multirep',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    multirep_pao_model_zonal_dTdt_r2,
    multirep_pao_model_zonal_dQvdt_r2,
    multirep_pao_model_zonal_dQldt_r2,
    multirep_pao_model_zonal_dQidt_r2,
    multirep_pao_model_zonal_dUdt_r2,
    multirep_pao_model_zonal_dVdt_r2,
    model_name = 'pao_model',
    config_name = 'multirep',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    multirep_convnext_zonal_dTdt_r2,
    multirep_convnext_zonal_dQvdt_r2,
    multirep_convnext_zonal_dQldt_r2,
    multirep_convnext_zonal_dQidt_r2,
    multirep_convnext_zonal_dUdt_r2,
    multirep_convnext_zonal_dVdt_r2,
    model_name = 'convnext',
    config_name = 'multirep',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    multirep_encdec_lstm_zonal_dTdt_r2,
    multirep_encdec_lstm_zonal_dQvdt_r2,
    multirep_encdec_lstm_zonal_dQldt_r2,
    multirep_encdec_lstm_zonal_dQidt_r2,
    multirep_encdec_lstm_zonal_dUdt_r2,
    multirep_encdec_lstm_zonal_dVdt_r2,
    model_name = 'encdec_lstm',
    config_name = 'multirep',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    v6_unet_zonal_dTdt_r2,
    v6_unet_zonal_dQvdt_r2,
    v6_unet_zonal_dQldt_r2,
    v6_unet_zonal_dQidt_r2,
    v6_unet_zonal_dUdt_r2,
    v6_unet_zonal_dVdt_r2,
    model_name = 'unet',
    config_name = 'v6',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    v6_squeezeformer_zonal_dTdt_r2,
    v6_squeezeformer_zonal_dQvdt_r2,
    v6_squeezeformer_zonal_dQldt_r2,
    v6_squeezeformer_zonal_dQidt_r2,
    v6_squeezeformer_zonal_dUdt_r2,
    v6_squeezeformer_zonal_dVdt_r2,
    model_name = 'squeezeformer',
    config_name = 'v6',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    v6_pure_resLSTM_zonal_dTdt_r2,
    v6_pure_resLSTM_zonal_dQvdt_r2,
    v6_pure_resLSTM_zonal_dQldt_r2,
    v6_pure_resLSTM_zonal_dQidt_r2,
    v6_pure_resLSTM_zonal_dUdt_r2,
    v6_pure_resLSTM_zonal_dVdt_r2,
    model_name = 'pure_resLSTM',
    config_name = 'v6',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    v6_pao_model_zonal_dTdt_r2,
    v6_pao_model_zonal_dQvdt_r2,
    v6_pao_model_zonal_dQldt_r2,
    v6_pao_model_zonal_dQidt_r2,
    v6_pao_model_zonal_dUdt_r2,
    v6_pao_model_zonal_dVdt_r2,
    model_name = 'pao_model',
    config_name = 'v6',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    v6_convnext_zonal_dTdt_r2,
    v6_convnext_zonal_dQvdt_r2,
    v6_convnext_zonal_dQldt_r2,
    v6_convnext_zonal_dQidt_r2,
    v6_convnext_zonal_dUdt_r2,
    v6_convnext_zonal_dVdt_r2,
    model_name = 'convnext',
    config_name = 'v6',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))
plot_zonal_r2(
    v6_encdec_lstm_zonal_dTdt_r2,
    v6_encdec_lstm_zonal_dQvdt_r2,
    v6_encdec_lstm_zonal_dQldt_r2,
    v6_encdec_lstm_zonal_dQidt_r2,
    v6_encdec_lstm_zonal_dUdt_r2,
    v6_encdec_lstm_zonal_dVdt_r2,
    model_name = 'encdec_lstm',
    config_name = 'v6',
    show = False,
    save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_r2'))

### Offline bias function (single model, multiple variables)

In [None]:
def plot_zonal_mean_tendency_bias(config_name, model_name, seed, show = True, save_path = None):
    nn_preds = config_preds[config_name][model_name](seed)
    nn_diff = nn_preds - actual_target
    diff_dTdt = nn_diff[:,:,0:60]
    diff_dQvdt = nn_diff[:,:,60:120]
    diff_dQldt = nn_diff[:,:,120:180]
    diff_dQidt = nn_diff[:,:,180:240]
    diff_dUdt = nn_diff[:,:,240:300]
    diff_dVdt = nn_diff[:,:,300:360]

    # Create a figure with subplots
    fig, axs = plt.subplots(3, 2, figsize=(14, 17))
    # Generate the panel labels
    labels = [f"({letter})" for letter in string.ascii_lowercase[:6]]
    latitude_ticks = [-60, -30, 0, 30, 60]
    latitude_labels = ['60S', '30S', '0', '30N', '60N']

    # Loop through each variable and its corresponding subplot row

    nn_diff_DTPHYS = offline_var_settings['DTPHYS']['scaling'] * offline_area_time_mean_3d(diff_dTdt)
    nn_diff_DQ1PHYS = offline_var_settings['DQ1PHYS']['scaling'] * offline_area_time_mean_3d(diff_dQvdt)
    nn_diff_DQ2PHYS = offline_var_settings['DQ2PHYS']['scaling'] * offline_area_time_mean_3d(diff_dQldt)
    nn_diff_DQ3PHYS = offline_var_settings['DQ3PHYS']['scaling'] * offline_area_time_mean_3d(diff_dQidt)
    nn_diff_DUPHYS = offline_var_settings['DUPHYS']['scaling'] * offline_area_time_mean_3d(diff_dUdt)
    nn_diff_DVPHYS = offline_var_settings['DVPHYS']['scaling'] * offline_area_time_mean_3d(diff_dVdt)

    plotted_artists = {}

    plotted_artists['DTPHYS'] = nn_diff_DTPHYS.plot(ax=axs[0,0], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings['DTPHYS']['vmin'], vmax=offline_var_settings['DTPHYS']['vmax'])
    axs[0,0].set_title(f'{labels[0]} dT/dt Bias ({offline_var_settings["DTPHYS"]["unit"]})')
    axs[0,0].invert_yaxis()
    axs[0,0].set_xlabel('Latitude')

    plotted_artists['DQ1PHYS'] = nn_diff_DQ1PHYS.plot(ax=axs[0,1], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings['DQ1PHYS']['vmin'], vmax=offline_var_settings['DQ1PHYS']['vmax'])
    axs[0,1].set_title(f'{labels[1]} dQv/dt Bias ({offline_var_settings["DQ1PHYS"]["unit"]})')
    axs[0,1].invert_yaxis()
    axs[0,1].set_xlabel('Latitude')

    plotted_artists['DQ2PHYS'] = nn_diff_DQ2PHYS.plot(ax=axs[1,0], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings['DQ2PHYS']['vmin'], vmax=offline_var_settings['DQ2PHYS']['vmax'])
    axs[1,0].set_title(f'{labels[2]} dQl/dt Bias ({offline_var_settings["DQ2PHYS"]["unit"]})')
    axs[1,0].invert_yaxis()
    axs[1,0].set_xlabel('Latitude')

    plotted_artists['DQ3PHYS'] = nn_diff_DQ3PHYS.plot(ax=axs[1,1], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings['DQ3PHYS']['vmin'], vmax=offline_var_settings['DQ3PHYS']['vmax'])
    axs[1,1].set_title(f'{labels[3]} dQi/dt Bias ({offline_var_settings["DQ3PHYS"]["unit"]})')
    axs[1,1].invert_yaxis()
    axs[1,1].set_xlabel('Latitude')

    plotted_artists['DUPHYS'] = nn_diff_DUPHYS.plot(ax=axs[2,0], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings['DUPHYS']['vmin'], vmax=offline_var_settings['DUPHYS']['vmax'])
    axs[2,0].set_title(f'{labels[4]} dU/dt Bias ({offline_var_settings["DUPHYS"]["unit"]})')
    axs[2,0].invert_yaxis()
    axs[2,0].set_xlabel('Latitude')

    plotted_artists['DVPHYS'] = nn_diff_DVPHYS.plot(ax=axs[2,1], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings['DVPHYS']['vmin'], vmax=offline_var_settings['DVPHYS']['vmax'])
    axs[2,1].set_title(f'{labels[5]} dV/dt Bias ({offline_var_settings["DVPHYS"]["unit"]})')
    axs[2,1].invert_yaxis()
    axs[2,1].set_xlabel('Latitude')

    # add a colorbar to each subplot

    var_order = ['DTPHYS', 'DQ1PHYS', 'DQ2PHYS', 'DQ3PHYS', 'DUPHYS', 'DVPHYS']
    for ax, var_key in zip(axs.flat, var_order):
        img = plotted_artists[var_key]  # Use the stored artist
        if img is not None: # Check if artist exists
            cbar = fig.colorbar(img, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)
            # if var_key in ['DQ2PHYS', 'DQ3PHYS']:
            #     cbar.ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f"{x:.0e}"))
        else:
            print(f"Warning: No artist found for variable {var_key} to create colorbar.")
    
    for ax in axs.flat:
        ax.set_xticks(latitude_ticks)  # Set the positions for the ticks
        ax.set_xticklabels(latitude_labels)  # Set the custom text labels

    plt.suptitle(f'Offline Tendency Biases for {model_names[model_name]} {config_names[config_name]} (seed {seed[5:]})', fontsize=14, x = .6, y = .95)
    plt.subplots_adjust(right=1, top=.9)
    if save_path:
        plt.savefig(os.path.join(save_path, f'offline_tendency_bias_{model_name}_{config_name}_{seed}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

In [None]:
plot_zonal_mean_tendency_bias('standard', 'unet', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'unet', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'unet', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'squeezeformer', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'squeezeformer', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'squeezeformer', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'pure_resLSTM', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'pure_resLSTM', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'pure_resLSTM', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'pao_model', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'pao_model', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'pao_model', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'convnext', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'convnext', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'convnext', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'encdec_lstm', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'encdec_lstm', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('standard', 'encdec_lstm', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'unet', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'unet', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'unet', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'squeezeformer', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'squeezeformer', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'squeezeformer', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'pure_resLSTM', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'pure_resLSTM', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'pure_resLSTM', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'pao_model', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'pao_model', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'pao_model', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'convnext', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'convnext', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'convnext', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'encdec_lstm', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'encdec_lstm', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('conf_loss', 'encdec_lstm', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'unet', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'unet', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'unet', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'squeezeformer', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'squeezeformer', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'squeezeformer', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'pure_resLSTM', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'pure_resLSTM', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'pure_resLSTM', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'pao_model', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'pao_model', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'pao_model', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'convnext', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'convnext', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'convnext', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'encdec_lstm', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'encdec_lstm', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('diff_loss', 'encdec_lstm', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'unet', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'unet', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'unet', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'squeezeformer', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'squeezeformer', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'squeezeformer', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'pure_resLSTM', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'pure_resLSTM', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'pure_resLSTM', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'pao_model', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'pao_model', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'pao_model', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'convnext', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'convnext', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'convnext', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'encdec_lstm', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'encdec_lstm', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('multirep', 'encdec_lstm', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'unet', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'unet', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'unet', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'squeezeformer', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'squeezeformer', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'squeezeformer', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'pure_resLSTM', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'pure_resLSTM', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'pure_resLSTM', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'pao_model', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'pao_model', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'pao_model', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'convnext', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'convnext', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'convnext', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'encdec_lstm', 'seed_7', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'encdec_lstm', 'seed_43', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))
plot_zonal_mean_tendency_bias('v6', 'encdec_lstm', 'seed_1024', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'zonal_mean_tendency_bias'))

### Offline bias function (multiple models, single variable)

In [None]:
def plot_offline_bias_model_comparison(config_name, var, show = True, save_path = None):
    unet_preds = np.mean(np.array([config_preds[config_name]['unet'](seed) for seed in seeds]), axis = 0)
    squeezeformer_preds = np.mean(np.array([config_preds[config_name]['squeezeformer'](seed) for seed in seeds]), axis = 0)
    pure_resLSTM_preds = np.mean(np.array([config_preds[config_name]['pure_resLSTM'](seed) for seed in seeds]), axis = 0)
    pao_model_preds = np.mean(np.array([config_preds[config_name]['pao_model'](seed) for seed in seeds]), axis = 0)
    convnext_preds = np.mean(np.array([config_preds[config_name]['convnext'](seed) for seed in seeds]), axis = 0)
    encdec_lstm_preds = np.mean(np.array([config_preds[config_name]['encdec_lstm'](seed) for seed in seeds]), axis = 0)
    unet_diff = offline_var_settings[var]['scaling'] * offline_area_time_mean_3d(unet_preds[:,:,offline_var_settings[var]['var_index']:(offline_var_settings[var]['var_index']+60)] - actual_target[:,:,offline_var_settings[var]['var_index']:(offline_var_settings[var]['var_index']+60)])
    squeezeformer_diff = offline_var_settings[var]['scaling'] * offline_area_time_mean_3d(squeezeformer_preds[:,:,offline_var_settings[var]['var_index']:(offline_var_settings[var]['var_index']+60)] - actual_target[:,:,offline_var_settings[var]['var_index']:(offline_var_settings[var]['var_index']+60)])
    pure_resLSTM_diff = offline_var_settings[var]['scaling'] * offline_area_time_mean_3d(pure_resLSTM_preds[:,:,offline_var_settings[var]['var_index']:(offline_var_settings[var]['var_index']+60)] - actual_target[:,:,offline_var_settings[var]['var_index']:(offline_var_settings[var]['var_index']+60)])
    pao_model_diff = offline_var_settings[var]['scaling'] * offline_area_time_mean_3d(pao_model_preds[:,:,offline_var_settings[var]['var_index']:(offline_var_settings[var]['var_index']+60)] - actual_target[:,:,offline_var_settings[var]['var_index']:(offline_var_settings[var]['var_index']+60)])
    convnext_diff = offline_var_settings[var]['scaling'] * offline_area_time_mean_3d(convnext_preds[:,:,offline_var_settings[var]['var_index']:(offline_var_settings[var]['var_index']+60)] - actual_target[:,:,offline_var_settings[var]['var_index']:(offline_var_settings[var]['var_index']+60)])
    encdec_lstm_diff = offline_var_settings[var]['scaling'] * offline_area_time_mean_3d(encdec_lstm_preds[:,:,offline_var_settings[var]['var_index']:(offline_var_settings[var]['var_index']+60)] - actual_target[:,:,offline_var_settings[var]['var_index']:(offline_var_settings[var]['var_index']+60)])

    # Create a figure with subplots
    fig, axs = plt.subplots(2, 3, figsize=(18, 10))
    # Generate the panel labels
    labels = [f"({letter})" for letter in string.ascii_lowercase[:6]]
    latitude_ticks = [-60, -30, 0, 30, 60]
    latitude_labels = ['60S', '30S', '0', '30N', '60N']
    # Loop through each variable and its corresponding subplot row

    plotted_artists = {}

    plotted_artists['unet'] = unet_diff.plot(ax=axs[0,0], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings[var]['vmin'], vmax=offline_var_settings[var]['vmax'])
    axs[0,0].set_title(f'{labels[0]} U-Net {offline_var_settings[var]["var_title"]} Bias ({offline_var_settings[var]["unit"]})')
    axs[0,0].invert_yaxis()
    axs[0,0].set_xlabel('Latitude')

    plotted_artists['squeezeformer'] = squeezeformer_diff.plot(ax=axs[0,1], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings[var]['vmin'], vmax=offline_var_settings[var]['vmax'])
    axs[0,1].set_title(f'{labels[1]} SqueezeFormer {offline_var_settings[var]["var_title"]} Bias ({offline_var_settings[var]["unit"]})')
    axs[0,1].invert_yaxis()
    axs[0,1].set_xlabel('Latitude')

    plotted_artists['pure_resLSTM'] = pure_resLSTM_diff.plot(ax=axs[0,2], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings[var]['vmin'], vmax=offline_var_settings[var]['vmax'])
    axs[0,2].set_title(f'{labels[2]} Pure ResLSTM {offline_var_settings[var]["var_title"]} Bias ({offline_var_settings[var]["unit"]})')
    axs[0,2].invert_yaxis()
    axs[0,2].set_xlabel('Latitude')

    plotted_artists['pao_model'] = pao_model_diff.plot(ax=axs[1,0], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings[var]['vmin'], vmax=offline_var_settings[var]['vmax'])
    axs[1,0].set_title(f'{labels[3]} Pao Model {offline_var_settings[var]["var_title"]} Bias ({offline_var_settings[var]["unit"]})')
    axs[1,0].invert_yaxis()
    axs[1,0].set_xlabel('Latitude')

    plotted_artists['convnext'] = convnext_diff.plot(ax=axs[1,1], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings[var]['vmin'], vmax=offline_var_settings[var]['vmax'])
    axs[1,1].set_title(f'{labels[4]} ConvNeXt {offline_var_settings[var]["var_title"]} Bias ({offline_var_settings[var]["unit"]})')
    axs[1,1].invert_yaxis()
    axs[1,1].set_xlabel('Latitude')

    plotted_artists['encdec_lstm'] = encdec_lstm_diff.plot(ax=axs[1,2], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings[var]['vmin'], vmax=offline_var_settings[var]['vmax'])
    axs[1,2].set_title(f'{labels[5]} Encoder-Decoder LSTM {offline_var_settings[var]["var_title"]} Bias ({offline_var_settings[var]["unit"]})')
    axs[1,2].invert_yaxis()
    axs[1,2].set_xlabel('Latitude')

    # add a colorbar to each subplot

    model_order = ['unet', 'squeezeformer', 'pure_resLSTM', 'pao_model', 'convnext', 'encdec_lstm']
    for ax, var_key in zip(axs.flat, model_order):
        img = plotted_artists[var_key]  # Use the stored artist
        if img is not None: # Check if artist exists
            cbar = fig.colorbar(img, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)
            # if var_key in ['DQ2PHYS', 'DQ3PHYS']:
            #     cbar.ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f"{x:.0e}"))
        else:
            print(f"Warning: No artist found for variable {var_key} to create colorbar.")
    
    for ax in axs.flat:
        ax.set_xticks(latitude_ticks)  # Set the positions for the ticks
        ax.set_xticklabels(latitude_labels)  # Set the custom text labels
    plt.suptitle(f'Offline {offline_var_settings[var]["var_title"]} Tendency Biases Averaged Across Seeds ({config_names[config_name]} Configuration)', fontsize=14, x = .6, y = .95)
    plt.subplots_adjust(right=1, top=.9)
    if save_path:
        plt.savefig(os.path.join(save_path, f'offline_bias_model_comparison_{config_name}_{var}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

In [None]:
plot_offline_bias_model_comparison('standard', 'DTPHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('standard', 'DQ1PHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('standard', 'DQ2PHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('standard', 'DQ3PHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('standard', 'DUPHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('standard', 'DVPHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))

plot_offline_bias_model_comparison('conf_loss', 'DTPHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('conf_loss', 'DQ1PHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('conf_loss', 'DQ2PHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('conf_loss', 'DQ3PHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('conf_loss', 'DUPHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('conf_loss', 'DVPHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))

plot_offline_bias_model_comparison('diff_loss', 'DTPHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('diff_loss', 'DQ1PHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('diff_loss', 'DQ2PHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('diff_loss', 'DQ3PHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('diff_loss', 'DUPHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('diff_loss', 'DVPHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))

plot_offline_bias_model_comparison('multirep', 'DTPHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('multirep', 'DQ1PHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('multirep', 'DQ2PHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('multirep', 'DQ3PHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('multirep', 'DUPHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('multirep', 'DVPHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))

plot_offline_bias_model_comparison('v6', 'DTPHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('v6', 'DQ1PHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('v6', 'DQ2PHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('v6', 'DQ3PHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('v6', 'DUPHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))
plot_offline_bias_model_comparison('v6', 'DVPHYS', show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_bias_model_comparison'))

### Offline bias function (Kaggle, single model, multiple variables)

In [None]:
def plot_kaggle_bias(nn_preds, team, actual_target, show = True, save_path = None):
    nn_diff = nn_preds - actual_target
    diff_dTdt = nn_diff[:,:,0:60]
    diff_dQvdt = nn_diff[:,:,60:120]
    diff_dQldt = nn_diff[:,:,120:180]
    diff_dQidt = nn_diff[:,:,180:240]
    diff_dUdt = nn_diff[:,:,240:300]
    diff_dVdt = nn_diff[:,:,300:360]

    # Create a figure with subplots
    fig, axs = plt.subplots(3, 2, figsize=(14, 17))
    # Generate the panel labels
    labels = [f"({letter})" for letter in string.ascii_lowercase[:6]]
    latitude_ticks = [-60, -30, 0, 30, 60]
    latitude_labels = ['60S', '30S', '0', '30N', '60N']
    # Loop through each variable and its corresponding subplot row

    nn_diff_DTPHYS = offline_var_settings['DTPHYS']['scaling'] * offline_area_time_mean_3d(diff_dTdt)
    nn_diff_DQ1PHYS = offline_var_settings['DQ1PHYS']['scaling'] * offline_area_time_mean_3d(diff_dQvdt)
    nn_diff_DQ2PHYS = offline_var_settings['DQ2PHYS']['scaling'] * offline_area_time_mean_3d(diff_dQldt)
    nn_diff_DQ3PHYS = offline_var_settings['DQ3PHYS']['scaling'] * offline_area_time_mean_3d(diff_dQidt)
    nn_diff_DUPHYS = offline_var_settings['DUPHYS']['scaling'] * offline_area_time_mean_3d(diff_dUdt)
    nn_diff_DVPHYS = offline_var_settings['DVPHYS']['scaling'] * offline_area_time_mean_3d(diff_dVdt)

    plotted_artists = {}

    plotted_artists['DTPHYS'] = nn_diff_DTPHYS.plot(ax=axs[0,0], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings['DTPHYS']['vmin'], vmax=offline_var_settings['DTPHYS']['vmax'])
    axs[0,0].set_title(f'{labels[0]} dT/dt Bias ({offline_var_settings["DTPHYS"]["unit"]})')
    axs[0,0].invert_yaxis()
    axs[0,0].set_xlabel('Latitude')

    plotted_artists['DQ1PHYS'] = nn_diff_DQ1PHYS.plot(ax=axs[0,1], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings['DQ1PHYS']['vmin'], vmax=offline_var_settings['DQ1PHYS']['vmax'])
    axs[0,1].set_title(f'{labels[1]} dQv/dt Bias ({offline_var_settings["DQ1PHYS"]["unit"]})')
    axs[0,1].invert_yaxis()
    axs[0,1].set_xlabel('Latitude')

    plotted_artists['DQ2PHYS'] = nn_diff_DQ2PHYS.plot(ax=axs[1,0], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings['DQ2PHYS']['vmin'], vmax=offline_var_settings['DQ2PHYS']['vmax'])
    axs[1,0].set_title(f'{labels[2]} dQl/dt Bias ({offline_var_settings["DQ2PHYS"]["unit"]})')
    axs[1,0].invert_yaxis()
    axs[1,0].set_xlabel('Latitude')

    plotted_artists['DQ3PHYS'] = nn_diff_DQ3PHYS.plot(ax=axs[1,1], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings['DQ3PHYS']['vmin'], vmax=offline_var_settings['DQ3PHYS']['vmax'])
    axs[1,1].set_title(f'{labels[3]} dQi/dt Bias ({offline_var_settings["DQ3PHYS"]["unit"]})')
    axs[1,1].invert_yaxis()
    axs[1,1].set_xlabel('Latitude')

    plotted_artists['DUPHYS'] = nn_diff_DUPHYS.plot(ax=axs[2,0], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings['DUPHYS']['vmin'], vmax=offline_var_settings['DUPHYS']['vmax'])
    axs[2,0].set_title(f'{labels[4]} dU/dt Bias ({offline_var_settings["DUPHYS"]["unit"]})')
    axs[2,0].invert_yaxis()
    axs[2,0].set_xlabel('Latitude')

    plotted_artists['DVPHYS'] = nn_diff_DVPHYS.plot(ax=axs[2,1], add_colorbar=False, cmap='RdBu_r', vmin=offline_var_settings['DVPHYS']['vmin'], vmax=offline_var_settings['DVPHYS']['vmax'])
    axs[2,1].set_title(f'{labels[5]} dV/dt Bias ({offline_var_settings["DVPHYS"]["unit"]})')
    axs[2,1].invert_yaxis()
    axs[2,1].set_xlabel('Latitude')

    # add a colorbar to each subplot

    var_order = ['DTPHYS', 'DQ1PHYS', 'DQ2PHYS', 'DQ3PHYS', 'DUPHYS', 'DVPHYS']
    for ax, var_key in zip(axs.flat, var_order):
        img = plotted_artists[var_key]  # Use the stored artist
        if img is not None: # Check if artist exists
            cbar = fig.colorbar(img, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)
            if var_key in ['DQ2PHYS', 'DQ3PHYS']:
                cbar.ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f"{x:.0e}"))
        else:
            print(f"Warning: No artist found for variable {var_key} to create colorbar.")
    
    for ax in axs.flat:
        ax.set_xticks(latitude_ticks)  # Set the positions for the ticks
        ax.set_xticklabels(latitude_labels)  # Set the custom text labels

    plt.suptitle(f'Offline Tendency Biases for {team}', fontsize=14, x = .6, y = .95)
    plt.subplots_adjust(right=1, top=.9)
    if save_path:
        plt.savefig(os.path.join(save_path, f'offline_kaggle_comparison_{team}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

In [None]:
plot_kaggle_bias(greysnow, 'Greysnow', actual_target, show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_kaggle_bias'))
plot_kaggle_bias(adam, 'Adam', actual_target, show = False, save_path = os.path.join(climsim3_figures_save_path_offline, 'offline_kaggle_bias'))

# Online plotting functions

### Online Zonal Mean Bias (single model)

In [None]:
def plot_online_zonal_mean_bias(config_name, model_name, seed, num_years, show = True, save_path = None):
    # Generate the panel labels
    labels = [f"({letter})" for letter in string.ascii_lowercase[:9]]
    latitude_ticks = [-60, -30, 0, 30, 60]
    latitude_labels = ['60S', '30S', '0', '30N', '60N']
    # Loop through each variable and its corresponding subplot row
    ds_mmf_1, ds_mmf_2 = read_mmf_online_data(num_years)
    ds_nn = read_nn_online_data(config_name, model_name, seed, num_years)
    if not ds_nn:
        print(f'Load unsuccessful for {config_name}, {model_name}, {seed}. Skipping')
        return
    # Create a figure with subplots
    fig, axs = plt.subplots(3, 3, figsize=(12.8, 8), constrained_layout = True)
    # dict_keys(['T', 'Q', 'U', 'CLDLIQ', 'CLDICE', 'DTPHYS', 'DQ1PHYS', 'DQnPHYS', 'DUPHYS'])
    zonal_mean_bias = {var:online_var_settings[var]['scaling'] * (online_area_time_mean_3d(ds_nn, var) - online_area_time_mean_3d(ds_mmf_1, var)) for var in online_var_settings.keys()}

    bias_T = zonal_mean_bias['T'].plot(ax=axs[0, 0], add_colorbar=False, cmap='RdBu_r', vmin=online_var_settings['T']['vmin'], vmax=online_var_settings['T']['vmax'])
    axs[0, 0].set_title(f"{labels[0]} {online_var_settings['T']['var_title']} Bias ({online_var_settings['T']['unit']})")
    axs[0, 0].invert_yaxis()
    axs[0, 0].set_xlabel('')
    cbar = fig.colorbar(bias_T)
    axs[0, 0].set_ylabel("Hybrid pressure (hPa)")

    bias_Q = zonal_mean_bias['Q'].plot(ax=axs[0, 1], add_colorbar=False, cmap='RdBu_r', vmin=online_var_settings['Q']['vmin'], vmax=online_var_settings['Q']['vmax'])
    axs[0, 1].set_title(f"{labels[1]} {online_var_settings['Q']['var_title']} Bias ({online_var_settings['Q']['unit']})")
    axs[0, 1].invert_yaxis()
    axs[0, 1].set_xlabel('')
    axs[0, 1].set_ylabel('')
    cbar = fig.colorbar(bias_Q)

    bias_U = zonal_mean_bias['U'].plot(ax=axs[0, 2], add_colorbar=False, cmap='RdBu_r', vmin=online_var_settings['U']['vmin'], vmax=online_var_settings['U']['vmax'])
    axs[0, 2].set_title(f"{labels[2]} {online_var_settings['U']['var_title']} Bias ({online_var_settings['U']['unit']})")
    axs[0, 2].invert_yaxis()
    axs[0, 2].set_xlabel('')
    axs[0, 2].set_ylabel('')
    cbar = fig.colorbar(bias_U)

    bias_CLDLIQ = zonal_mean_bias['CLDLIQ'].plot(ax=axs[1, 0], add_colorbar=False, cmap='RdBu_r', vmin=online_var_settings['CLDLIQ']['vmin'], vmax=online_var_settings['CLDLIQ']['vmax'])
    axs[1, 0].set_title(f"{labels[3]} {online_var_settings['CLDLIQ']['var_title']} Bias ({online_var_settings['CLDLIQ']['unit']})")
    axs[1, 0].invert_yaxis()
    axs[1, 0].set_xlabel('')
    axs[1, 0].set_ylabel('')
    cbar = fig.colorbar(bias_CLDLIQ)
    axs[1, 0].set_ylabel("Hybrid pressure (hPa)")

    bias_CLDICE = zonal_mean_bias['CLDICE'].plot(ax=axs[1, 1], add_colorbar=False, cmap='RdBu_r', vmin=online_var_settings['CLDICE']['vmin'], vmax=online_var_settings['CLDICE']['vmax'])
    axs[1, 1].plot(lat_bin_mids, idx_tropopause_zm, 'k--')
    axs[1, 1].set_title(f"{labels[4]} {online_var_settings['CLDICE']['var_title']} Bias ({online_var_settings['CLDICE']['unit']})")
    axs[1, 1].invert_yaxis()
    axs[1, 1].set_xlabel('')
    axs[1, 1].set_ylabel('')
    cbar = fig.colorbar(bias_CLDICE)

    bias_DTPHYS = zonal_mean_bias['DTPHYS'].plot(ax=axs[1, 2], add_colorbar=False, cmap='RdBu_r', vmin=online_var_settings['DTPHYS']['vmin'], vmax=online_var_settings['DTPHYS']['vmax'])
    axs[1, 2].set_title(f"{labels[5]} {online_var_settings['DTPHYS']['var_title']} Bias ({online_var_settings['DTPHYS']['unit']})")
    axs[1, 2].invert_yaxis()
    axs[1, 2].set_xlabel('')
    axs[1, 2].set_ylabel('')
    cbar = fig.colorbar(bias_DTPHYS)

    bias_DQ1PHYS = zonal_mean_bias['DQ1PHYS'].plot(ax=axs[2, 0], add_colorbar=False, cmap='RdBu_r', vmin=online_var_settings['DQ1PHYS']['vmin'], vmax=online_var_settings['DQ1PHYS']['vmax'])
    axs[2, 0].set_title(f"{labels[6]} {online_var_settings['DQ1PHYS']['var_title']} Bias ({online_var_settings['DQ1PHYS']['unit']})")
    axs[2, 0].invert_yaxis()
    axs[2, 0].set_xlabel('')
    axs[2, 0].set_ylabel('')
    cbar = fig.colorbar(bias_DQ1PHYS)
    axs[2, 0].set_ylabel("Hybrid pressure (hPa)")
    axs[2, 0].set_xlabel("Latitude")

    bias_DQnPHYS = zonal_mean_bias['DQnPHYS'].plot(ax=axs[2, 1], add_colorbar=False, cmap='RdBu_r', vmin=online_var_settings['DQnPHYS']['vmin'], vmax=online_var_settings['DQnPHYS']['vmax'])
    axs[2, 1].set_title(f"{labels[6]} {online_var_settings['DQnPHYS']['var_title']} Bias ({online_var_settings['DQnPHYS']['unit']})")
    axs[2, 1].invert_yaxis()
    axs[2, 1].set_xlabel('')
    axs[2, 1].set_ylabel('')
    cbar = fig.colorbar(bias_DQnPHYS)
    axs[2, 1].set_xlabel("Latitude")

    bias_DUPHYS = zonal_mean_bias['DUPHYS'].plot(ax=axs[2, 2], add_colorbar=False, cmap='RdBu_r', vmin=online_var_settings['DUPHYS']['vmin'], vmax=online_var_settings['DUPHYS']['vmax'])
    axs[2, 2].set_title(f"{labels[6]} {online_var_settings['DUPHYS']['var_title']} Bias ({online_var_settings['DUPHYS']['unit']})")
    axs[2, 2].invert_yaxis()
    axs[2, 2].set_xlabel('')
    axs[2, 2].set_ylabel('')
    cbar = fig.colorbar(bias_DUPHYS)
    axs[2, 2].set_xlabel("Latitude")

    # Set these ticks and labels for each subplot
    for ax_row in axs:
        for ax in ax_row:
            ax.set_xticks(latitude_ticks)  # Set the positions for the ticks
            ax.set_xticklabels(latitude_labels)  # Set the custom text labels

    plt.suptitle(f"{num_years} year zonal mean difference for {model_names[model_name]} ({config_names[config_name]} Configuration, Seed {seed})", fontsize=14)

    if save_path:
        plt.savefig(os.path.join(save_path, f'online_{num_years}_year_zonal_mean_bias_{model_name}_{config_name}_{seed}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

In [None]:
plot_online_zonal_mean_bias('standard', 'unet', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'unet', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'unet', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'squeezeformer', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'squeezeformer', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'squeezeformer', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'pure_resLSTM', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'pure_resLSTM', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'pure_resLSTM', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'pao_model', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'pao_model', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'pao_model', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'convnext', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'convnext', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'convnext', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'encdec_lstm', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'encdec_lstm', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'encdec_lstm', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))

plot_online_zonal_mean_bias('conf_loss', 'unet', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'unet', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'unet', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'squeezeformer', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'squeezeformer', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'squeezeformer', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'pure_resLSTM', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'pure_resLSTM', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'pure_resLSTM', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'pao_model', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'pao_model', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'pao_model', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'convnext', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'convnext', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'convnext', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'encdec_lstm', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'encdec_lstm', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'encdec_lstm', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))

plot_online_zonal_mean_bias('diff_loss', 'unet', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'unet', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'unet', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'squeezeformer', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'squeezeformer', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'squeezeformer', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'pure_resLSTM', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'pure_resLSTM', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'pure_resLSTM', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'pao_model', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'pao_model', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'pao_model', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'convnext', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'convnext', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'convnext', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'encdec_lstm', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'encdec_lstm', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'encdec_lstm', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))

plot_online_zonal_mean_bias('multirep', 'unet', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'unet', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'unet', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'squeezeformer', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'squeezeformer', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'squeezeformer', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'pure_resLSTM', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'pure_resLSTM', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'pure_resLSTM', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'pao_model', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'pao_model', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'pao_model', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'convnext', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'convnext', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'convnext', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'encdec_lstm', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'encdec_lstm', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'encdec_lstm', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))

plot_online_zonal_mean_bias('v6', 'unet', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'unet', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'unet', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'squeezeformer', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'squeezeformer', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'squeezeformer', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'pure_resLSTM', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'pure_resLSTM', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'pure_resLSTM', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'pao_model', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'pao_model', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'pao_model', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'convnext', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'convnext', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'convnext', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'encdec_lstm', 7, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'encdec_lstm', 43, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'encdec_lstm', 1024, 4, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))

In [None]:
plot_online_zonal_mean_bias('standard', 'unet', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'unet', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'unet', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'squeezeformer', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'squeezeformer', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'squeezeformer', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'pure_resLSTM', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'pure_resLSTM', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'pure_resLSTM', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'pao_model', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'pao_model', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'pao_model', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'convnext', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'convnext', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'convnext', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'encdec_lstm', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'encdec_lstm', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('standard', 'encdec_lstm', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))

plot_online_zonal_mean_bias('conf_loss', 'unet', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'unet', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'unet', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'squeezeformer', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'squeezeformer', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'squeezeformer', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'pure_resLSTM', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'pure_resLSTM', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'pure_resLSTM', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'pao_model', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'pao_model', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'pao_model', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'convnext', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'convnext', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'convnext', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'encdec_lstm', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'encdec_lstm', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('conf_loss', 'encdec_lstm', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))

plot_online_zonal_mean_bias('diff_loss', 'unet', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'unet', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'unet', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'squeezeformer', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'squeezeformer', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'squeezeformer', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'pure_resLSTM', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'pure_resLSTM', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'pure_resLSTM', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'pao_model', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'pao_model', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'pao_model', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'convnext', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'convnext', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'convnext', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'encdec_lstm', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'encdec_lstm', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('diff_loss', 'encdec_lstm', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))

plot_online_zonal_mean_bias('multirep', 'unet', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'unet', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'unet', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'squeezeformer', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'squeezeformer', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'squeezeformer', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'pure_resLSTM', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'pure_resLSTM', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'pure_resLSTM', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'pao_model', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'pao_model', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'pao_model', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'convnext', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'convnext', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'convnext', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'encdec_lstm', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'encdec_lstm', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('multirep', 'encdec_lstm', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))

plot_online_zonal_mean_bias('v6', 'unet', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'unet', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'unet', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'squeezeformer', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'squeezeformer', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'squeezeformer', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'pure_resLSTM', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'pure_resLSTM', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'pure_resLSTM', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'pao_model', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'pao_model', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'pao_model', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'convnext', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'convnext', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'convnext', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'encdec_lstm', 7, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'encdec_lstm', 43, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))
plot_online_zonal_mean_bias('v6', 'encdec_lstm', 1024, 5, show = False, save_path = os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias'))

### Online zonal mean bias model comparison

In [None]:
def plot_online_zonal_mean_bias_model_comparison(config_name, var, seed, num_years, show = True, save_path = None):
    # Create a figure with subplots
    fig, axs = plt.subplots(2, 3, figsize=(12.8, 6), constrained_layout = True) 
    # Generate the panel labels
    labels = [f"({letter})" for letter in string.ascii_lowercase[:9]]
    latitude_ticks = [-60, -30, 0, 30, 60]
    latitude_labels = ['60S', '30S', '0', '30N', '60N']
    # Loop through each variable and its corresponding subplot row
    ds_mmf_1, ds_mmf_2 = read_mmf_online_data(num_years)
    ds_nn = {
        'unet': read_nn_online_data(config_name, 'unet', seed, num_years),
        'squeezeformer': read_nn_online_data(config_name, 'squeezeformer', seed, num_years),
        'pure_resLSTM': read_nn_online_data(config_name, 'pure_resLSTM', seed, num_years),
        'pao_model': read_nn_online_data(config_name, 'pao_model', seed, num_years),
        'convnext': read_nn_online_data(config_name, 'convnext', seed, num_years),
        'encdec_lstm': read_nn_online_data(config_name, 'encdec_lstm', seed, num_years)
    }
    if not ds_nn['unet'] or not ds_nn['squeezeformer'] or not ds_nn['pure_resLSTM'] or not ds_nn['pao_model'] or not ds_nn['convnext'] or not ds_nn['encdec_lstm']:
        return

    # dict_keys(['T', 'Q', 'U', 'CLDLIQ', 'CLDICE', 'DTPHYS', 'DQ1PHYS', 'DQnPHYS', 'DUPHYS'])
    zonal_mean_bias = {model:online_var_settings[var]['scaling'] * (online_area_time_mean_3d(ds_nn[model], var) - online_area_time_mean_3d(ds_mmf_1, var)) for model in model_names}

    unet_bias = zonal_mean_bias['unet'].plot(ax=axs[0, 0], add_colorbar=False, cmap='RdBu_r', vmin=online_var_settings[var]['vmin'], vmax=online_var_settings[var]['vmax'])
    axs[0, 0].set_title(f"{labels[0]} {model_names['unet']}")
    axs[0, 0].invert_yaxis()
    axs[0, 0].set_xlabel('')
    cbar = fig.colorbar(unet_bias)
    axs[0, 0].set_ylabel("Hybrid pressure (hPa)")

    squeezeformer_bias = zonal_mean_bias['squeezeformer'].plot(ax=axs[0, 1], add_colorbar=False, cmap='RdBu_r', vmin=online_var_settings[var]['vmin'], vmax=online_var_settings[var]['vmax'])
    axs[0, 1].set_title(f"{labels[1]} {model_names['squeezeformer']}")
    axs[0, 1].invert_yaxis()
    axs[0, 1].set_xlabel('')
    axs[0, 1].set_ylabel('')
    cbar = fig.colorbar(squeezeformer_bias)

    pure_resLSTM_bias = zonal_mean_bias['pure_resLSTM'].plot(ax=axs[0, 2], add_colorbar=False, cmap='RdBu_r', vmin=online_var_settings[var]['vmin'], vmax=online_var_settings[var]['vmax'])
    axs[0, 2].set_title(f"{labels[2]} {model_names['pure_resLSTM']}")
    axs[0, 2].invert_yaxis()
    axs[0, 2].set_xlabel('')
    axs[0, 2].set_ylabel('')
    cbar = fig.colorbar(pure_resLSTM_bias)

    pao_model_bias = zonal_mean_bias['pao_model'].plot(ax=axs[1, 0], add_colorbar=False, cmap='RdBu_r', vmin=online_var_settings[var]['vmin'], vmax=online_var_settings[var]['vmax'])
    axs[1, 0].set_title(f"{labels[3]} {model_names['pao_model']}")
    axs[1, 0].invert_yaxis()
    axs[1, 0].set_xlabel('')
    axs[1, 0].set_ylabel('')
    cbar = fig.colorbar(pao_model_bias)
    axs[1, 0].set_ylabel("Hybrid pressure (hPa)")
    axs[1, 0].set_xlabel("Latitude")

    convnext_bias = zonal_mean_bias['convnext'].plot(ax=axs[1, 1], add_colorbar=False, cmap='RdBu_r', vmin=online_var_settings[var]['vmin'], vmax=online_var_settings[var]['vmax'])
    axs[1, 1].set_title(f"{labels[4]} {model_names['convnext']}")
    axs[1, 1].invert_yaxis()
    axs[1, 1].set_xlabel('')
    axs[1, 1].set_ylabel('')
    cbar = fig.colorbar(convnext_bias)
    axs[1, 1].set_xlabel("Latitude")

    encdec_lstm_bias = zonal_mean_bias['encdec_lstm'].plot(ax=axs[1, 2], add_colorbar=False, cmap='RdBu_r', vmin=online_var_settings[var]['vmin'], vmax=online_var_settings[var]['vmax'])
    axs[1, 2].set_title(f"{labels[5]} {model_names['encdec_lstm']}")
    axs[1, 2].invert_yaxis()
    axs[1, 2].set_xlabel('')
    axs[1, 2].set_ylabel('')
    cbar = fig.colorbar(encdec_lstm_bias)
    axs[1, 2].set_xlabel("Latitude")

    if var == 'CLDICE':
        for ax_row in axs:
            for ax in ax_row:
                ax.plot(lat_bin_mids, idx_tropopause_zm, 'k--')

    # Set these ticks and labels for each subplot
    for ax_row in axs:
        for ax in ax_row:
            ax.set_xticks(latitude_ticks)  # Set the positions for the ticks
            ax.set_xticklabels(latitude_labels)  # Set the custom text labels

    plt.suptitle(f"{num_years} year {online_var_settings[var]['var_title']} ({online_var_settings[var]['unit']}) zonal mean difference ({config_names[config_name]} Configuration, Seed {seed})", fontsize=14)

    if save_path:
        plt.savefig(os.path.join(save_path, f'online_{num_years}_year_zonal_mean_{var}_bias_model_comparison_{config_name}_{seed}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

In [None]:
years_to_try = [4,5]
for years in years_to_try:
    for online_var in tqdm(online_var_settings.keys()):
        for config_name in ['standard', 'conf_loss', 'diff_loss', 'multirep', 'v6']:
            for seed in [7, 43, 1024]:
                plot_online_zonal_mean_bias_model_comparison(config_name, online_var, seed, years, show=False, save_path=os.path.join(climsim3_figures_save_path_online, 'online_zonal_mean_bias_model_comparison'))

### Online RMSE comparison (models)

In [None]:
def plot_online_rmse_model_comparison(config_name, num_years, show = True, save_path = None):
    months = np.arange(1, num_years * 12 + 1)
    def calculate_rmse(ds1, ds2, total_weight, var):
        # Initialize the RMSE array with NaN values
        rmse_per_month = np.full(len(months), np.nan)
        if not ds1:
            return rmse_per_month

        total_weight = get_dp(ds2) * ds2['area'].values[:,None,:]
        # Determine the number of months in ds1
        num_months = ds1[var].shape[0]
        
        # Slice total_weight to match the number of months in ds1
        total_weight_sliced = total_weight[:num_months, :, :]
        
        # Compute RMSE for existing months
        squared_diff = (ds1[var] - ds2[var]) ** 2
        weighted_squared_diff = squared_diff * total_weight_sliced
        weighted_sum = weighted_squared_diff.sum(axis=(1, 2))
        total_weight_sum = total_weight_sliced.sum(axis=(1, 2))
        weighted_mean_squared_diff = weighted_sum / total_weight_sum
        rmse_existing_months = np.sqrt(weighted_mean_squared_diff)
        
        # Fill in the RMSE array with the computed values
        rmse_per_month[:num_months] = rmse_existing_months.values
        rmse_per_month = rmse_per_month * online_var_settings[var]['scaling']
        return rmse_per_month

    ds_mmf_1, ds_mmf_2 = read_mmf_online_data(num_years)
    mmf_1_total_weight = get_pressure_area_weights(ds_mmf_1)
    mmf_1_rmse = {
        'T': calculate_rmse(ds_mmf_2, ds_mmf_1, mmf_1_total_weight, var='T'),
        'Q': calculate_rmse(ds_mmf_2, ds_mmf_1, mmf_1_total_weight, var='Q'),
        'U': calculate_rmse(ds_mmf_2, ds_mmf_1, mmf_1_total_weight, var='U'),
        'TOTCLD': calculate_rmse(ds_mmf_2, ds_mmf_1, mmf_1_total_weight, var='TOTCLD')
    }
    ds_nn = {
        'unet': {seed_number: read_nn_online_data(config_name, 'unet', seed_number, num_years) for seed_number in seed_numbers},
        'squeezeformer': {seed_number: read_nn_online_data(config_name, 'squeezeformer', seed_number, num_years) for seed_number in seed_numbers},
        'pure_resLSTM': {seed_number: read_nn_online_data(config_name, 'pure_resLSTM', seed_number, num_years) for seed_number in seed_numbers},
        'pao_model': {seed_number: read_nn_online_data(config_name, 'pao_model', seed_number, num_years) for seed_number in seed_numbers},
        'convnext': {seed_number: read_nn_online_data(config_name, 'convnext', seed_number, num_years) for seed_number in seed_numbers},
        'encdec_lstm': {seed_number: read_nn_online_data(config_name, 'encdec_lstm', seed_number, num_years) for seed_number in seed_numbers}
    }
    ds_nn_rmse = {
        'T':{model_name: np.array([calculate_rmse(ds_nn[model_name][seed_number], ds_mmf_1, mmf_1_total_weight, var = 'T') for seed_number in seed_numbers]) for model_name in model_names.keys()},
        'Q':{model_name: np.array([calculate_rmse(ds_nn[model_name][seed_number], ds_mmf_1, mmf_1_total_weight, var = 'Q') for seed_number in seed_numbers]) for model_name in model_names.keys()},
        'CLDLIQ':{model_name: np.array([calculate_rmse(ds_nn[model_name][seed_number], ds_mmf_1, mmf_1_total_weight, var = 'CLDLIQ') for seed_number in seed_numbers]) for model_name in model_names.keys()},
        'U':{model_name: np.array([calculate_rmse(ds_nn[model_name][seed_number], ds_mmf_1, mmf_1_total_weight, var = 'U') for seed_number in seed_numbers]) for model_name in model_names.keys()},
        'TOTCLD':{model_name: np.array([calculate_rmse(ds_nn[model_name][seed_number], ds_mmf_1, mmf_1_total_weight, var = 'TOTCLD') for seed_number in seed_numbers]) for model_name in model_names.keys()}
    }
    fig, axes = plt.subplots(2, 2, figsize=(18, 8))
    axes[0,0].plot(months, mmf_1_rmse['T'], label='Internal Variability', color='black', marker='o', markersize = 3)
    for model_name in model_names.keys():
        axes[0,0].fill_between(
            months,
            np.nanmin(ds_nn_rmse['T'][model_name], axis = 0),
            np.nanmax(ds_nn_rmse['T'][model_name], axis = 0),
            color = color_dict[model_name],
            alpha = 0.15
        )
        if np.sum(np.isnan(ds_nn_rmse['T'][model_name])) != np.prod(ds_nn_rmse['T'][model_name].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['T'][model_name] - mmf_1_rmse['T']), axis = 1)
            line, = axes[0,0].plot(months, ds_nn_rmse['T'][model_name][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names[model_name], color = color_dict[model_name], linestyle = '--')
        else:
            line, = axes[0,0].plot(months, np.full(months.shape, np.nan), label = model_names[model_name], color = color_dict[model_name])
    axes[0,0].set_xticks(np.arange(0,12*(num_years+1),12), np.arange(0,num_years+1)) 
    axes[0,0].set_yscale('log')
    axes[0,0].set_yticks([0.5, 1, 2, 5, 10, 20, 50, 100])
    axes[0,0].set_ylim(None, 100)
    axes[0,0].set_yticklabels(['0.5', '1', '2', '5', '10', '20', '50', '100'], fontsize=12)
    axes[0,0].set_ylabel('Online RMSE (K)', fontsize=15)
    axes[0,0].set_title('(a) Temperature', fontsize=15)
    axes[0,0].legend(fontsize=13, loc='upper left', ncol = 2)
    axes[0,0].grid(True)

    axes[0,1].plot(months, calculate_rmse(ds_mmf_2, ds_mmf_1, mmf_1_total_weight, var='Q'), label='Internal Variability', color='black', marker='o', markersize = 3)
    for model_name in model_names.keys():
        axes[0,1].fill_between(
            months,
            np.nanmin(ds_nn_rmse['Q'][model_name], axis = 0),
            np.nanmax(ds_nn_rmse['Q'][model_name], axis = 0),
            color = color_dict[model_name],
            alpha = 0.15
        )
        if np.sum(np.isnan(ds_nn_rmse['Q'][model_name])) != np.prod(ds_nn_rmse['Q'][model_name].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['Q'][model_name] - mmf_1_rmse['Q']), axis = 1)
            line, = axes[0,1].plot(months, ds_nn_rmse['Q'][model_name][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names[model_name], color = color_dict[model_name], linestyle = '--')
        else:
            line, = axes[0,1].plot(months, np.full(months.shape, np.nan), label = model_names[model_name], color = color_dict[model_name])
    axes[0,1].set_xticks(np.arange(0,12*(num_years+1),12), np.arange(0,num_years+1)) 
    axes[0,1].set_yscale('log')
    axes[0,1].set_yticks([0.1, 0.2, 0.5, 1, 2, 5, 10])
    axes[0,1].set_ylim(None, 10)
    axes[0,1].set_yticklabels(['0.1', '0.2', '0.5', '1', '2', '5', '10'], fontsize=12)
    axes[0,1].set_ylabel('Online RMSE (g/kg)', fontsize=15)
    axes[0,1].set_title('(b) Moisture', fontsize=15)
    axes[0,1].grid(True)

    axes[1,0].plot(months, calculate_rmse(ds_mmf_2, ds_mmf_1, mmf_1_total_weight, var='U'), label='Internal Variability', color='black', marker='o', markersize = 3)
    for model_name in model_names.keys():
        axes[1,0].fill_between(
            months,
            np.nanmin(ds_nn_rmse['U'][model_name], axis = 0),
            np.nanmax(ds_nn_rmse['U'][model_name], axis = 0),
            color = color_dict[model_name],
            alpha = 0.15
        )
        if np.sum(np.isnan(ds_nn_rmse['U'][model_name])) != np.prod(ds_nn_rmse['U'][model_name].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['U'][model_name] - mmf_1_rmse['U']), axis = 1)
            line, = axes[1,0].plot(months, ds_nn_rmse['U'][model_name][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names[model_name], color = color_dict[model_name], linestyle = '--')
        else:
            line, = axes[1,0].plot(months, np.full(months.shape, np.nan), label = model_names[model_name], color = color_dict[model_name])
    axes[1,0].set_xticks(np.arange(0,12*(num_years+1),12), np.arange(0,num_years+1))
    axes[1,0].set_yscale('log')
    axes[1,0].set_yticks([0.5, 1, 2, 5, 10, 20, 50, 100])
    axes[1,0].set_ylim(None, 100)
    axes[1,0].set_yticklabels(['0.5', '1', '2', '5', '10', '20', '50', '100'], fontsize=20)
    axes[1,0].set_ylabel('Online RMSE (m/s)', fontsize=15)
    axes[1,0].set_title('(c) Zonal Wind', fontsize=15)
    axes[1,0].grid(True)

    axes[1,1].plot(months, calculate_rmse(ds_mmf_2, ds_mmf_1, mmf_1_total_weight, var='TOTCLD'), label='Internal Variability', color='black', marker='o', markersize = 3)
    for model_name in model_names.keys():
        axes[1,1].fill_between(
            months,
            np.nanmin(ds_nn_rmse['TOTCLD'][model_name], axis = 0),
            np.nanmax(ds_nn_rmse['TOTCLD'][model_name], axis = 0),
            color = color_dict[model_name],
            alpha = 0.15
        )
        if np.sum(np.isnan(ds_nn_rmse['TOTCLD'][model_name])) != np.prod(ds_nn_rmse['TOTCLD'][model_name].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['TOTCLD'][model_name] - mmf_1_rmse['TOTCLD']), axis = 1)
            line, = axes[1,1].plot(months, ds_nn_rmse['TOTCLD'][model_name][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names[model_name], color = color_dict[model_name], linestyle = '--')
        else:
            line, = axes[1,1].plot(months, np.full(months.shape, np.nan), label = model_names[model_name], color = color_dict[model_name])
    axes[1,1].set_xticks(np.arange(0,12*(num_years+1),12), np.arange(0,num_years+1)) 
    axes[1,1].set_yscale('log')
    axes[1,1].set_yticks([2, 5, 10, 20, 50, 100, 200, 500])
    axes[1,1].set_ylim(None, 500)
    axes[1,1].set_yticklabels(['2', '5', '10', '20', '50', '100', '200', '500'], fontsize=12)
    axes[1,1].set_ylabel('Online RMSE (mg/kg)', fontsize=15)
    axes[1,1].set_title('(d) Total Cloud', fontsize=15)
    axes[1,1].grid(True)

    axes[1,0].set_xlabel('Year', fontsize=15)
    axes[1,1].set_xlabel('Year', fontsize=15)

    fig.suptitle(f'{num_years} Year Online Root Mean Squared Error ({config_names[config_name]} Configuration)', fontsize = 16)

    if save_path:
        plt.savefig(os.path.join(save_path, f'online_{num_years}_year_RMSE_model_comparison_{config_name}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

### Online RMSE comparison (configurations)

In [None]:
def plot_online_rmse_config_comparison(model_name, num_years, show = True, save_path = None):
    months = np.arange(1, num_years * 12 + 1)
    def calculate_rmse(ds1, ds2, total_weight, var):
        # Initialize the RMSE array with NaN values
        rmse_per_month = np.full(len(months), np.nan)
        if not ds1:
            return rmse_per_month
        
        total_weight = get_dp(ds2) * ds2['area'].values[:,None,:]
        # Determine the number of months in ds1
        num_months = ds1[var].shape[0]
        
        # Slice total_weight to match the number of months in ds1
        total_weight_sliced = total_weight[:num_months, :, :]
        
        # Compute RMSE for existing months
        squared_diff = (ds1[var] - ds2[var]) ** 2
        weighted_squared_diff = squared_diff * total_weight_sliced
        weighted_sum = weighted_squared_diff.sum(axis=(1, 2))
        total_weight_sum = total_weight_sliced.sum(axis=(1, 2))
        weighted_mean_squared_diff = weighted_sum / total_weight_sum
        rmse_existing_months = np.sqrt(weighted_mean_squared_diff)
        
        # Fill in the RMSE array with the computed values
        rmse_per_month[:num_months] = rmse_existing_months.values
        rmse_per_month = rmse_per_month * online_var_settings[var]['scaling']
        return rmse_per_month

    ds_mmf_1, ds_mmf_2 = read_mmf_online_data(num_years)
    mmf_1_total_weight = get_pressure_area_weights(ds_mmf_1)
    mmf_1_rmse = {
        'T': calculate_rmse(ds_mmf_2, ds_mmf_1, mmf_1_total_weight, var='T'),
        'Q': calculate_rmse(ds_mmf_2, ds_mmf_1, mmf_1_total_weight, var='Q'),
        'U': calculate_rmse(ds_mmf_2, ds_mmf_1, mmf_1_total_weight, var='U'),
        'TOTCLD': calculate_rmse(ds_mmf_2, ds_mmf_1, mmf_1_total_weight, var='TOTCLD')
    }
    ds_nn = {
        'standard': {seed_number: read_nn_online_data('standard', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'conf_loss': {seed_number: read_nn_online_data('conf_loss', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'diff_loss': {seed_number: read_nn_online_data('diff_loss', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'multirep': {seed_number: read_nn_online_data('multirep', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'v6': {seed_number: read_nn_online_data('v6', model_name, seed_number, num_years) for seed_number in seed_numbers}
    }
    ds_nn_rmse = {
        'T':{config_name: np.array([calculate_rmse(ds_nn[config_name][seed_number], ds_mmf_1, mmf_1_total_weight, var = 'T') for seed_number in seed_numbers]) for config_name in config_names.keys()},
        'Q':{config_name: np.array([calculate_rmse(ds_nn[config_name][seed_number], ds_mmf_1, mmf_1_total_weight, var = 'Q') for seed_number in seed_numbers]) for config_name in config_names.keys()},
        'CLDLIQ':{config_name: np.array([calculate_rmse(ds_nn[config_name][seed_number], ds_mmf_1, mmf_1_total_weight, var = 'CLDLIQ') for seed_number in seed_numbers]) for config_name in config_names.keys()},
        'U':{config_name: np.array([calculate_rmse(ds_nn[config_name][seed_number], ds_mmf_1, mmf_1_total_weight, var = 'U') for seed_number in seed_numbers]) for config_name in config_names.keys()},
        'TOTCLD':{config_name: np.array([calculate_rmse(ds_nn[config_name][seed_number], ds_mmf_1, mmf_1_total_weight, var = 'TOTCLD') for seed_number in seed_numbers]) for config_name in config_names.keys()}
    }
    fig, axes = plt.subplots(2, 2, figsize=(18, 8))
    axes[0,0].plot(months, mmf_1_rmse['T'], label='Internal Variability', color='black', marker='o', markersize = 3)
    for config_name in config_names.keys():
        axes[0,0].fill_between(
            months,
            np.nanmin(ds_nn_rmse['T'][config_name], axis = 0),
            np.nanmax(ds_nn_rmse['T'][config_name], axis = 0),
            color = color_dict_config[config_name],
            alpha = 0.15
        )
        if np.sum(np.isnan(ds_nn_rmse['T'][config_name])) != np.prod(ds_nn_rmse['T'][config_name].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['T'][config_name] - mmf_1_rmse['T']), axis = 1)
            line, = axes[0,0].plot(months, ds_nn_rmse['T'][config_name][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = config_names[config_name], color = color_dict_config[config_name], linestyle = '--')
        else:
            line, = axes[0,0].plot(months, np.full(months.shape, np.nan), label = config_names[config_name], color = color_dict_config[config_name], linestyle = '--')
    axes[0,0].set_xticks(np.arange(0,12*(num_years+1),12), np.arange(0,num_years+1))
    axes[0,0].set_yscale('log')
    axes[0,0].set_yticks([0.5, 1, 2, 5, 10, 20, 50, 100])
    axes[0,0].set_ylim(None, 100)
    axes[0,0].set_yticklabels(['0.5', '1', '2', '5', '10', '20', '50', '100'], fontsize=12)
    axes[0,0].set_ylabel('Online RMSE (K)', fontsize=15)
    axes[0,0].set_title('(a) Temperature', fontsize=15)
    axes[0,0].legend(fontsize=13, loc='upper left', ncol = 2)
    axes[0,0].grid(True)

    axes[0,1].plot(months, calculate_rmse(ds_mmf_2, ds_mmf_1, mmf_1_total_weight, var='Q'), label='Internal Variability', color='black', marker='o', markersize = 3)
    for config_name in config_names.keys():
        axes[0,1].fill_between(
            months,
            np.nanmin(ds_nn_rmse['Q'][config_name], axis = 0),
            np.nanmax(ds_nn_rmse['Q'][config_name], axis = 0),
            color = color_dict_config[config_name],
            alpha = 0.15
        )
        if np.sum(np.isnan(ds_nn_rmse['Q'][config_name])) != np.prod(ds_nn_rmse['Q'][config_name].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['Q'][config_name] - mmf_1_rmse['Q']), axis = 1)
            line, = axes[0,1].plot(months, ds_nn_rmse['Q'][config_name][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = config_names[config_name], color = color_dict_config[config_name], linestyle = '--')
        else:
            line, = axes[0,1].plot(months, np.full(months.shape, np.nan), label = config_names[config_name], color = color_dict_config[config_name])
    axes[0,1].set_xticks(np.arange(0,12*(num_years+1),12), np.arange(0,num_years+1))
    axes[0,1].set_yscale('log')
    axes[0,1].set_yticks([0.1, 0.2, 0.5, 1, 2, 5, 10])
    axes[0,1].set_ylim(None, 10)
    axes[0,1].set_yticklabels(['0.1', '0.2', '0.5', '1', '2', '5', '10'], fontsize=12)
    axes[0,1].set_ylabel('Online RMSE (g/kg)', fontsize=15)
    axes[0,1].set_title('(b) Moisture', fontsize=15)
    axes[0,1].grid(True)

    axes[1,0].plot(months, calculate_rmse(ds_mmf_2, ds_mmf_1, mmf_1_total_weight, var='U'), label='Internal Variability', color='black', marker='o', markersize = 3)
    for config_name in config_names.keys():
        axes[1,0].fill_between(
            months,
            np.nanmin(ds_nn_rmse['U'][config_name], axis = 0),
            np.nanmax(ds_nn_rmse['U'][config_name], axis = 0),
            color = color_dict_config[config_name],
            alpha = 0.15
        )
        if np.sum(np.isnan(ds_nn_rmse['U'][config_name])) != np.prod(ds_nn_rmse['U'][config_name].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['U'][config_name] - mmf_1_rmse['U']), axis = 1)
            line, = axes[1,0].plot(months, ds_nn_rmse['U'][config_name][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = config_names[config_name], color = color_dict_config[config_name], linestyle = '--')
        else:
            line, = axes[1,0].plot(months, np.full(months.shape, np.nan), label = config_names[config_name], color = color_dict_config[config_name])
    axes[1,0].set_xticks(np.arange(0,12*(num_years+1),12), np.arange(0,num_years+1))
    axes[1,0].set_yscale('log')
    axes[1,0].set_yticks([0.5, 1, 2, 5, 10, 20, 50, 100])
    axes[1,0].set_ylim(None, 100)
    axes[1,0].set_yticklabels(['0.5', '1', '2', '5', '10', '20', '50', '100'], fontsize=20)
    axes[1,0].set_ylabel('Online RMSE (m/s)', fontsize=15)
    axes[1,0].set_title('(c) Zonal Wind', fontsize=15)
    axes[1,0].grid(True)

    axes[1,1].plot(months, calculate_rmse(ds_mmf_2, ds_mmf_1, mmf_1_total_weight, var='TOTCLD'), label='Internal Variability', color='black', marker='o', markersize = 3)
    for config_name in config_names.keys():
        axes[1,1].fill_between(
            months,
            np.nanmin(ds_nn_rmse['TOTCLD'][config_name], axis = 0),
            np.nanmax(ds_nn_rmse['TOTCLD'][config_name], axis = 0),
            color = color_dict_config[config_name],
            alpha = 0.15
        )
        if np.sum(np.isnan(ds_nn_rmse['TOTCLD'][config_name])) != np.prod(ds_nn_rmse['TOTCLD'][config_name].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['TOTCLD'][config_name] - mmf_1_rmse['TOTCLD']), axis = 1)
            line, = axes[1,1].plot(months, ds_nn_rmse['TOTCLD'][config_name][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = config_names[config_name], color = color_dict_config[config_name], linestyle = '--')
        else:
            line, = axes[1,1].plot(months, np.full(months.shape, np.nan), label = config_names[config_name], color = color_dict_config[config_name])
    axes[1,1].set_xticks(np.arange(0,12*(num_years+1),12), np.arange(0,num_years+1))
    axes[1,1].set_yscale('log')
    axes[1,1].set_yticks([2, 5, 10, 20, 50, 100, 200, 500])
    axes[1,1].set_ylim(None, 500)
    axes[1,1].set_yticklabels(['2', '5', '10', '20', '50', '100', '200', '500'], fontsize=12)
    axes[1,1].set_ylabel('Online RMSE (mg/kg)', fontsize=15)
    axes[1,1].set_title('(d) Total Cloud', fontsize=15)
    axes[1,1].grid(True)

    axes[1,0].set_xlabel('Year', fontsize=15)
    axes[1,1].set_xlabel('Year', fontsize=15)

    fig.suptitle(f'{num_years} Year Online Root Mean Squared Error ({model_names[model_name]})', fontsize = 16)

    if save_path:
        plt.savefig(os.path.join(save_path, f'online_{num_years}_year_RMSE_config_comparison_{model_name}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

### Area-weighted online mean values (model comparison)

In [None]:
def plot_online_area_mean_model_comparison(config_name, num_years, show = True, save_path = None):
    months = np.arange(1, num_years * 12 + 1)
    ds_mmf_1, ds_mmf_2 = read_mmf_online_data(num_years)
    column_titles = ['Global mean', '30N-90N mean', '30S-90S mean', '30S-30N mean']
    row_ylabels = [
        'T$_{59}$ (K)',
        'Precipitable water (kg/m$^2$)',
        'Total cloud path (kg/m$^2$)'
    ]
    variables = ['T', 'TMQ', 'TCP']
    def calculate_mean(ds, w, var):
        mean_per_month = np.full(len(months), np.nan)
        if not ds:
            return mean_per_month
        if var == 'T':
            mean_per_month[:len(ds['time'])] = np.average(ds['T'][:, -1, :].values, weights=w, axis=1)
        elif var == 'TMQ':
            mean_per_month[:len(ds['time'])] = np.average(ds['TMQ'][:, :].values, weights=w, axis=1)
        elif var == 'TCP':
            mean_per_month[:len(ds['time'])] = get_tcp_mean(ds, w)
        return mean_per_month

    get_mean_function = {
        'T': lambda ds, w: calculate_mean(ds, w, 'T'),
        'TMQ': lambda ds, w: calculate_mean(ds, w, 'TMQ'),
        'TCP': lambda ds, w: calculate_mean(ds, w, 'TCP')
    }
    ds_nn = {
        'unet': {seed_number: read_nn_online_data(config_name, 'unet', seed_number, num_years) for seed_number in seed_numbers},
        'squeezeformer': {seed_number: read_nn_online_data(config_name, 'squeezeformer', seed_number, num_years) for seed_number in seed_numbers},
        'pure_resLSTM': {seed_number: read_nn_online_data(config_name, 'pure_resLSTM', seed_number, num_years) for seed_number in seed_numbers},
        'pao_model': {seed_number: read_nn_online_data(config_name, 'pao_model', seed_number, num_years) for seed_number in seed_numbers},
        'convnext': {seed_number: read_nn_online_data(config_name, 'convnext', seed_number, num_years) for seed_number in seed_numbers},
        'encdec_lstm': {seed_number: read_nn_online_data(config_name, 'encdec_lstm', seed_number, num_years) for seed_number in seed_numbers}
    }
    fig, axes = plt.subplots(3, 4, figsize=(18, 8))
    for row, var in enumerate(variables):
        for col, weight_key in enumerate(area_weight_dict.keys()):
            weight = area_weight_dict[weight_key]
            ds_mmf_1_var_vals = get_mean_function[var](ds_mmf_1, weight)
            line_mmf_1, = axes[row, col].plot(months, ds_mmf_1_var_vals, label = 'MMF', color = 'black')
            line_mmf_2, = axes[row, col].plot(months, get_mean_function[var](ds_mmf_2, weight), label = 'MMF2', color = 'black', linestyle = 'dashed')
            ds_nn_mean = {
                'unet': np.array([get_mean_function[var](ds_nn['unet'][seed_number], weight) for seed_number in seed_numbers]),
                'squeezeformer': np.array([get_mean_function[var](ds_nn['squeezeformer'][seed_number], weight) for seed_number in seed_numbers]),
                'pure_resLSTM': np.array([get_mean_function[var](ds_nn['pure_resLSTM'][seed_number], weight) for seed_number in seed_numbers]),
                'pao_model': np.array([get_mean_function[var](ds_nn['pao_model'][seed_number], weight) for seed_number in seed_numbers]),
                'convnext': np.array([get_mean_function[var](ds_nn['convnext'][seed_number], weight) for seed_number in seed_numbers]),
                'encdec_lstm': np.array([get_mean_function[var](ds_nn['encdec_lstm'][seed_number], weight) for seed_number in seed_numbers])
            }
            for model_name in model_names.keys():
                axes[row, col].fill_between(
                    months,
                    np.nanmin(ds_nn_mean[model_name], axis = 0),
                    np.nanmax(ds_nn_mean[model_name], axis = 0),
                    color = color_dict[model_name],
                    alpha=0.15
                )
            if np.sum(np.isnan(ds_nn_mean['unet'])) != np.prod(ds_nn_mean['unet'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_mean['unet'] - ds_mmf_1_var_vals), axis = 1)
                line_unet, = axes[row, col].plot(months, ds_nn_mean['unet'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names['unet'], color = color_dict['unet'], linestyle = '--')
            else:
                line_unet, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = model_names['unet'], color = color_dict['unet'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_mean['squeezeformer'])) != np.prod(ds_nn_mean['squeezeformer'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_mean['squeezeformer'] - ds_mmf_1_var_vals), axis = 1)
                line_squeezeformer, = axes[row, col].plot(months, ds_nn_mean['squeezeformer'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names['squeezeformer'], color = color_dict['squeezeformer'], linestyle = '--')
            else:
                line_squeezeformer, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = model_names['squeezeformer'], color = color_dict['squeezeformer'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_mean['pure_resLSTM'])) != np.prod(ds_nn_mean['pure_resLSTM'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_mean['pure_resLSTM'] - ds_mmf_1_var_vals), axis = 1)
                line_pure_resLSTM, = axes[row, col].plot(months, ds_nn_mean['pure_resLSTM'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names['pure_resLSTM'], color = color_dict['pure_resLSTM'], linestyle = '--')
            else:
                line_pure_resLSTM, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = model_names['pure_resLSTM'], color = color_dict['pure_resLSTM'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_mean['pao_model'])) != np.prod(ds_nn_mean['pao_model'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_mean['pao_model'] - ds_mmf_1_var_vals), axis = 1)
                line_pao_model, = axes[row, col].plot(months, ds_nn_mean['pao_model'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names['pao_model'], color = color_dict['pao_model'], linestyle = '--')
            else:
                line_pao_model, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = model_names['pao_model'], color = color_dict['pao_model'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_mean['convnext'])) != np.prod(ds_nn_mean['convnext'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_mean['convnext'] - ds_mmf_1_var_vals), axis = 1)
                line_convnext, = axes[row, col].plot(months, ds_nn_mean['convnext'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names['convnext'], color = color_dict['convnext'], linestyle = '--')
            else:
                line_convnext, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = model_names['convnext'], color = color_dict['convnext'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_mean['encdec_lstm'])) != np.prod(ds_nn_mean['encdec_lstm'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_mean['encdec_lstm'] - ds_mmf_1_var_vals), axis = 1)
                line_encdec_lstm, = axes[row, col].plot(months, ds_nn_mean['encdec_lstm'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names['encdec_lstm'], color = color_dict['encdec_lstm'], linestyle = '--')
            else:
                line_encdec_lstm, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = model_names['encdec_lstm'], color = color_dict['encdec_lstm'], linestyle = '--')

            if row != 2:
                ylim_min = np.max(np.min(ds_mmf_1_var_vals) - 1.6 * np.std(ds_mmf_1_var_vals), 0)
                ylim_max = np.max(ds_mmf_1_var_vals) + 1.6 * np.std(ds_mmf_1_var_vals)
            elif col == 0:
                ylim_min = .05
                ylim_max = .4
            elif col == 1:
                ylim_min = .05
                ylim_max = .6
            elif col == 2:
                ylim_min = .05
                ylim_max = .4
            elif col == 3:
                ylim_min = .05
                ylim_max = .4
            axes[row, col].set_ylim(ylim_min, ylim_max)
            axes[row, col].grid(True)
            axes[row, col].tick_params(axis='both', labelsize=12)
            axes[row, col].set_xticks(np.arange(0,12*(num_years+1),12), np.arange(0,num_years+1))
            # Set column titles
            if row == 0:
                axes[row, col].set_title(column_titles[col],fontsize=14)
            
            # Set row y-labels
            if col == 0:
                axes[row, col].set_ylabel(row_ylabels[row],fontsize=12)
            
            # Set x-label for the last row
            if row == 2:
                axes[row, col].set_xlabel("Years",fontsize=12)
            
            if row == 2 and col == 0:
                axes[row, col].legend(handles = [line_mmf_1, line_mmf_2, line_unet, line_squeezeformer],
                                    labels = ['MMF', 'MMF2', model_names['unet'], model_names['squeezeformer']], fontsize=10,loc='upper left')
            elif row == 2 and col == 1:
                axes[row, col].legend(handles = [line_pure_resLSTM, line_pao_model, line_convnext, line_encdec_lstm],
                                      labels = [model_names['pure_resLSTM'], model_names['pao_model'], model_names['convnext'], model_names['encdec_lstm']], fontsize=10,loc='upper left')
    fig.suptitle(f'{num_years} Year Area-Weighted Mean Values ({config_names[config_name]})', fontsize = 16)
    # Adjust layout and display the plot
    plt.tight_layout()
    # plt.savefig('time_series_mean_5years.pdf', format='pdf', dpi=400, bbox_inches='tight')
    if save_path:
        plt.savefig(os.path.join(save_path, f'online_{num_years}_year_area_mean_model_comparison_{config_name}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

### Area-weighted online mean values (config comparison)

In [None]:
def plot_online_area_mean_config_comparison(model_name, num_years, show = True, save_path = None):
    months = np.arange(1, num_years * 12 + 1)
    ds_mmf_1, ds_mmf_2 = read_mmf_online_data(num_years)
    column_titles = ['Global mean', '30N-90N mean', '30S-90S mean', '30S-30N mean']
    row_ylabels = [
        'T$_{59}$ (K)',
        'Precipitable water (kg/m$^2$)',
        'Total cloud path (kg/m$^2$)'
    ]
    variables = ['T', 'TMQ', 'TCP']
    def calculate_mean(ds, w, var):
        mean_per_month = np.full(len(months), np.nan)
        if not ds:
            return mean_per_month
        if var == 'T':
            mean_per_month[:len(ds['time'])] = np.average(ds['T'][:, -1, :].values, weights=w, axis=1)
        elif var == 'TMQ':
            mean_per_month[:len(ds['time'])] = np.average(ds['TMQ'][:, :].values, weights=w, axis=1)
        elif var == 'TCP':
            mean_per_month[:len(ds['time'])] = get_tcp_mean(ds, w)
        return mean_per_month

    get_mean_function = {
        'T': lambda ds, w: calculate_mean(ds, w, 'T'),
        'TMQ': lambda ds, w: calculate_mean(ds, w, 'TMQ'),
        'TCP': lambda ds, w: calculate_mean(ds, w, 'TCP')
    }
    ds_nn = {
        'standard': {seed_number: read_nn_online_data('standard', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'conf_loss': {seed_number: read_nn_online_data('conf_loss', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'diff_loss': {seed_number: read_nn_online_data('diff_loss', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'multirep': {seed_number: read_nn_online_data('multirep', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'v6': {seed_number: read_nn_online_data('v6', model_name, seed_number, num_years) for seed_number in seed_numbers},
    }
    fig, axes = plt.subplots(3, 4, figsize=(18, 8))
    for row, var in enumerate(variables):
        for col, weight_key in enumerate(area_weight_dict.keys()):
            weight = area_weight_dict[weight_key]
            ds_mmf_1_var_vals = get_mean_function[var](ds_mmf_1, weight)
            line_mmf_1, = axes[row, col].plot(months, ds_mmf_1_var_vals, label = 'MMF', color = 'black')
            line_mmf_2, = axes[row, col].plot(months, get_mean_function[var](ds_mmf_2, weight), label = 'MMF2', color = 'black', linestyle = 'dashed')
            ds_nn_mean = {
                'standard': np.array([get_mean_function[var](ds_nn['standard'][seed_number], weight) for seed_number in seed_numbers]),
                'conf_loss': np.array([get_mean_function[var](ds_nn['conf_loss'][seed_number], weight) for seed_number in seed_numbers]),
                'diff_loss': np.array([get_mean_function[var](ds_nn['diff_loss'][seed_number], weight) for seed_number in seed_numbers]),
                'multirep': np.array([get_mean_function[var](ds_nn['multirep'][seed_number], weight) for seed_number in seed_numbers]),
                'v6': np.array([get_mean_function[var](ds_nn['v6'][seed_number], weight) for seed_number in seed_numbers]),
            }
            for config_name in config_names.keys():
                axes[row, col].fill_between(
                    months,
                    np.nanmin(ds_nn_mean[config_name], axis = 0),
                    np.nanmax(ds_nn_mean[config_name], axis = 0),
                    color = color_dict_config[config_name],
                    alpha=0.15
                )
            if np.sum(np.isnan(ds_nn_mean['standard'])) != np.prod(ds_nn_mean['standard'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_mean['standard'] - ds_mmf_1_var_vals), axis = 1)
                line_standard, = axes[row, col].plot(months, ds_nn_mean['standard'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = config_names['standard'], color = color_dict_config['standard'], linestyle = '--')
            else:
                line_standard, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = config_names['standard'], color = color_dict_config['standard'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_mean['conf_loss'])) != np.prod(ds_nn_mean['conf_loss'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_mean['conf_loss'] - ds_mmf_1_var_vals), axis = 1)
                line_conf_loss, = axes[row, col].plot(months, ds_nn_mean['conf_loss'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = config_names['conf_loss'], color = color_dict_config['conf_loss'], linestyle = '--')
            else:
                line_conf_loss, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = config_names['conf_loss'], color = color_dict_config['conf_loss'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_mean['diff_loss'])) != np.prod(ds_nn_mean['diff_loss'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_mean['diff_loss'] - ds_mmf_1_var_vals), axis = 1)
                line_diff_loss, = axes[row, col].plot(months, ds_nn_mean['diff_loss'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = config_names['diff_loss'], color = color_dict_config['diff_loss'], linestyle = '--')
            else:
                line_diff_loss, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = config_names['diff_loss'], color = color_dict_config['diff_loss'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_mean['multirep'])) != np.prod(ds_nn_mean['multirep'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_mean['multirep'] - ds_mmf_1_var_vals), axis = 1)
                line_multirep, = axes[row, col].plot(months, ds_nn_mean['multirep'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = config_names['multirep'], color = color_dict_config['multirep'], linestyle = '--')
            else:
                line_multirep, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = config_names['multirep'], color = color_dict_config['multirep'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_mean['v6'])) != np.prod(ds_nn_mean['v6'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_mean['v6'] - ds_mmf_1_var_vals), axis = 1)
                line_v6, = axes[row, col].plot(months, ds_nn_mean['v6'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = config_names['v6'], color = color_dict_config['v6'], linestyle = '--')
            else:
                line_v6, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = config_names['v6'], color = color_dict_config['v6'], linestyle = '--')

            if row != 2:
                ylim_min = np.max(np.min(ds_mmf_1_var_vals) - 1.6 * np.std(ds_mmf_1_var_vals), 0)
                ylim_max = np.max(ds_mmf_1_var_vals) + 1.6 * np.std(ds_mmf_1_var_vals)
            elif col == 0:
                ylim_min = .05
                ylim_max = .4
            elif col == 1:
                ylim_min = .05
                ylim_max = .6
            elif col == 2:
                ylim_min = .05
                ylim_max = .4
            elif col == 3:
                ylim_min = .05
                ylim_max = .4
            axes[row, col].set_ylim(ylim_min, ylim_max)
            axes[row, col].grid(True)
            axes[row, col].tick_params(axis='both', labelsize=12)
            axes[row, col].set_xticks(np.arange(0,12*(num_years+1),12), np.arange(0,num_years+1))
            # Set column titles
            if row == 0:
                axes[row, col].set_title(column_titles[col],fontsize=14)
            
            # Set row y-labels
            if col == 0:
                axes[row, col].set_ylabel(row_ylabels[row],fontsize=12)
            
            # Set x-label for the last row
            if row == 2:
                axes[row, col].set_xlabel("Months",fontsize=12)
            
            if row == 2 and col == 0:
                axes[row, col].legend(handles = [line_mmf_1, line_mmf_2, line_standard, line_conf_loss],
                                    labels = ['MMF', 'MMF2', config_names['standard'], config_names['conf_loss']], fontsize=10,loc='upper left')
            elif row == 2 and col == 1:
                axes[row, col].legend(handles = [line_diff_loss, line_multirep, line_v6],
                                      labels = [config_names['diff_loss'], config_names['multirep'], config_names['v6']], fontsize=10,loc='upper left')
    fig.suptitle(f'{num_years} Year Area-Weighted Mean Values ({model_names[model_name]})', fontsize = 16)
    # Adjust layout and display the plot
    plt.tight_layout()
    # plt.savefig('time_series_mean_5years.pdf', format='pdf', dpi=400, bbox_inches='tight')
    if save_path:
        plt.savefig(os.path.join(save_path, f'online_{num_years}_year_area_mean_config_comparison_{model_name}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()


### Area-weighted online standard deviation values (model comparison)

In [None]:
def plot_online_area_std_model_comparison(config_name, num_years, show = True, save_path = None):
    months = np.arange(1, num_years * 12 + 1)
    ds_mmf_1, ds_mmf_2 = read_mmf_online_data(num_years)
    column_titles = ['Global mean', '30N-90N mean', '30S-90S mean', '30S-30N mean']
    row_ylabels = [
        'T$_{59}$ (K)',
        'Precipitable water (kg/m$^2$)',
        'Total cloud path (kg/m$^2$)'
    ]
    variables = ['T', 'TMQ', 'TCP']
    def calculate_std(ds, w, var):
        std_per_month = np.full(len(months), np.nan)
        if not ds:
            return std_per_month
        if var == 'T':
            mean_per_month[:len(ds['time'])] = np.average(ds['T'][:, -1, :].values, weights=w, axis=1)
            squared_diff = (ds['T'][:,-1,:].values - mean_per_month[:len(ds['time']), None])**2
            std_per_month[:len(ds['time'])] = np.sqrt(np.average(squared_diff, weights=w, axis=1))
        elif var == 'TMQ':
            mean_per_month[:len(ds['time'])] = np.average(ds['TMQ'][:, :].values, weights=w, axis=1)
            squared_diff = (ds['TMQ'][:,:].values - mean_per_month[:len(ds['time']), None])**2
            std_per_month[:len(ds['time'])] = np.sqrt(np.average(squared_diff, weights=w, axis=1))
        elif var == 'TCP':
            std_per_month[:len(ds['time'])] = get_tcp_std(ds, w)

        return std_per_month

    get_std_function = {
        'T': lambda ds, w: calculate_std(ds, w, 'T'),
        'TMQ': lambda ds, w: calculate_std(ds, w, 'TMQ'),
        'TCP': lambda ds, w: calculate_std(ds, w, 'TCP')
    }
    ds_nn = {
        'unet': {seed_number: read_nn_online_data(config_name, 'unet', seed_number, num_years) for seed_number in seed_numbers},
        'squeezeformer': {seed_number: read_nn_online_data(config_name, 'squeezeformer', seed_number, num_years) for seed_number in seed_numbers},
        'pure_resLSTM': {seed_number: read_nn_online_data(config_name, 'pure_resLSTM', seed_number, num_years) for seed_number in seed_numbers},
        'pao_model': {seed_number: read_nn_online_data(config_name, 'pao_model', seed_number, num_years) for seed_number in seed_numbers},
        'convnext': {seed_number: read_nn_online_data(config_name, 'convnext', seed_number, num_years) for seed_number in seed_numbers},
        'encdec_lstm': {seed_number: read_nn_online_data(config_name, 'encdec_lstm', seed_number, num_years) for seed_number in seed_numbers}
    }
    fig, axes = plt.subplots(3, 4, figsize=(18, 8))
    for row, var in enumerate(variables):
        for col, weight_key in enumerate(area_weight_dict.keys()):
            weight = area_weight_dict[weight_key]
            ds_mmf_1_var_vals = get_std_function[var](ds_mmf_1, weight)
            line_mmf_1, = axes[row, col].plot(months, ds_mmf_1_var_vals, label = 'MMF', color = 'black')
            line_mmf_2, = axes[row, col].plot(months, get_std_function[var](ds_mmf_2, weight), label = 'MMF2', color = 'black', linestyle = 'dashed')
            ds_nn_std = {
                'unet': np.array([get_std_function[var](ds_nn['unet'][seed_number], weight) for seed_number in seed_numbers]),
                'squeezeformer': np.array([get_std_function[var](ds_nn['squeezeformer'][seed_number], weight) for seed_number in seed_numbers]),
                'pure_resLSTM': np.array([get_std_function[var](ds_nn['pure_resLSTM'][seed_number], weight) for seed_number in seed_numbers]),
                'pao_model': np.array([get_std_function[var](ds_nn['pao_model'][seed_number], weight) for seed_number in seed_numbers]),
                'convnext': np.array([get_std_function[var](ds_nn['convnext'][seed_number], weight) for seed_number in seed_numbers]),
                'encdec_lstm': np.array([get_std_function[var](ds_nn['encdec_lstm'][seed_number], weight) for seed_number in seed_numbers])
            }
            for model_name in model_names.keys():
                axes[row, col].fill_between(
                    months,
                    np.nanmin(ds_nn_std[model_name], axis = 0),
                    np.nanmax(ds_nn_std[model_name], axis = 0),
                    color = color_dict[model_name],
                    alpha=0.15
                )
            if np.sum(np.isnan(ds_nn_std['unet'])) != np.prod(ds_nn_std['unet'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_std['unet'] - ds_mmf_1_var_vals), axis = 1)
                line_unet, = axes[row, col].plot(months, ds_nn_std['unet'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names['unet'], color = color_dict['unet'], linestyle = '--')
            else:
                line_unet, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = model_names['unet'], color = color_dict['unet'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_std['squeezeformer'])) != np.prod(ds_nn_std['squeezeformer'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_std['squeezeformer'] - ds_mmf_1_var_vals), axis = 1)
                line_squeezeformer, = axes[row, col].plot(months, ds_nn_std['squeezeformer'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names['squeezeformer'], color = color_dict['squeezeformer'], linestyle = '--')
            else:
                line_squeezeformer, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = model_names['squeezeformer'], color = color_dict['squeezeformer'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_std['pure_resLSTM'])) != np.prod(ds_nn_std['pure_resLSTM'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_std['pure_resLSTM'] - ds_mmf_1_var_vals), axis = 1)
                line_pure_resLSTM, = axes[row, col].plot(months, ds_nn_std['pure_resLSTM'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names['pure_resLSTM'], color = color_dict['pure_resLSTM'], linestyle = '--')
            else:
                line_pure_resLSTM, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = model_names['pure_resLSTM'], color = color_dict['pure_resLSTM'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_std['pao_model'])) != np.prod(ds_nn_std['pao_model'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_std['pao_model'] - ds_mmf_1_var_vals), axis = 1)
                line_pao_model, = axes[row, col].plot(months, ds_nn_std['pao_model'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names['pao_model'], color = color_dict['pao_model'], linestyle = '--')
            else:
                line_pao_model, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = model_names['pao_model'], color = color_dict['pao_model'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_std['convnext'])) != np.prod(ds_nn_std['convnext'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_std['convnext'] - ds_mmf_1_var_vals), axis = 1)
                line_convnext, = axes[row, col].plot(months, ds_nn_std['convnext'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names['convnext'], color = color_dict['convnext'], linestyle = '--')
            else:
                line_convnext, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = model_names['convnext'], color = color_dict['convnext'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_std['encdec_lstm'])) != np.prod(ds_nn_std['encdec_lstm'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_std['encdec_lstm'] - ds_mmf_1_var_vals), axis = 1)
                line_encdec_lstm, = axes[row, col].plot(months, ds_nn_std['encdec_lstm'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = model_names['encdec_lstm'], color = color_dict['encdec_lstm'], linestyle = '--')
            else:
                line_encdec_lstm, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = model_names['encdec_lstm'], color = color_dict['encdec_lstm'], linestyle = '--')

            if row != 2:
                ylim_min = np.max(np.min(ds_mmf_1_var_vals) - 1.6 * np.std(ds_mmf_1_var_vals), 0)
                ylim_max = np.max(ds_mmf_1_var_vals) + 1.6 * np.std(ds_mmf_1_var_vals)
            elif col == 0:
                ylim_min = .05
                ylim_max = .4
            elif col == 1:
                ylim_min = .05
                ylim_max = .6
            elif col == 2:
                ylim_min = .05
                ylim_max = .4
            elif col == 3:
                ylim_min = .05
                ylim_max = .4
            axes[row, col].set_ylim(ylim_min, ylim_max)
            axes[row, col].grid(True)
            axes[row, col].tick_params(axis='both', labelsize=12)
            axes[row, col].set_xticks(np.arange(0,12*(num_years+1),12), np.arange(0,num_years+1))
            # Set column titles
            if row == 0:
                axes[row, col].set_title(column_titles[col],fontsize=14)
            
            # Set row y-labels
            if col == 0:
                axes[row, col].set_ylabel(row_ylabels[row],fontsize=12)
            
            # Set x-label for the last row
            if row == 2:
                axes[row, col].set_xlabel("Years",fontsize=12)
            
            if row == 2 and col == 0:
                axes[row, col].legend(handles = [line_mmf_1, line_mmf_2, line_unet, line_squeezeformer],
                                    labels = ['MMF', 'MMF2', model_names['unet'], model_names['squeezeformer']], fontsize=10,loc='upper left')
            elif row == 2 and col == 1:
                axes[row, col].legend(handles = [line_pure_resLSTM, line_pao_model, line_convnext, line_encdec_lstm],
                                      labels = [model_names['pure_resLSTM'], model_names['pao_model'], model_names['convnext'], model_names['encdec_lstm']], fontsize=10,loc='upper left')
    fig.suptitle(f'{num_years} Year Area-Weighted Standard Deviation Values ({config_names[config_name]})', fontsize = 16)
    # Adjust layout and display the plot
    plt.tight_layout()
    # plt.savefig('time_series_mean_5years.pdf', format='pdf', dpi=400, bbox_inches='tight')
    if save_path:
        plt.savefig(os.path.join(save_path, f'online_{num_years}_year_area_std_model_comparison_{config_name}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

### Area-weighted online standard deviation values (config comparison)

In [None]:
def plot_online_area_std_config_comparison(model_name, num_years, show = True, save_path = None):
    months = np.arange(1, num_years * 12 + 1)
    ds_mmf_1, ds_mmf_2 = read_mmf_online_data(num_years)
    column_titles = ['Global mean', '30N-90N mean', '30S-90S mean', '30S-30N mean']
    row_ylabels = [
        'T$_{59}$ (K)',
        'Precipitable water (kg/m$^2$)',
        'Total cloud path (kg/m$^2$)'
    ]
    variables = ['T', 'TMQ', 'TCP']
    def calculate_std(ds, w, var):
        std_per_month = np.full(len(months), np.nan)
        if not ds:
            return std_per_month
        if var == 'T':
            mean_per_month[:len(ds['time'])] = np.average(ds['T'][:, -1, :].values, weights=w, axis=1)
            squared_diff = (ds['T'][:,-1,:].values - mean_per_month[:len(ds['time']), None])**2
            std_per_month[:len(ds['time'])] = np.sqrt(np.average(squared_diff, weights=w, axis=1))
        elif var == 'TMQ':
            mean_per_month[:len(ds['time'])] = np.average(ds['TMQ'][:, :].values, weights=w, axis=1)
            squared_diff = (ds['TMQ'][:,:].values - mean_per_month[:len(ds['time']), None])**2
            std_per_month[:len(ds['time'])] = np.sqrt(np.average(squared_diff, weights=w, axis=1))
        elif var == 'TCP':
            std_per_month[:len(ds['time'])] = get_tcp_std(ds, w)

        return std_per_month

    get_std_function = {
        'T': lambda ds, w: calculate_std(ds, w, 'T'),
        'TMQ': lambda ds, w: calculate_std(ds, w, 'TMQ'),
        'TCP': lambda ds, w: calculate_std(ds, w, 'TCP')
    }
    ds_nn = {
        'standard': {seed_number: read_nn_online_data('standard', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'conf_loss': {seed_number: read_nn_online_data('conf_loss', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'diff_loss': {seed_number: read_nn_online_data('diff_loss', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'multirep': {seed_number: read_nn_online_data('multirep', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'v6': {seed_number: read_nn_online_data('v6', model_name, seed_number, num_years) for seed_number in seed_numbers},
    }
    fig, axes = plt.subplots(3, 4, figsize=(18, 8))
    for row, var in enumerate(variables):
        for col, weight_key in enumerate(area_weight_dict.keys()):
            weight = area_weight_dict[weight_key]
            ds_mmf_1_var_vals = get_std_function[var](ds_mmf_1, weight)
            line_mmf_1, = axes[row, col].plot(months, ds_mmf_1_var_vals, label = 'MMF', color = 'black')
            line_mmf_2, = axes[row, col].plot(months, get_std_function[var](ds_mmf_2, weight), label = 'MMF2', color = 'black', linestyle = 'dashed')
            ds_nn_std = {
                'standard': np.array([get_std_function[var](ds_nn['standard'][seed_number], weight) for seed_number in seed_numbers]),
                'conf_loss': np.array([get_std_function[var](ds_nn['conf_loss'][seed_number], weight) for seed_number in seed_numbers]),
                'diff_loss': np.array([get_std_function[var](ds_nn['diff_loss'][seed_number], weight) for seed_number in seed_numbers]),
                'multirep': np.array([get_std_function[var](ds_nn['multirep'][seed_number], weight) for seed_number in seed_numbers]),
                'v6': np.array([get_std_function[var](ds_nn['v6'][seed_number], weight) for seed_number in seed_numbers]),
            }
            for config_name in config_names.keys():
                axes[row, col].fill_between(
                    months,
                    np.nanmin(ds_nn_std[config_name], axis = 0),
                    np.nanmax(ds_nn_std[config_name], axis = 0),
                    color = color_dict_config[config_name],
                    alpha=0.15
                )
            if np.sum(np.isnan(ds_nn_std['standard'])) != np.prod(ds_nn_std['standard'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_std['standard'] - ds_mmf_1_var_vals), axis = 1)
                line_standard, = axes[row, col].plot(months, ds_nn_std['standard'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = config_names['standard'], color = color_dict_config['standard'], linestyle = '--')
            else:
                line_standard, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = config_names['standard'], color = color_dict_config['standard'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_std['conf_loss'])) != np.prod(ds_nn_std['conf_loss'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_std['conf_loss'] - ds_mmf_1_var_vals), axis = 1)
                line_conf_loss, = axes[row, col].plot(months, ds_nn_std['conf_loss'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = config_names['conf_loss'], color = color_dict_config['conf_loss'], linestyle = '--')
            else:
                line_conf_loss, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = config_names['conf_loss'], color = color_dict_config['conf_loss'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_std['diff_loss'])) != np.prod(ds_nn_std['diff_loss'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_std['diff_loss'] - ds_mmf_1_var_vals), axis = 1)
                line_diff_loss, = axes[row, col].plot(months, ds_nn_std['diff_loss'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = config_names['diff_loss'], color = color_dict_config['diff_loss'], linestyle = '--')
            else:
                line_diff_loss, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = config_names['diff_loss'], color = color_dict_config['diff_loss'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_std['multirep'])) != np.prod(ds_nn_std['multirep'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_std['multirep'] - ds_mmf_1_var_vals), axis = 1)
                line_multirep, = axes[row, col].plot(months, ds_nn_std['multirep'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = config_names['multirep'], color = color_dict_config['multirep'], linestyle = '--')
            else:
                line_multirep, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = config_names['multirep'], color = color_dict_config['multirep'], linestyle = '--')
            if np.sum(np.isnan(ds_nn_std['v6'])) != np.prod(ds_nn_std['v6'].shape):
                abs_diff = np.nanmean(np.abs(ds_nn_std['v6'] - ds_mmf_1_var_vals), axis = 1)
                line_v6, = axes[row, col].plot(months, ds_nn_std['v6'][np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff))),:], label = config_names['v6'], color = color_dict_config['v6'], linestyle = '--')
            else:
                line_v6, = axes[row, col].plot(months, np.full(months.shape, np.nan), label = config_names['v6'], color = color_dict_config['v6'], linestyle = '--')

            if row != 2:
                ylim_min = np.max(np.min(ds_mmf_1_var_vals) - 1.6 * np.std(ds_mmf_1_var_vals), 0)
                ylim_max = np.max(ds_mmf_1_var_vals) + 1.6 * np.std(ds_mmf_1_var_vals)
            elif col == 0:
                ylim_min = .05
                ylim_max = .4
            elif col == 1:
                ylim_min = .05
                ylim_max = .6
            elif col == 2:
                ylim_min = .05
                ylim_max = .4
            elif col == 3:
                ylim_min = .05
                ylim_max = .4
            axes[row, col].set_ylim(ylim_min, ylim_max)
            axes[row, col].grid(True)
            axes[row, col].tick_params(axis='both', labelsize=12)
            axes[row, col].set_xticks(np.arange(0,12*(num_years+1),12), np.arange(0,num_years+1))
            # Set column titles
            if row == 0:
                axes[row, col].set_title(column_titles[col],fontsize=14)
            
            # Set row y-labels
            if col == 0:
                axes[row, col].set_ylabel(row_ylabels[row],fontsize=12)
            
            # Set x-label for the last row
            if row == 2:
                axes[row, col].set_xlabel("Months",fontsize=12)
            
            if row == 2 and col == 0:
                axes[row, col].legend(handles = [line_mmf_1, line_mmf_2, line_standard, line_conf_loss],
                                    labels = ['MMF', 'MMF2', config_names['standard'], config_names['conf_loss']], fontsize=10,loc='upper left')
            elif row == 2 and col == 1:
                axes[row, col].legend(handles = [line_diff_loss, line_multirep, line_v6],
                                      labels = [config_names['diff_loss'], config_names['multirep'], config_names['v6']], fontsize=10,loc='upper left')
    fig.suptitle(f'{num_years} Year Area-Weighted Standard Deviation Values ({model_names[model_name]})', fontsize = 16)
    # Adjust layout and display the plot
    plt.tight_layout()
    # plt.savefig('time_series_mean_5years.pdf', format='pdf', dpi=400, bbox_inches='tight')
    if save_path:
        plt.savefig(os.path.join(save_path, f'online_{num_years}_year_area_std_config_comparison_{model_name}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()


### Five Year Global Mean RMSE (model comparison)

In [None]:
def plot_online_global_rmse_model_comparison(config_name, num_years, show = False, save_path = None):
    months = np.arange(1, num_years * 12 + 1)
    ds_mmf_1, ds_mmf_2 = read_mmf_online_data(num_years)
    mmf_1_total_weight = get_pressure_area_weights(ds_mmf_1)
    ds_nn = {
        'unet': {seed_number: read_nn_online_data(config_name, 'unet', seed_number, num_years) for seed_number in seed_numbers},
        'squeezeformer': {seed_number: read_nn_online_data(config_name, 'squeezeformer', seed_number, num_years) for seed_number in seed_numbers},
        'pure_resLSTM': {seed_number: read_nn_online_data(config_name, 'pure_resLSTM', seed_number, num_years) for seed_number in seed_numbers},
        'pao_model': {seed_number: read_nn_online_data(config_name, 'pao_model', seed_number, num_years) for seed_number in seed_numbers},
        'convnext': {seed_number: read_nn_online_data(config_name, 'convnext', seed_number, num_years) for seed_number in seed_numbers},
        'encdec_lstm': {seed_number: read_nn_online_data(config_name, 'encdec_lstm', seed_number, num_years) for seed_number in seed_numbers}
    }
    variables = ['T', 'Q', 'CLDLIQ', 'CLDICE', 'U', 'V']
    ylim_upper = {
        'T': 5,
        'Q': 0.7,
        'CLDLIQ': 60,
        'CLDICE': 8,
        'U': 11,
        'V': 5
    }
    fig, axes = plt.subplots(2, 3, figsize=(8, 7), sharey=True, constrained_layout=True)  # 2 rows, 3 columns
    axes = axes.flatten()  # Flatten axes for easier iteration
    def load_nn_var_time_mean(ds_nn_xr, var, num_years):
        return_vals = np.full((data_v2_rh_mc.num_levels, data_v2_rh_mc.num_latlon), np.nan)
        if not ds_nn_xr or len(ds_nn_xr['time']) < num_years * 12:
            return return_vals
        else:
            return ds_nn_xr[var].mean(dim = 'time').values
    for ax, var in zip(axes, variables):
        ds_mmf_1_mean = ds_mmf_1[var].mean(dim = 'time').values * online_var_settings[var]['scaling']
        ds_mmf_2_mean = ds_mmf_2[var].mean(dim = 'time').values * online_var_settings[var]['scaling']
        mmf_rmse = np.sqrt(np.average((ds_mmf_2_mean - ds_mmf_1_mean) ** 2, axis = 1, weights = area_weight))
        mmf_rmse_global = np.sqrt(np.average((ds_mmf_2_mean - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight))
        ds_nn_rmse = {
            'unet': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['unet'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, axis = 1, weights = area_weight)) for seed_number in seed_numbers]),
            'squeezeformer': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['squeezeformer'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, axis = 1, weights = area_weight)) for seed_number in seed_numbers]),
            'pure_resLSTM': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['pure_resLSTM'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, axis = 1, weights = area_weight)) for seed_number in seed_numbers]),
            'pao_model': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['pao_model'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, axis = 1, weights = area_weight)) for seed_number in seed_numbers]),
            'convnext': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['convnext'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, axis = 1, weights = area_weight)) for seed_number in seed_numbers]),
            'encdec_lstm': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['encdec_lstm'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, axis = 1, weights = area_weight)) for seed_number in seed_numbers])
        }
        ds_nn_rmse_global = {
            'unet': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['unet'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers]),
            'squeezeformer': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['squeezeformer'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers]),
            'pure_resLSTM': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['pure_resLSTM'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers]),
            'pao_model': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['pao_model'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers]),
            'convnext': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['convnext'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers]),
            'encdec_lstm': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['encdec_lstm'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers])
        }
        for model_name in model_names.keys():
            ax.fill_betweenx(
                level,
                np.nanmin(ds_nn_rmse[model_name], axis = 0),
                np.nanmax(ds_nn_rmse[model_name], axis = 0),
                color = color_dict[model_name],
                alpha=0.15
            )
        line_mmf, = ax.plot(mmf_rmse, level, label=f'{mmf_rmse_global:.2f}', linestyle='-.', color='black')
        if np.sum(np.isnan(ds_nn_rmse['unet'])) != np.prod(ds_nn_rmse['unet'].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['unet'] - mmf_rmse), axis = 1)
            argidx = np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff)))
            line_unet, = ax.plot(ds_nn_rmse['unet'][argidx,:], level, label = f"{ds_nn_rmse_global['unet'][argidx]:.2f}", color = color_dict['unet'], linestyle = '-.')
        else:
            line_unet, = ax.plot(np.full(level.shape, np.nan), level, label = 'N/A', color = color_dict['unet'], linestyle = '-.')
        if np.sum(np.isnan(ds_nn_rmse['squeezeformer'])) != np.prod(ds_nn_rmse['squeezeformer'].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['squeezeformer'] - mmf_rmse), axis = 1)
            argidx = np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff)))
            line_squeezeformer, = ax.plot(ds_nn_rmse['squeezeformer'][argidx,:], level, label = f"{ds_nn_rmse_global['squeezeformer'][argidx]:.2f}", color = color_dict['squeezeformer'], linestyle = '-.')
        else:
            line_squeezeformer, = ax.plot(np.full(level.shape, np.nan), level, label = 'N/A', color = color_dict['squeezeformer'], linestyle = '-.')
        if np.sum(np.isnan(ds_nn_rmse['pure_resLSTM'])) != np.prod(ds_nn_rmse['pure_resLSTM'].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['pure_resLSTM'] - mmf_rmse), axis = 1)
            argidx = np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff)))
            line_pure_resLSTM, = ax.plot(ds_nn_rmse['pure_resLSTM'][argidx,:], level, label = f"{ds_nn_rmse_global['pure_resLSTM'][argidx]:.2f}", color = color_dict['pure_resLSTM'], linestyle = '-.')
        else:
            line_pure_resLSTM, = ax.plot(np.full(level.shape, np.nan), level, label = 'N/A', color = color_dict['pure_resLSTM'], linestyle = '-.')
        if np.sum(np.isnan(ds_nn_rmse['pao_model'])) != np.prod(ds_nn_rmse['pao_model'].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['pao_model'] - mmf_rmse), axis = 1)
            argidx = np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff)))
            line_pao_model, = ax.plot(ds_nn_rmse['pao_model'][argidx,:], level, label = f"{ds_nn_rmse_global['pao_model'][argidx]:.2f}", color = color_dict['pao_model'], linestyle = '-.')
        else:
            line_pao_model, = ax.plot(np.full(level.shape, np.nan), level, label = 'N/A', color = color_dict['pao_model'], linestyle = '-.')
        if np.sum(np.isnan(ds_nn_rmse['convnext'])) != np.prod(ds_nn_rmse['convnext'].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['convnext'] - mmf_rmse), axis = 1)
            argidx = np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff)))
            line_convnext, = ax.plot(ds_nn_rmse['convnext'][argidx,:], level, label = f"{ds_nn_rmse_global['convnext'][argidx]:.2f}", color = color_dict['convnext'], linestyle = '-.')
        else:
            line_convnext, = ax.plot(np.full(level.shape, np.nan), level, label = 'N/A', color = color_dict['convnext'], linestyle = '-.')
        if np.sum(np.isnan(ds_nn_rmse['encdec_lstm'])) != np.prod(ds_nn_rmse['encdec_lstm'].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['encdec_lstm'] - mmf_rmse), axis = 1)
            argidx = np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff)))
            line_encdec_lstm, = ax.plot(ds_nn_rmse['encdec_lstm'][argidx,:], level, label = f"{ds_nn_rmse_global['encdec_lstm'][argidx]:.2f}", color = color_dict['encdec_lstm'], linestyle = '-.')
        else:
            line_encdec_lstm, = ax.plot(np.full(level.shape, np.nan), level, label = 'N/A', color = color_dict['encdec_lstm'], linestyle = '-.')

        ax.set_xlim(left = 0, right = ylim_upper[var])
        ax.tick_params(axis='both', labelsize=12)
        ax.set_title(f"{online_var_settings[var]['var_title']} ({online_var_settings[var]['unit']})", fontsize=14, loc='center')  # Add main title with subplot label
        ax.set_xlabel(f"{online_var_settings[var]['unit']}", fontsize=14)  # Keep unit in x-label
        ax.invert_yaxis()  # Reverse the y-axis
        ax.legend(fontsize=8, ncol = 2)

    handles = [line_mmf, line_unet, line_squeezeformer, line_pure_resLSTM, line_pao_model, line_convnext, line_encdec_lstm]
    labels = ['MMF2', model_names['unet'], model_names['squeezeformer'], model_names['pure_resLSTM'], model_names['pao_model'], model_names['convnext'], model_names['encdec_lstm']]

    fig.legend(handles, labels, loc='center right', bbox_to_anchor=(1.27, 0.5), title='Model')
    fig.suptitle(f'{num_years} Year Global Mean Root Mean Squared Error ({config_names[config_name]} configuration)', fontsize=16)
    # Set a shared y-label for the first column
    axes[0].set_ylabel('Hybrid pressure (hPa)', fontsize=14)
    axes[3].set_ylabel('Hybrid pressure (hPa)', fontsize=14)
    plt.gca().invert_yaxis()
    plt.tight_layout()
    # plt.savefig('state_rmse_profiles_and_scalar.pdf', format='pdf', dpi=400, bbox_inches='tight')
    if save_path:
        plt.savefig(os.path.join(save_path, f'online_{num_years}_year_global_RMSE_model_comparison_{config_name}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

### Five Year Global Mean RMSE (config comparison)

In [None]:
def plot_online_global_rmse_config_comparison(model_name, num_years, show = False, save_path = None):
    months = np.arange(1, num_years * 12 + 1)
    ds_mmf_1, ds_mmf_2 = read_mmf_online_data(num_years)
    mmf_1_total_weight = get_pressure_area_weights(ds_mmf_1)
    ds_nn = {
        'standard': {seed_number: read_nn_online_data('standard', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'conf_loss': {seed_number: read_nn_online_data('conf_loss', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'diff_loss': {seed_number: read_nn_online_data('diff_loss', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'multirep': {seed_number: read_nn_online_data('multirep', model_name, seed_number, num_years) for seed_number in seed_numbers},
        'v6': {seed_number: read_nn_online_data('v6', model_name, seed_number, num_years) for seed_number in seed_numbers}
    }
    variables = ['T', 'Q', 'CLDLIQ', 'CLDICE', 'U', 'V']
    ylim_upper = {
        'T': 5,
        'Q': 0.7,
        'CLDLIQ': 60,
        'CLDICE': 8,
        'U': 11,
        'V': 5
    }
    fig, axes = plt.subplots(2, 3, figsize=(8, 7), sharey=True, constrained_layout=True)  # 2 rows, 3 columns
    axes = axes.flatten()  # Flatten axes for easier iteration
    def load_nn_var_time_mean(ds_nn_xr, var, num_years):
        return_vals = np.full((data_v2_rh_mc.num_levels, data_v2_rh_mc.num_latlon), np.nan)
        if not ds_nn_xr or len(ds_nn_xr['time']) < num_years * 12:
            return return_vals
        else:
            return ds_nn_xr[var].mean(dim = 'time').values
    for ax, var in zip(axes, variables):
        ds_mmf_1_mean = ds_mmf_1[var].mean(dim = 'time').values * online_var_settings[var]['scaling']
        ds_mmf_2_mean = ds_mmf_2[var].mean(dim = 'time').values * online_var_settings[var]['scaling']
        mmf_rmse = np.sqrt(np.average((ds_mmf_2_mean - ds_mmf_1_mean) ** 2, axis = 1, weights = area_weight))
        mmf_rmse_global = np.sqrt(np.average((ds_mmf_2_mean - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight))
        ds_nn_rmse = {
            'standard': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['standard'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, axis = 1, weights = area_weight)) for seed_number in seed_numbers]),
            'conf_loss': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['conf_loss'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, axis = 1, weights = area_weight)) for seed_number in seed_numbers]),
            'diff_loss': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['diff_loss'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, axis = 1, weights = area_weight)) for seed_number in seed_numbers]),
            'multirep': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['multirep'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, axis = 1, weights = area_weight)) for seed_number in seed_numbers]),
            'v6': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['v6'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, axis = 1, weights = area_weight)) for seed_number in seed_numbers])
        }
        ds_nn_rmse_global = {
            'standard': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['standard'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers]),
            'conf_loss': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['conf_loss'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers]),
            'diff_loss': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['diff_loss'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers]),
            'multirep': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['multirep'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers]),
            'v6': np.array([np.sqrt(np.average((load_nn_var_time_mean(ds_nn['v6'][seed_number], var, num_years) * online_var_settings[var]['scaling'] - ds_mmf_1_mean) ** 2, weights = mmf_1_total_weight)) for seed_number in seed_numbers])
        }
        for config_name in config_names.keys():
            ax.fill_betweenx(
                level,
                np.nanmin(ds_nn_rmse[config_name], axis = 0),
                np.nanmax(ds_nn_rmse[config_name], axis = 0),
                color = color_dict_config[config_name],
                alpha=0.15
            )
        line_mmf, = ax.plot(mmf_rmse, level, label=f'{mmf_rmse_global:.2f}', linestyle='-.', color='black')
        if np.sum(np.isnan(ds_nn_rmse['standard'])) != np.prod(ds_nn_rmse['standard'].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['standard'] - mmf_rmse), axis = 1)
            argidx = np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff)))
            line_standard, = ax.plot(ds_nn_rmse['standard'][argidx,:], level, label = f"{ds_nn_rmse_global['standard'][argidx]:.2f}", color = color_dict_config['standard'], linestyle = '-.')
        else:
            line_standard, = ax.plot(np.full(level.shape, np.nan), level, label = 'N/A', color = color_dict_config['standard'], linestyle = '-.')
        if np.sum(np.isnan(ds_nn_rmse['conf_loss'])) != np.prod(ds_nn_rmse['conf_loss'].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['conf_loss'] - mmf_rmse), axis = 1)
            argidx = np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff)))
            line_conf_loss, = ax.plot(ds_nn_rmse['conf_loss'][argidx,:], level, label = f"{ds_nn_rmse_global['conf_loss'][argidx]:.2f}", color = color_dict_config['conf_loss'], linestyle = '-.')
        else:
            line_conf_loss, = ax.plot(np.full(level.shape, np.nan), level, label = 'N/A', color = color_dict_config['conf_loss'], linestyle = '-.')
        if np.sum(np.isnan(ds_nn_rmse['diff_loss'])) != np.prod(ds_nn_rmse['diff_loss'].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['diff_loss'] - mmf_rmse), axis = 1)
            argidx = np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff)))
            line_diff_loss, = ax.plot(ds_nn_rmse['diff_loss'][argidx,:], level, label = f"{ds_nn_rmse_global['diff_loss'][argidx]:.2f}", color = color_dict_config['diff_loss'], linestyle = '-.')
        else:
            line_diff_loss, = ax.plot(np.full(level.shape, np.nan), level, label = 'N/A', color = color_dict_config['diff_loss'], linestyle = '-.')
        if np.sum(np.isnan(ds_nn_rmse['multirep'])) != np.prod(ds_nn_rmse['multirep'].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['multirep'] - mmf_rmse), axis = 1)
            argidx = np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff)))
            line_multirep, = ax.plot(ds_nn_rmse['multirep'][argidx,:], level, label = f"{ds_nn_rmse_global['multirep'][argidx]:.2f}", color = color_dict_config['multirep'], linestyle = '-.')
        else:
            line_multirep, = ax.plot(np.full(level.shape, np.nan), level, label = 'N/A', color = color_dict_config['multirep'], linestyle = '-.')
        if np.sum(np.isnan(ds_nn_rmse['v6'])) != np.prod(ds_nn_rmse['v6'].shape):
            abs_diff = np.nanmean(np.abs(ds_nn_rmse['v6'] - mmf_rmse), axis = 1)
            argidx = np.nanargmin(np.abs(abs_diff - np.nanmedian(abs_diff)))
            line_v6, = ax.plot(ds_nn_rmse['v6'][argidx,:], level, label = f"{ds_nn_rmse_global['v6'][argidx]:.2f}", color = color_dict_config['v6'], linestyle = '-.')
        else:
            line_v6, = ax.plot(np.full(level.shape, np.nan), level, label = 'N/A', color = color_dict_config['v6'], linestyle = '-.')

        ax.set_xlim(left = 0, right = ylim_upper[var])
        ax.tick_params(axis='both', labelsize=12)
        ax.set_title(f"{online_var_settings[var]['var_title']} ({online_var_settings[var]['unit']})", fontsize=14, loc='center')  # Add main title with subplot label
        ax.set_xlabel(f"{online_var_settings[var]['unit']}", fontsize=14)  # Keep unit in x-label
        ax.invert_yaxis()  # Reverse the y-axis
        ax.legend(fontsize=8, ncol = 2)

    handles = [line_mmf, line_standard, line_conf_loss, line_diff_loss, line_multirep, line_v6]
    labels = ['MMF2', config_names['standard'], config_names['conf_loss'], config_names['diff_loss'], config_names['multirep'], config_names['v6']]

    fig.legend(handles, labels, loc='center right', bbox_to_anchor=(1.27, 0.5), title='Model')
    fig.suptitle(f'{num_years} Year Global Mean Root Mean Squared Error ({model_names[model_name]})', fontsize=16)
    # Set a shared y-label for the first column
    axes[0].set_ylabel('Hybrid pressure (hPa)', fontsize=14)
    axes[3].set_ylabel('Hybrid pressure (hPa)', fontsize=14)
    plt.gca().invert_yaxis() 
    plt.tight_layout()
    # plt.savefig('state_rmse_profiles_and_scalar.pdf', format='pdf', dpi=400, bbox_inches='tight')
    if save_path:
        plt.savefig(os.path.join(save_path, f'online_{num_years}_year_global_RMSE_config_comparison_{model_name}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

# Precipitation section

### Precipitation Distribution (Model Comparison)

In [None]:
def plot_precc_dist_model_comparison(num_years, config_name, surface_type = 'global', show = False, save_path = None):
    assert surface_type in ['global', 'land', 'ocean', 'ice']
    ds_mmf_1, ds_mmf_2 = read_mmf_online_data(num_years)
    mmf_1_precc = ds_mmf_1['PRECT'].values
    nn_precc = {model_name: {seed_number: read_nn_online_data(config_name, model_name, seed_number, num_years) for seed_number in seed_numbers} for model_name in model_names.keys()}
    mmf_1_hourly_prect = xr.open_mfdataset('/pscratch/sd/z/zeyuanhu/hu_etal2024_data_v2/data_hourly/precip_hourly/mmf_ref/PRECT*nc')['PRECT'].sel(time = slice(None, str(num_years + 2).zfill(4))).values * 86400 * 1000
    mmf_1_hourly_prect_flat = mmf_1_hourly_prect.flatten()
    nn_hourly_prect = {model_name: {seed_number: read_nn_online_precip_data(config_name, model_name, seed_number, num_years) for seed_number in seed_numbers} for model_name in model_names.keys()}
    nn_hourly_prect = {model_name: {seed_number: nn_hourly_prect[model_name][seed_number].sel(time = slice(None, str(num_years + 2).zfill(4))).values * 86400 * 1000 if nn_hourly_prect[model_name][seed_number] is not None else []
                                    for seed_number in seed_numbers} for model_name in model_names.keys()}
    nn_hourly_prect = {model_name: {seed_number: nn_hourly_prect[model_name][seed_number] if len(nn_hourly_prect[model_name][seed_number]) == 365 * 24 * num_years else []
                                    for seed_number in seed_numbers} for model_name in model_names.keys()}
    if surface_type == 'global':
        prefix_str = 'Global'
        mmf_1_custom_weighting = None
        nn_custom_weighting = {model_name: {seed_number: None for seed_number in seed_numbers} for model_name in model_names.keys()}
        tile_area = data_v2_rh_mc.grid_info_area
        mmf_flat_area_weights = np.tile(tile_area, mmf_1_hourly_prect.shape[0])
    elif surface_type == 'land':
        prefix_str = 'Land'
        mmf_1_custom_weighting = get_pressure_area_weights(ds_mmf_1, surface_type = 'land')
        nn_custom_weighting = {model_name: {seed_number: get_pressure_area_weights(nn_precc[model_name][seed_number], surface_type = 'land') if nn_precc[model_name][seed_number] is not None else mmf_1_custom_weighting 
                                                         for seed_number in seed_numbers} for model_name in model_names.keys()}
        tile_area = data_v2_rh_mc.grid_info_area * np.mean(ds_mmf_1['LANDFRAC'], axis = 0)
        mmf_flat_area_weights = np.tile(tile_area, mmf_1_hourly_prect.shape[0])
    elif surface_type == 'ocean':
        prefix_str = 'Ocean'
        mmf_1_custom_weighting = get_pressure_area_weights(ds_mmf_1, surface_type = 'ocean')
        nn_custom_weighting = {model_name: {seed_number: get_pressure_area_weights(nn_precc[model_name][seed_number], surface_type = 'ocean') if nn_precc[model_name][seed_number] is not None else mmf_1_custom_weighting
                                                         for seed_number in seed_numbers} for model_name in model_names.keys()}
        tile_area = data_v2_rh_mc.grid_info_area * np.mean(ds_mmf_1['OCNFRAC'], axis = 0)
        mmf_flat_area_weights = np.tile(tile_area, mmf_1_hourly_prect.shape[0])
    elif surface_type == 'ice':
        prefix_str = 'Ice'
        mmf_1_custom_weighting = get_pressure_area_weights(ds_mmf_1, surface_type = 'ice')
        nn_custom_weighting = {model_name: {seed_number: get_pressure_area_weights(nn_precc[model_name][seed_number], surface_type = 'ice') if nn_precc[model_name][seed_number] is not None else mmf_1_custom_weighting
                                                         for seed_number in seed_numbers} for model_name in model_names.keys()}
        tile_area = data_v2_rh_mc.grid_info_area * np.mean(ds_mmf_1['ICEFRAC'], axis = 0)
        mmf_flat_area_weights = np.tile(tile_area, mmf_1_hourly_prect.shape[0])
    mmf_1_precc_mean = np.mean(data_v2_rh_mc.zonal_bin_weight_2d(mmf_1_precc, custom_weighting = mmf_1_custom_weighting) * 86400 * 1000, axis = 0)
    nn_precc_mean = {model_name: {seed_number: np.mean(data_v2_rh_mc.zonal_bin_weight_2d(nn_precc[model_name][seed_number]['PRECT'].values, custom_weighting = nn_custom_weighting[model_name][seed_number]) * 86400 * 1000, axis = 0)
                if nn_precc[model_name][seed_number] is not None and nn_precc[model_name][seed_number]['PRECT'].shape[0] == 12 * num_years else np.full(lat_bin_mids.shape, np.nan) for seed_number in seed_numbers} for model_name in model_names.keys()}

    lat_ticks = [-60, -30, 0, 30, 60]
    lat_labels = ['60S', '30S', '0', '30N', '60N']
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(9, 3.7))
    ax1 = axes[0]
    ax1.plot(lat_bin_mids, mmf_1_precc_mean, label='MMF', color='black', linestyle='-')
    for model_name in model_names.keys():
        model_mean_arr = np.array([nn_precc_mean[model_name][seed_number] for seed_number in seed_numbers])
        if np.sum(np.isnan(model_mean_arr)) == model_mean_arr.size:
            continue
        ax1.fill_between(
            lat_bin_mids,
            np.nanmin(model_mean_arr, axis=0),
            np.nanmax(model_mean_arr, axis=0),
            color=color_dict[model_name],
            alpha=0.15
        )
        argidx = np.nanargmin(np.nanmean(np.abs(model_mean_arr - np.nanmedian(model_mean_arr, axis = 0)), axis=1))
        ax1.plot(lat_bin_mids, model_mean_arr[argidx, :], label = f"{model_names[model_name]}", color=color_dict[model_name], linestyle='--')

    ax1.set_xlabel('Latitude')
    ax1.set_ylabel('Precipitation (mm/day)')
    ax1.set_xticks(lat_ticks)
    ax1.set_xticklabels(lat_labels)
    ax1.set_title('(a) Mean Precipitation')
    # change fontsize of legend and make it two column
    handles, labels = ax1.get_legend_handles_labels()

    ax1.legend(fontsize='small', ncol=2)
    ax1.set_ylim(0, 8)
    ax1.set_xlim(-90,90)
    # Second plot: Weighted histogram of precipitation
    ax2 = axes[1]
    bins_lev = np.arange(-2,180,4)
    bin_centers = (bins_lev[:-1] + bins_lev[1:]) / 2

    def plot_histogram(ax, data_flat, weights, label, color, linestyle):
        hist, bins = np.histogram(data_flat, bins=bins_lev, weights=weights, density=True)
        ax.plot(bin_centers, hist, label=label, color=color, linestyle=linestyle, linewidth=2)

    plot_histogram(ax2, mmf_1_hourly_prect_flat, mmf_flat_area_weights, 'MMF', 'black', '-')
    for model_name in model_names.keys():
        if all(len(nn_hourly_prect[model_name][seed_number]) == 0 for seed_number in seed_numbers):
            continue
        model_mean_arr = np.array([nn_precc_mean[model_name][seed_number] for seed_number in seed_numbers])
        if np.sum(np.isnan(model_mean_arr)) == model_mean_arr.size:
            continue
        nn_prect_hist = np.array([np.histogram(np.array(nn_hourly_prect[model_name][seed_number]).flatten(),
                                    bins=bins_lev,
                                    weights=np.tile(tile_area, np.array(nn_hourly_prect[model_name][seed_number]).shape[0]),
                                    density=True)[0] for seed_number in seed_numbers])
        ax2.fill_between(
            bin_centers,
            np.nanmin(nn_prect_hist, axis=0),
            np.nanmax(nn_prect_hist, axis=0),
            color=color_dict[model_name],
            alpha=0.15
        )
        argidx = np.nanargmin(np.nanmean(np.abs(nn_prect_hist - np.nanmedian(nn_prect_hist, axis = 0)), axis=1))
        ax2.plot(bin_centers, nn_prect_hist[argidx, :], label = f"{model_names[model_name]}", color=color_dict[model_name], linestyle='--')

    ax2.set_yscale('log')
    ax2.set_xlabel('Precipitation (mm/day)')
    ax2.set_ylabel('Frequency')
    ax2.set_title('(b) Histogram of Precipitation')
    ax2.set_ylim(1e-8,0.5)
    ax2.set_xlim(0,180)

    handles1 = handles[:4]
    labels1 = labels[:4]
    handles2 = handles[4:]
    labels2 = labels[4:]

    # Add legends to each subplot
    ax1.legend(handles1, labels1, loc='upper left')
    ax2.legend(handles2, labels2, loc='upper right')

    # Adjust layout
    # plt.tight_layout()
    # add space between suptitle
    plt.suptitle(f'{prefix_str} Precipitation {num_years}-Year Distribution ({config_names[config_name]} Configuration)', fontsize=14.5)
    plt.subplots_adjust(top=0.85, wspace=0.3)  # Adjust the top space and width space between subplots
    
    # plt.savefig('precipitation_distribution_hist_nopruning_noclass.eps', format='eps', dpi=600)
    if save_path:
        plt.savefig(os.path.join(save_path, f'online_{num_years}_year_precc_dist_model_comparison_{config_name}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

###

### Precipitation Distribution (Config Comparison)

In [None]:
def plot_precc_dist_config_comparison(num_years, model_name, surface_type = 'global', show = False, save_path = None):
    assert surface_type in ['global', 'land', 'ocean', 'ice']
    ds_mmf_1, ds_mmf_2 = read_mmf_online_data(num_years)
    mmf_1_precc = ds_mmf_1['PRECT'].values
    nn_precc = {config_name: {seed_number: read_nn_online_data(config_name, model_name, seed_number, num_years) for seed_number in seed_numbers} for config_name in config_names.keys()}
    mmf_1_hourly_prect = xr.open_mfdataset('/pscratch/sd/z/zeyuanhu/hu_etal2024_data_v2/data_hourly/precip_hourly/mmf_ref/PRECT*nc')['PRECT'].sel(time = slice(None, str(num_years + 2).zfill(4))).values * 86400 * 1000
    mmf_1_hourly_prect_flat = mmf_1_hourly_prect.flatten()
    nn_hourly_prect = {config_name: {seed_number: read_nn_online_precip_data(config_name, model_name, seed_number, num_years) for seed_number in seed_numbers} for config_name in config_names.keys()}
    nn_hourly_prect = {config_name: {seed_number: nn_hourly_prect[config_name][seed_number].sel(time = slice(None, str(num_years + 2).zfill(4))).values * 86400 * 1000 if nn_hourly_prect[config_name][seed_number] is not None else []
                                    for seed_number in seed_numbers} for config_name in config_names.keys()}
    nn_hourly_prect = {config_name: {seed_number: nn_hourly_prect[config_name][seed_number] if len(nn_hourly_prect[config_name][seed_number]) == 365 * 24 * num_years else []
                                    for seed_number in seed_numbers} for config_name in config_names.keys()}
    if surface_type == 'global':
        prefix_str = 'Global'
        mmf_1_custom_weighting = None
        nn_custom_weighting = {config_name: {seed_number: None for seed_number in seed_numbers} for config_name in config_names.keys()}
        tile_area = data_v2_rh_mc.grid_info_area
        mmf_flat_area_weights = np.tile(tile_area, mmf_1_hourly_prect.shape[0])
    elif surface_type == 'land':
        prefix_str = 'Land'
        mmf_1_custom_weighting = get_pressure_area_weights(ds_mmf_1, surface_type = 'land')
        nn_custom_weighting = {config_name: {seed_number: get_pressure_area_weights(nn_precc[config_name][seed_number], surface_type = 'land') if nn_precc[config_name][seed_number] is not None else mmf_1_custom_weighting
                                                         for seed_number in seed_numbers} for config_name in config_names.keys()}
        tile_area = data_v2_rh_mc.grid_info_area * np.mean(ds_mmf_1['LANDFRAC'], axis = 0)
        mmf_flat_area_weights = np.tile(tile_area, mmf_1_hourly_prect.shape[0])
    elif surface_type == 'ocean':
        prefix_str = 'Ocean'
        mmf_1_custom_weighting = get_pressure_area_weights(ds_mmf_1, surface_type = 'ocean')
        nn_custom_weighting = {config_name: {seed_number: get_pressure_area_weights(nn_precc[config_name][seed_number], surface_type = 'ocean') if nn_precc[config_name][seed_number] is not None else mmf_1_custom_weighting
                                                         for seed_number in seed_numbers} for config_name in config_names.keys()}
        tile_area = data_v2_rh_mc.grid_info_area * np.mean(ds_mmf_1['OCNFRAC'], axis = 0)
        mmf_flat_area_weights = np.tile(tile_area, mmf_1_hourly_prect.shape[0])
    elif surface_type == 'ice':
        prefix_str = 'Ice'
        mmf_1_custom_weighting = get_pressure_area_weights(ds_mmf_1, surface_type = 'ice')
        nn_custom_weighting = {config_name: {seed_number: get_pressure_area_weights(nn_precc[config_name][seed_number], surface_type = 'ice') if nn_precc[config_name][seed_number] is not None else mmf_1_custom_weighting
                                                         for seed_number in seed_numbers} for config_name in config_names.keys()}
        tile_area = data_v2_rh_mc.grid_info_area * np.mean(ds_mmf_1['ICEFRAC'], axis = 0)
        mmf_flat_area_weights = np.tile(tile_area, mmf_1_hourly_prect.shape[0])
    mmf_1_precc_mean = np.mean(data_v2_rh_mc.zonal_bin_weight_2d(mmf_1_precc, custom_weighting = mmf_1_custom_weighting) * 86400 * 1000, axis = 0)
    nn_precc_mean = {config_name: {seed_number: np.mean(data_v2_rh_mc.zonal_bin_weight_2d(nn_precc[config_name][seed_number]['PRECT'].values, custom_weighting = nn_custom_weighting[config_name][seed_number]) * 86400 * 1000, axis = 0)
                if nn_precc[config_name][seed_number] is not None and nn_precc[config_name][seed_number]['PRECT'].shape[0] == 12 * num_years else np.full(lat_bin_mids.shape, np.nan) for seed_number in seed_numbers} for config_name in config_names.keys()}

    lat_ticks = [-60, -30, 0, 30, 60]
    lat_labels = ['60S', '30S', '0', '30N', '60N']
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(9, 3.7))
    ax1 = axes[0]
    ax1.plot(lat_bin_mids, mmf_1_precc_mean, label='MMF', color='black', linestyle='-')
    for config_name in config_names.keys():
        model_mean_arr = np.array([nn_precc_mean[config_name][seed_number] for seed_number in seed_numbers])
        if np.sum(np.isnan(model_mean_arr)) == model_mean_arr.size:
            continue
        ax1.fill_between(
            lat_bin_mids,
            np.nanmin(model_mean_arr, axis=0),
            np.nanmax(model_mean_arr, axis=0),
            color=color_dict_config[config_name],
            alpha=0.15
        )
        argidx = np.nanargmin(np.nanmean(np.abs(model_mean_arr - np.nanmedian(model_mean_arr, axis = 0)), axis=1))
        ax1.plot(lat_bin_mids, model_mean_arr[argidx, :], label = f"{config_names[config_name]}", color=color_dict_config[config_name], linestyle='--')

    ax1.set_xlabel('Latitude')
    ax1.set_ylabel('Precipitation (mm/day)')
    ax1.set_xticks(lat_ticks)
    ax1.set_xticklabels(lat_labels)
    ax1.set_title('(a) Mean Precipitation')
    # change fontsize of legend and make it two column
    handles, labels = ax1.get_legend_handles_labels()

    ax1.legend(fontsize='small', ncol=2)
    ax1.set_ylim(0, 8)
    ax1.set_xlim(-90,90)
    # Second plot: Weighted histogram of precipitation
    ax2 = axes[1]
    bins_lev = np.arange(-2,180,4)
    bin_centers = (bins_lev[:-1] + bins_lev[1:]) / 2

    def plot_histogram(ax, data_flat, weights, label, color, linestyle):
        hist, bins = np.histogram(data_flat, bins=bins_lev, weights=weights, density=True)
        ax.plot(bin_centers, hist, label=label, color=color, linestyle=linestyle, linewidth=2)

    plot_histogram(ax2, mmf_1_hourly_prect_flat, mmf_flat_area_weights, 'MMF', 'black', '-')
    for config_name in config_names.keys():
        if all(len(nn_hourly_prect[config_name][seed_number]) == 0 for seed_number in seed_numbers):
            continue
        model_mean_arr = np.array([nn_precc_mean[config_name][seed_number] for seed_number in seed_numbers])
        if np.sum(np.isnan(model_mean_arr)) == model_mean_arr.size:
            continue
        nn_prect_hist = np.array([np.histogram(np.array(nn_hourly_prect[config_name][seed_number]).flatten(),
                                    bins=bins_lev,
                                    weights=np.tile(tile_area, np.array(nn_hourly_prect[config_name][seed_number]).shape[0]),
                                    density=True)[0] for seed_number in seed_numbers])
        ax2.fill_between(
            bin_centers,
            np.nanmin(nn_prect_hist, axis=0),
            np.nanmax(nn_prect_hist, axis=0),
            color=color_dict_config[config_name],
            alpha=0.15
        )
        argidx = np.nanargmin(np.nanmean(np.abs(nn_prect_hist - np.nanmedian(nn_prect_hist, axis = 0)), axis=1))
        ax2.plot(bin_centers, nn_prect_hist[argidx, :], label = f"{config_names[config_name]}", color=color_dict_config[config_name], linestyle='--')

    ax2.set_yscale('log')
    ax2.set_xlabel('Precipitation (mm/day)')
    ax2.set_ylabel('Frequency')
    ax2.set_title('(b) Histogram of Precipitation')
    ax2.set_ylim(1e-8,0.5)
    ax2.set_xlim(0,180)

    handles1 = handles[:4]
    labels1 = labels[:4]
    handles2 = handles[4:]
    labels2 = labels[4:]

    # Add legends to each subplot
    ax1.legend(handles1, labels1, loc='upper left')
    ax2.legend(handles2, labels2, loc='upper right')

    # Adjust layout
    # plt.tight_layout()
    # add space between suptitle
    plt.suptitle(f'{prefix_str} Precipitation {num_years}-Year Distribution ({model_names[model_name]})', fontsize=14.5)
    plt.subplots_adjust(top=0.85, wspace=0.3)  # Adjust the top space and width space between subplots
    
    # plt.savefig('precipitation_distribution_hist_nopruning_noclass.eps', format='eps', dpi=600)
    if save_path:
        plt.savefig(os.path.join(save_path, f'online_{num_years}_year_precc_dist_config_comparison_{model_name}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

### Precipitation Map plots

In [None]:
def generate_precip_map(num_years, model_name, config_name, seed_number, show = False, save_path = None):
    fig, axs = plt.subplots(nrows=1, ncols=3, 
                            subplot_kw={'projection': ccrs.Robinson(central_longitude=179.5)}, 
                            figsize=(18, 6))
    ds_mmf_1, ds_mmf_2 = read_mmf_online_data(num_years)
    ds_nn = read_nn_online_data(config_name, model_name, seed_number, num_years)
    if not ds_nn:
        print(f"No data found for {model_name} with {config_name} configuration and seed {seed_number}. Skipping plot.")
        return
    ds_mmf_1_prect_mean = ds_mmf_1['PRECT'].mean(dim='time').values * 86400 * 1000  # Convert to mm/day
    ds_mmf_2_prect_mean = ds_mmf_2['PRECT'].mean(dim='time').values * 86400 * 1000  # Convert to mm/day
    ds_nn_prect_mean = ds_nn['PRECT'].mean(dim='time').values * 86400 * 1000  # Convert to mm/day
    nn_rmse = np.sqrt(np.average((ds_nn_prect_mean - ds_mmf_1_prect_mean) ** 2, weights = area_weight))
    precip_vmax = 15
    precip_vmin = 0
    precip_levels = np.linspace(precip_vmin, precip_vmax, 11, endpoint=True)
    axs[0].set_global()
    contour1 = axs[0].tricontourf(lons, lats, ds_mmf_1_prect_mean, cmap='YlGnBu', 
                                  transform=ccrs.PlateCarree(), 
                                  levels=precip_levels, extend='both')
    axs[0].coastlines()
    axs[0].set_title('(a) MMF Precipitation (mm/day)', fontsize = 14)

    axs[1].set_global()
    contour2 = axs[1].tricontourf(lons, lats, ds_nn_prect_mean, cmap = 'YlGnBu',
                                  transform=ccrs.PlateCarree(),
                                  levels=precip_levels, extend='both')
    axs[1].coastlines()
    axs[1].set_title(f'(b) NN Precipitation (mm/day)', fontsize = 14)

    bias_max = 4
    bias_levels = np.linspace(-bias_max, bias_max, 11, endpoint = True)
    bias2 = ds_nn_prect_mean - ds_mmf_1_prect_mean
    axs[2].set_global()
    contour3 = axs[2].tricontourf(lons, lats, bias2, cmap='RdBu_r', 
                                  transform=ccrs.PlateCarree(),
                                  levels=bias_levels, extend='both')
    axs[2].coastlines()
    axs[2].set_title(f'(c) Precipitation RMSE: {nn_rmse:.2f} (mm/day)', fontsize = 14)
    # Add a shared colorbar for precipitation (first two plots)
    cbar_ax1 = fig.add_axes([0.18, 0.25, 0.4, 0.02])  # [left, bottom, width, height]
    cbar1 = fig.colorbar(contour1, cax=cbar_ax1, orientation='horizontal')
    cbar1.set_label('', fontsize=12)

    # Add a separate colorbar for bias (third plot)
    cbar_ax2 = fig.add_axes([0.68, 0.25, 0.2, 0.02])  # [left, bottom, width, height]
    cbar2 = fig.colorbar(contour3, cax=cbar_ax2, orientation='horizontal')
    cbar2.set_label('', fontsize=12)
    plt.subplots_adjust(top = .9, wspace=0.1)
    plt.suptitle(f'{num_years}-Year Mean Precipitation and Bias for {model_names[model_name]} ({config_names[config_name]}, Seed {seed_number})', fontsize=16, y = .8)
    if save_path:
        plt.savefig(os.path.join(save_path, f'online_{num_years}_year_precip_and_bias_{model_name}_{config_name}_{seed_number}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()
    

# Binning visualization

In [7]:
def plot_binning_model_comparison(config_name, show = True, save_path = None):
    nn_preds = {model_name: np.mean(np.array([config_preds[config_name][model_name](seed) for seed in seeds]), axis = 0) for model_name in model_names.keys()}
    moistening_diffs = {model_name: (nn_preds[model_name][:,:,60:120] - actual_target[:,:,60:120]) * offline_var_settings['DQ1PHYS']['scaling'] for model_name in model_names.keys()}
    offline_precip_area_weights = {model_name: get_offline_precip_area_weights(nn_preds[model_name]) for model_name in model_names.keys()}
    cmap = matplotlib.colormaps['viridis']
    colors = cmap(np.linspace(0, 1, 11))
    alphas = np.linspace(0.1, 1, 11)
    fig, axs = plt.subplots(2, 3, figsize=(16, 10), constrained_layout=True)
    ax_flat = axs.flatten()
    letter_labels = [f"({letter})" for letter in string.ascii_lowercase[:6]]
    for ax_idx, model_name in enumerate(model_names.keys()):
        for precip_idx, precip_key in enumerate(offline_precip_area_weights[model_name].keys()): 
            ax_flat[ax_idx].plot(np.sum(np.sum(moistening_diffs[model_name] * offline_precip_area_weights[model_name][precip_key], axis=1), axis=0), 
                                 level, label=precip_percentile_labels[precip_key], linestyle='-', color=colors[precip_idx], alpha=alphas[precip_idx])
        ax_flat[ax_idx].invert_yaxis()
        if ax_idx > 2:
            ax_flat[ax_idx].set_xlabel('Moistening Tendency Bias (g/kg/s)')
        if ax_idx == 0 or ax_idx == 3:
            ax_flat[ax_idx].set_ylabel('Hybrid Pressure Level (hPa)')
        ax_flat[ax_idx].set_title(f'{letter_labels[ax_idx]} {model_names[model_name]}')
        ax_flat[ax_idx].legend()
        ax_flat[ax_idx].grid(True)
        ax_flat[ax_idx].set_xlim(-1.4e-6, 2.2e-6)
    plt.suptitle(f'Seed-Averaged Binned Offline Moistening Tendency Bias by Active Precipitation Percentile ({config_names[config_name]} Configuration)', fontsize=14.5)
    if save_path:
        plt.savefig(os.path.join(save_path, f'offline_binned_moistening_tendency_bias_model_comparison_{config_name}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

In [8]:
for config_name in config_names.keys():
    plot_binning_model_comparison(config_name, show=False, save_path=os.path.join(climsim3_figures_save_path_offline, 'offline_binned_moistening_bias_model_comparison'))

In [9]:
def plot_binning_config_comparison(model_name, show = True, save_path = None):
    nn_preds = {config_name: np.mean(np.array([config_preds[config_name][model_name](seed) for seed in seeds]), axis = 0) for config_name in config_names.keys()}
    moistening_diffs = {config_name: (nn_preds[config_name][:,:,60:120] - actual_target[:,:,60:120]) * offline_var_settings['DQ1PHYS']['scaling'] for config_name in config_names.keys()}
    offline_precip_area_weights = {config_name: get_offline_precip_area_weights(nn_preds[config_name]) for config_name in config_names.keys()}
    cmap = matplotlib.colormaps['viridis']
    colors = cmap(np.linspace(0, 1, 11))
    alphas = np.linspace(0.1, 1, 11)

    fig = plt.figure(figsize=(16, 10), constrained_layout=True)
    gs = fig.add_gridspec(2, 6)

    # Create the 5 subplots using the GridSpec
    # Top row: 3 plots, each spanning 2 columns (total 6 columns, fills the row)
    ax0 = fig.add_subplot(gs[0, 0:2]) # Plot 1: row 0, columns 0-1
    ax1 = fig.add_subplot(gs[0, 2:4]) # Plot 2: row 0, columns 2-3
    ax2 = fig.add_subplot(gs[0, 4:6]) # Plot 3: row 0, columns 4-5

    # Bottom row: 2 plots, each also spanning 2 columns (same width as top plots).
    # Centered by skipping the first column and the last column.
    ax3 = fig.add_subplot(gs[1, 1:3]) # Plot 4: row 1, columns 1-2 (leaves col 0 empty)
    ax4 = fig.add_subplot(gs[1, 3:5]) # Plot 5: row 1, columns 3-4 (leaves col 5 empty)

    # Put all axes into a list for easier iteration and plotting
    ax_flat = [ax0, ax1, ax2, ax3, ax4]
    letter_labels = [f"({letter})" for letter in string.ascii_lowercase[:6]]
    for ax_idx, config_name in enumerate(config_names.keys()):
        for precip_idx, precip_key in enumerate(offline_precip_area_weights[config_name].keys()): 
            ax_flat[ax_idx].plot(np.sum(np.sum(moistening_diffs[config_name] * offline_precip_area_weights[config_name][precip_key], axis=1), axis=0), 
                                 level, label=precip_percentile_labels[precip_key], linestyle='-', color=colors[precip_idx], alpha=alphas[precip_idx])
        ax_flat[ax_idx].invert_yaxis()
        ax_flat[ax_idx].set_xlabel('Moistening Tendency Bias (g/kg/s)')
        if ax_idx == 0 or ax_idx == 3:
            ax_flat[ax_idx].set_ylabel('Hybrid Pressure Level (hPa)')
        ax_flat[ax_idx].set_title(f'{letter_labels[ax_idx]} {config_names[config_name]}')
        ax_flat[ax_idx].legend()
        ax_flat[ax_idx].grid(True)
        ax_flat[ax_idx].set_xlim(-1.4e-6, 2.2e-6)
    plt.suptitle(f'Seed-Averaged Binned Offline Moistening Tendency Bias by Precipitation Percentile ({model_names[model_name]})', fontsize=14.5)
    if save_path:
        plt.savefig(os.path.join(save_path, f'offline_binned_moistening_tendency_bias_config_comparison_{model_name}.png'), dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

In [10]:
for model_name in model_names.keys():
    plot_binning_config_comparison(model_name, show=False, save_path=os.path.join(climsim3_figures_save_path_offline, 'offline_binned_moistening_bias_config_comparison'))