In [None]:
import xarray as xa
from matplotlib import pyplot as plt
import cartopy
import glob
import re
natsort = lambda s: [int(t) if t.isdigit() else t.lower() for t in re.split("(\d+)", s)]
import xcdat as xc
import numpy as np
from matplotlib import colors
import matplotlib as mpl
mpl.rc('font', family='DejaVu Serif') 
# mpl.rc('font', serif='Helvetica Neue') 
import pickle
from scipy.stats import pearsonr
from sklearn.metrics import r2_score
import os
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
mpl.rc('font', family='DejaVu Serif') 
mpl.rcParams['mathtext.rm']='DejaVu Serif'
mpl.rcParams['mathtext.fontset'] = 'custom'
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec

In [None]:
def get_forced(var, std=False, models=['CESM2']):
    if var=='tos':
        cat = 'Omon'
        start_year = 1950
        vid = 'tos'
    elif var=='pr':
        cat = 'Amon'
        start_year = 1979
        vid = 'pr'
    elif var=='monmaxpr':
        cat = 'Aday'
        start_year = 1979
        vid = 'pr'
    root_dir = "/p/lustre3/shiduan/ForceSMIP"  # path to forcesmip data (NCAR)
    # models = ['CESM2', 'MIROC6', 'MPI-ESM1-2-LR', 'MIROC-ES2L', 'CanESM5']
    reference_period = (str(start_year)+"-01-01", "2022-12-31") # climatological period (for anomaly calculations)
    tv_time_period = (str(start_year)+"-01-01", "2023-01-01")
    model_mean_list = []
    all_data = {}
    for model in models:
        model_data = []
        print(model)
        mpath = root_dir + '/Training/' + cat + '/' + var + '/' + model
        mfiles = glob.glob(mpath + '/*')
        # print(mfiles)
        if model == "CESM2":
            members = [p.split("ssp370_")[-1].split(".1880")[0] for p in mfiles]
        else:
            members = [p.split("_")[-1].split(".")[0] for p in mfiles]
        members.sort(key=natsort)
        # initialize model ensemble xarray dataset
        ds_model = None
        for im, member in enumerate(members):
            # print member progress
            print('.', end='')
            # get member filename
            fn = glob.glob(mpath + "/*_" + member + ".*.nc")
            # make sure filename is unique
            if len(fn) != 1:
                raise ValueError("Unexpected number of model members")
            else:
                fn = fn[0]
            # load data
            ds = xc.open_dataset(fn)
            ds = ds.bounds.add_missing_bounds(axes=['T'])
            # remove singletons / lon
            ds = ds.squeeze()
            #ds = ds.drop_vars('lon')
            # subset data to user-specified time period
            ds = ds.sel(time=slice(tv_time_period[0], tv_time_period[1]))
            # calculate departures (relative to user-specified reference time period)
            ds = ds.temporal.departures(vid, freq='month', reference_period=reference_period)
            if std:
                ds = ds.groupby(ds.time.dt.month)/ds.groupby(ds.time.dt.month).std(dim='time')
            if 'file_qf' in ds.variables:
                ds = ds.drop('file_qf')
            if 'ref_time' not in locals():
                ref_time = ds.time
            for i, t in enumerate(ds.time.values):
                m = t.month; y = t.year
                rt = ref_time.values[i]; rm = rt.month; ry = rt.year
                if ((ry != y) | (rm != m)):
                    raise ValueError("model time and reference time do not match")
            ds["time"] = ref_time.copy()
            # add model realization to model ensemble dataset
            if ds_model is None:
                ds_model = ds
            else:
                ds_model = xa.concat((ds_model, ds), dim='member')
        all_data[model] = ds_model.load()
        # ds_model_mean = ds_model.mean(dim='member', skipna=False)
        # ds_model_mean = ds_model_mean.load()
        # all_models.append(ds_model) 
        # model_mean_list.append(ds_model_mean)
    return all_data

In [None]:
def get_snr(model, all_data, vid, month=None):
    data_in = all_data[model]
    # Forced Trend
    data_forced = data_in.mean(dim='member')
    if month is not None:
        data_forced = data_forced.sel(time=data_forced.time.dt.month==month)
    time_length = len(data_forced.time.data)
    time_step = np.linspace(1, time_length+1, time_length)
    data_forced['time'] = time_step
    trend_forced = data_forced[vid].polyfit(dim='time', deg=1, skipna=False)
    trend_forced = trend_forced.polyfit_coefficients.sel(degree=1)*12*10 # degree per decade
    # individual trend
    trend_all = []
    for i in range(len(data_in.member)):
        realization = data_in[vid].isel(member=i)
        if month is not None:
            realization = realization.sel(time=realization.time.dt.month==month)
        time_length = len(data_forced.time.data)
        time_step = np.linspace(1, time_length+1, time_length)
        realization['time'] = time_step
        trend_realization = realization.polyfit(dim='time', deg=1, skipna=False)
        trend_realization = trend_realization.polyfit_coefficients.sel(degree=1)*12*10 # degree per decade
        trend_all.append(trend_realization)
    trend_all_xa = xa.concat(trend_all, dim='ens')
    # print(trend_all_xa.shape)
    trend_std = trend_all_xa.std(dim='ens')
    return trend_forced, trend_all, trend_std

# Pr

In [None]:
def get_data(model, var='pr', redo=False, month=None):
    path = '/p/lustre2/shiduan/ForceSMIP/LocalSNR/'
    if (not os.path.exists(path+model+'-pr-trend_forced_std.nc')) or redo==True:
        pr_all_data = get_forced(var='pr', std=False, models=[model])
        pr_all_data_std = get_forced(var='pr', std=True, models=[model])
        trend_forced_std, trend_all_std, trend_std_std = get_snr(model, pr_all_data_std, vid='pr', month=None)
        trend_forced, trend_all, trend_std = get_snr(model, pr_all_data, vid='pr', month=None)
        trend_forced_std.to_netcdf(path+model+'-pr-trend_forced_std.nc')
        # trend_all_std.to_netcdf(path+model+'-pr-trend_all_std.nc')
        trend_std_std.to_netcdf(path+model+'-pr-trend_std_std.nc')
        trend_forced.to_netcdf(path+model+'-pr-trend_forced.nc')
        # trend_all.to_netcdf(path+model+'-pr-trend_all.nc')
        trend_std.to_netcdf(path+model+'-pr-trend_std.nc')
        for m in range(1, 13):
            trend_forced_std, trend_all_std, trend_std_std = get_snr(model, pr_all_data_std, vid='pr', month=m)
            trend_forced, trend_all, trend_std = get_snr(model, pr_all_data, vid='pr', month=m)
            trend_forced_std.to_netcdf(path+model+'-pr-trend_forced_std_month'+str(m)+'.nc')
            trend_std_std.to_netcdf(path+model+'-pr-trend_std_std_month'+str(m)+'.nc')
            trend_forced.to_netcdf(path+model+'-pr-trend_forced_month'+str(m)+'.nc')
            trend_std.to_netcdf(path+model+'-pr-trend_std_month'+str(m)+'.nc')
    else:
        if month is None:
            trend_forced_std = xa.open_dataarray(path+model+'-pr-trend_forced_std.nc')
            # trend_all_std = xa.open_dataarray(path+model+'-pr-trend_all_std.nc')
            trend_std_std = xa.open_dataarray(path+model+'-pr-trend_std_std.nc')
            trend_forced = xa.open_dataarray(path+model+'-pr-trend_forced.nc')
            # trend_all = xa.open_dataarray(path+model+'-pr-trend_all.nc')
            trend_std = xa.open_dataarray(path+model+'-pr-trend_std.nc')
        else:
            trend_forced_std = xa.open_dataarray(path+model+'-pr-trend_forced_std_month'+str(month)+'.nc')
            trend_std_std = xa.open_dataarray(path+model+'-pr-trend_std_std_month'+str(month)+'.nc')

            trend_forced = xa.open_dataarray(path+model+'-pr-trend_forced_month'+str(month)+'.nc')
            trend_std = xa.open_dataarray(path+model+'-pr-trend_std_month'+str(month)+'.nc')
    return trend_forced_std, trend_std_std, trend_forced, trend_std

In [None]:
trend_forced_std, trend_std_std, trend_forced, trend_std  = get_data('CESM2')
snr_cesm2_pr = np.abs(trend_forced)/trend_std
snr_std_cesm2_pr = np.abs(trend_forced_std)/trend_std_std

In [None]:
trend_forced_std, trend_std_std, trend_forced, trend_std  = get_data('CanESM5')
snr_canesm5_pr = np.abs(trend_forced)/trend_std
snr_std_canesm5_pr = np.abs(trend_forced_std)/trend_std_std

In [None]:
trend_forced_std, trend_std_std, trend_forced, trend_std  = get_data('MIROC6')
snr_miroc6_pr = np.abs(trend_forced)/trend_std
snr_std_miroc6_pr = np.abs(trend_forced_std)/trend_std_std

In [None]:
trend_forced_std, trend_std_std, trend_forced, trend_std  = get_data('MPI-ESM1-2-LR')
snr_mpi_pr = np.abs(trend_forced)/trend_std
snr_std_mpi_pr = np.abs(trend_forced_std)/trend_std_std

In [None]:
trend_forced_std, trend_std_std, trend_forced, trend_std  = get_data('MIROC-ES2L')
snr_miroces2l_pr = np.abs(trend_forced)/trend_std
snr_std_miroces2l_pr = np.abs(trend_forced_std)/trend_std_std

In [None]:
import seaborn as sns

In [None]:
sns.color_palette("rocket_r", as_cmap=True)

# Whether standardization change the SNR

In [None]:
cesm2_corr = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('CESM2', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    r, p = pearsonr(snr_std_cesm2_pr_m.data.flatten(), snr_cesm2_pr_m.data.flatten())
    cesm2_corr.append(r)
canesm5_corr = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('CanESM5', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    r, p = pearsonr(snr_std_cesm2_pr_m.data.flatten(), snr_cesm2_pr_m.data.flatten())
    canesm5_corr.append(r)
miroc6_corr = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MIROC6', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    r, p = pearsonr(snr_std_cesm2_pr_m.data.flatten(), snr_cesm2_pr_m.data.flatten())
    miroc6_corr.append(r)
miroces2l_corr = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MIROC-ES2L', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    r, p = pearsonr(snr_std_cesm2_pr_m.data.flatten(), snr_cesm2_pr_m.data.flatten())
    miroces2l_corr.append(r)
mpi_corr = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MPI-ESM1-2-LR', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    r, p = pearsonr(snr_std_cesm2_pr_m.data.flatten(), snr_cesm2_pr_m.data.flatten())
    mpi_corr.append(r)

In [None]:
corrs = np.array([canesm5_corr, cesm2_corr, mpi_corr, miroc6_corr, miroces2l_corr])
print(corrs.shape)
plt.imshow(corrs)
plt.colorbar(shrink=.4)
plt.ylabel('Model')
plt.show()

In [None]:
trend_forced_std, trend_std_std, trend_forced, trend_std  = get_data('CESM2', month=None)
snr_cesm2_pr = np.abs(trend_forced)/trend_std
snr_std_cesm2_pr = np.abs(trend_forced_std)/trend_std_std
cesm2_r2 = r2_score(snr_cesm2_pr.data.flatten(), snr_std_cesm2_pr.data.flatten())
print(cesm2_r2)
trend_forced_std, trend_std_std, trend_forced, trend_std  = get_data('CanESM5', month=None)
snr_canesm_pr = np.abs(trend_forced)/trend_std
snr_std_canesm_pr = np.abs(trend_forced_std)/trend_std_std
canesm5_r2 = r2_score(snr_canesm_pr.data.flatten(), snr_std_canesm_pr.data.flatten())

trend_forced_std, trend_std_std, trend_forced, trend_std  = get_data('MIROC6', month=None)
snr_miroc6_pr = np.abs(trend_forced)/trend_std
snr_std_miroc6_pr = np.abs(trend_forced_std)/trend_std_std
miroc6_r2 = r2_score(snr_miroc6_pr.data.flatten(), snr_std_miroc6_pr.data.flatten())

trend_forced_std, trend_std_std, trend_forced, trend_std  = get_data('MPI-ESM1-2-LR', month=None)
snr_mpi_pr = np.abs(trend_forced)/trend_std
snr_std_mpi_pr = np.abs(trend_forced_std)/trend_std_std
mpi_r2 = r2_score(snr_mpi_pr.data.flatten(), snr_std_mpi_pr.data.flatten())

trend_forced_std, trend_std_std, trend_forced, trend_std  = get_data('MIROC-ES2L', month=None)
snr_miroces2l_pr = np.abs(trend_forced)/trend_std
snr_std_miroces2l_pr = np.abs(trend_forced_std)/trend_std_std
miroces2l_r2 = r2_score(snr_miroces2l_pr.data.flatten(), snr_std_miroces2l_pr.data.flatten())

In [None]:
cesm2_r2s = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('CESM2', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    r = r2_score(snr_cesm2_pr_m.data.flatten(), snr_std_cesm2_pr_m.data.flatten())
    cesm2_r2s.append(r)
canesm5_r2s = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('CanESM5', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    r = r2_score(snr_cesm2_pr_m.data.flatten(), snr_std_cesm2_pr_m.data.flatten())
    canesm5_r2s.append(r)
miroc6_r2s = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MIROC6', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    r = r2_score(snr_cesm2_pr_m.data.flatten(), snr_std_cesm2_pr_m.data.flatten())
    miroc6_r2s.append(r)
miroces2l_r2s = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MIROC-ES2L', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    r = r2_score(snr_cesm2_pr_m.data.flatten(), snr_std_cesm2_pr_m.data.flatten())
    miroces2l_r2s.append(r)
mpi_r2s = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MPI-ESM1-2-LR', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    r = r2_score(snr_cesm2_pr_m.data.flatten(), snr_std_cesm2_pr_m.data.flatten())
    mpi_r2s.append(r)

In [None]:
corrs = np.array([canesm5_r2s, cesm2_r2s, mpi_r2s, miroc6_r2s, miroces2l_r2s])
print(corrs.shape)
plt.imshow(corrs, cmap='Reds')
plt.colorbar(shrink=.4)
plt.ylabel('Model')
plt.show()

In [None]:
fig, axes = plt.subplots(figsize=(12, 6), ncols=3, nrows=2, sharex=True, sharey=True)
ax = axes.flatten()[0]
ax.set_ylim([0.85, 1.01])
ax.bar(x=np.arange(1, 13), height=canesm5_r2s, color='tab:blue')
# ax.boxplot(canesm5_r2s)
ax.axhline(canesm5_r2, color='black')
ax.set_title('CanESM5')
ax = axes.flatten()[1]
ax.bar(x=np.arange(1, 13), height=cesm2_r2s, color='tab:blue')
# ax.boxplot(cesm2_r2s)
ax.axhline(cesm2_r2, color='black')
ax.set_title('CESM2')
ax = axes.flatten()[2]
ax.bar(x=np.arange(1, 13), height=miroc6_r2s, color='tab:blue')
# ax.boxplot(miroc6_r2s)
ax.axhline(miroc6_r2, color='black')
ax.set_title('MIROC6')
ax = axes.flatten()[3]
ax.bar(x=np.arange(1, 13), height=mpi_r2s, color='tab:blue')
# ax.boxplot(mpi_r2s)
ax.axhline(mpi_r2, color='black')
ax.set_title('MPI-ESM1-2-LR')
ax = axes.flatten()[4]
ax.bar(x=np.arange(1, 13), height=miroces2l_r2s, color='tab:blue')
# ax.boxplot(miroces2l_r2s)
ax.axhline(miroces2l_r2, color='black')
ax.set_title('MIROC-ES2L')
ax.set_xticks([1, 4, 7, 10])
ax.set_xticklabels(['Jan', 'Apr', 'Jul', 'Oct'])
fig.delaxes(axes.flatten()[-1])
custom_lines = [Line2D([0], [0], color='tab:blue', lw=2, label='Anomaly+AllM'),
                Line2D([0], [0], color='tab:red', lw=2, label='StdAnomaly+AllM'),
                # Line2D([0], [0], color='black', linestyle='-.', label='0'),
                Patch(facecolor='tab:blue', edgecolor='none',
                         label='Anomaly+MbyM'),
                Patch(facecolor='tab:red', edgecolor='none',
                         label='StdAnomaly+MbyM')]
fig.legend(handles=custom_lines, bbox_to_anchor=[.87, 0.5])
plt.tight_layout()
plt.show()

# correlation between pattern and SNR

In [None]:
path = '/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(1979)+'_2022/'+'pr'+'-solver-stand-False-month-True-unforced-False-joint-False'
with open(path, 'rb') as pfile:
    solver = pickle.load(pfile)
cesm2_corr = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('CESM2', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    # snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=0)
    pc = solver[m-1].pcs().isel(mode=0)
    m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if m<0:
        pattern = -pattern
        print('flip')
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_cesm2_pr_m.data.flatten())
    cesm2_corr.append(r)
canesm5_corr = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('CanESM5', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    # snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=0)
    pc = solver[m-1].pcs().isel(mode=0)
    m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if m<0:
        pattern = -pattern
        print('flip')
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_cesm2_pr_m.data.flatten())
    canesm5_corr.append(r)
miroc6_corr = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MIROC6', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    # snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=0)
    pc = solver[m-1].pcs().isel(mode=0)
    m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if m<0:
        pattern = -pattern
        print('flip')
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_cesm2_pr_m.data.flatten())
    miroc6_corr.append(r)
miroces2l_corr = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MIROC-ES2L', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    # snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=0)
    pc = solver[m-1].pcs().isel(mode=0)
    m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if m<0:
        pattern = -pattern
        print('flip')
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_cesm2_pr_m.data.flatten())
    miroces2l_corr.append(r)
mpi_corr = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MPI-ESM1-2-LR', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    # snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=0)
    pc = solver[m-1].pcs().isel(mode=0)
    m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if m<0:
        pattern = -pattern
        print('flip')
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_cesm2_pr_m.data.flatten())
    mpi_corr.append(r)

In [None]:
path = '/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(1979)+'_2022/'+'pr'+'-solver-stand-True-month-True-unforced-False-joint-False'
with open(path, 'rb') as pfile:
    solver = pickle.load(pfile)
cesm2_corr_std = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('CESM2', month=m)
    # snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=0)
    pc = solver[m-1].pcs().isel(mode=0)
    m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if m<0:
        pattern = -pattern
        print('flip')
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_cesm2_pr_m.data.flatten())
    cesm2_corr_std.append(r)
canesm5_corr_std = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('CanESM5', month=m)
    # snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=0)
    pc = solver[m-1].pcs().isel(mode=0)
    m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if m<0:
        pattern = -pattern
        print('flip')
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_cesm2_pr_m.data.flatten())
    canesm5_corr_std.append(r)
miroc6_corr_std = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MIROC6', month=m)
    # snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=0)
    pc = solver[m-1].pcs().isel(mode=0)
    m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if m<0:
        pattern = -pattern
        print('flip')
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_cesm2_pr_m.data.flatten())
    miroc6_corr_std.append(r)
miroces2l_corr_std = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MIROC-ES2L', month=m)
    # snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=0)
    pc = solver[m-1].pcs().isel(mode=0)
    m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if m<0:
        pattern = -pattern
        print('flip')
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_cesm2_pr_m.data.flatten())
    miroces2l_corr_std.append(r)
mpi_corr_std = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MPI-ESM1-2-LR', month=m)
    # snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=0)
    pc = solver[m-1].pcs().isel(mode=0)
    m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if m<0:
        pattern = -pattern
        print('flip')
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_cesm2_pr_m.data.flatten())
    mpi_corr_std.append(r)

In [None]:
corrs = np.array([canesm5_corr, cesm2_corr, mpi_corr, miroc6_corr, miroces2l_corr])
print(corrs.shape)
plt.imshow(corrs)
plt.colorbar(shrink=.4)
plt.ylabel('Model')
plt.show()
corrs_std = np.array([canesm5_corr_std, cesm2_corr_std, mpi_corr_std, miroc6_corr_std, miroces2l_corr_std])
plt.imshow(corrs_std)
plt.colorbar(shrink=.4)
plt.ylabel('Model')
plt.show()

In [None]:
path = '/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(1979)+'_2022/'+'pr'+\
    '-solver-stand-False-month-False-unforced-False-joint-False'
with open(path, 'rb') as pfile:
    solver = pickle.load(pfile)
pattern = solver[0].eofs().isel(mode=1)
pattern = pattern.transpose('lat', 'lon')
r, p = pearsonr(np.abs(pattern.data.flatten()), snr_cesm2_pr.data.flatten())
print(r)
r, p = pearsonr(np.abs(pattern.data.flatten()), snr_canesm5_pr.data.flatten())
print(r)
r, p = pearsonr(np.abs(pattern.data.flatten()), snr_miroc6_pr.data.flatten())
print(r)
r, p = pearsonr(np.abs(pattern.data.flatten()), snr_mpi_pr.data.flatten())
print(r)
r, p = pearsonr(np.abs(pattern.data.flatten()), snr_miroces2l_pr.data.flatten())
print(r)

In [None]:
pattern.plot()

In [None]:
path = '/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(1979)+'_2022/'+'pr'+'-solver-stand-False-month-False-unforced-False-joint-False'
with open(path, 'rb') as pfile:
    solver = pickle.load(pfile)
pattern = solver[0].eofs().isel(mode=0)
pc = solver[0].pcs().isel(mode=0)
m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
if m<0:
    print('flip')
    pattern = -pattern
pattern = pattern.transpose('lat', 'lon')
cesm2_r, p = pearsonr(np.abs(pattern.data.flatten()), snr_cesm2_pr.data.flatten())
print(cesm2_r)
canesm5_r, p = pearsonr(np.abs(pattern.data.flatten()), snr_canesm5_pr.data.flatten())
print(canesm5_r)
miroc6_r, p = pearsonr(np.abs(pattern.data.flatten()), snr_miroc6_pr.data.flatten())
print(miroc6_r)
mpi_r, p = pearsonr(np.abs(pattern.data.flatten()), snr_mpi_pr.data.flatten())
print(mpi_r)
miroces2l_r, p = pearsonr(np.abs(pattern.data.flatten()), snr_miroces2l_pr.data.flatten())
print(miroces2l_r)

In [None]:
path = '/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(1979)+'_2022/'+'pr'+'-solver-stand-True-month-False-unforced-False-joint-False'
with open(path, 'rb') as pfile:
    solver = pickle.load(pfile)
pattern = solver[0].eofs().isel(mode=0)
pc = solver[0].pcs().isel(mode=0)
m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
if m<0:
    print('flip')
    pattern = -pattern
pattern = pattern.transpose('lat', 'lon')
cesm2_std_r, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_cesm2_pr.data.flatten())
print(cesm2_std_r)
canesm5_std_r, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_canesm5_pr.data.flatten())
print(canesm5_std_r)
miroc6_std_r, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_miroc6_pr.data.flatten())
print(miroc6_std_r)
mpi_std_r, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_mpi_pr.data.flatten())
print(mpi_std_r)
miroces2l_std_r, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_miroces2l_pr.data.flatten())
print(miroces2l_std_r)

In [None]:
fig, axes = plt.subplots(figsize=(12, 6), ncols=3, nrows=2, sharex=True, sharey=True)
ax = axes.flatten()[0]
ax.axhline(0, color='black', linestyle='-.')
ax.bar(x=np.arange(1, 13), height=canesm5_corr_std, color='tab:red', alpha=.8)
ax.bar(x=np.arange(1, 13), height=canesm5_corr, color='tab:blue')
ax.axhline(canesm5_r, color='tab:blue', linewidth=2)
ax.axhline(canesm5_std_r, color='tab:red', linewidth=2)
ax.set_title('CanESM5')
ax = axes.flatten()[1]
ax.axhline(0, color='black', linestyle='-.')
ax.bar(x=np.arange(1, 13), height=cesm2_corr_std, color='tab:red', alpha=.8)
ax.bar(x=np.arange(1, 13), height=cesm2_corr, color='tab:blue')
ax.axhline(cesm2_r, color='tab:blue', linewidth=2)
ax.axhline(cesm2_std_r, color='tab:red', linewidth=2)
ax.set_title('CESM2')
ax = axes.flatten()[2]
ax.axhline(0, color='black', linestyle='-.')
ax.bar(x=np.arange(1, 13), height=miroc6_corr_std, color='tab:red', alpha=.8)
ax.bar(x=np.arange(1, 13), height=miroc6_corr, color='tab:blue')
ax.axhline(miroc6_r, color='tab:blue', linewidth=2)
ax.axhline(miroc6_std_r, color='tab:red', linewidth=2)
ax.set_title('MIROC6')
ax = axes.flatten()[3]
ax.axhline(0, color='black', linestyle='-.')
ax.bar(x=np.arange(1, 13), height=mpi_corr_std, color='tab:red', alpha=.8)
ax.bar(x=np.arange(1, 13), height=mpi_corr, color='tab:blue')
ax.axhline(mpi_r, color='tab:blue', linewidth=2)
ax.axhline(mpi_std_r, color='tab:red', linewidth=2)
ax.set_title('MPI-ESM1-2-LR')
ax = axes.flatten()[4]
ax.axhline(0, color='black', linestyle='-.')
ax.bar(x=np.arange(1, 13), height=miroces2l_corr_std, color='tab:red', alpha=.8)
ax.bar(x=np.arange(1, 13), height=miroces2l_corr, color='tab:blue')
ax.axhline(miroces2l_r, color='tab:blue', linewidth=2)
ax.axhline(miroces2l_std_r, color='tab:red', linewidth=2)
ax.set_title('MIROC-ES2L')
ax.set_xticks([1, 4, 7, 10])
ax.set_xticklabels(['Jan', 'Apr', 'Jul', 'Oct'])
fig.delaxes(axes.flatten()[-1])
custom_lines = [Line2D([0], [0], color='tab:blue', lw=2, label='AllM Anomaly'),
                Line2D([0], [0], color='tab:red', lw=2, label='AllM StdAnomaly'),
                # Line2D([0], [0], color='black', linestyle='-.', label='0'),
                Patch(facecolor='tab:blue', edgecolor='none',
                         label='MbyM Anomaly'),
                Patch(facecolor='tab:red', edgecolor='none',
                         label='MbyM StdAnomaly')]
fig.legend(handles=custom_lines, bbox_to_anchor=[.87, 0.5])
plt.tight_layout()
plt.show()

# Second EOF

## AllM

In [None]:
path = '/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(1979)+'_2022/'+'pr'+'-solver-stand-False-month-False-unforced-False-joint-False'
with open(path, 'rb') as pfile:
    solver = pickle.load(pfile)
pattern = solver[0].eofs().isel(mode=1)
pc = solver[0].pcs().isel(mode=1)
m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
if m<0:
    pattern = -pattern
    print('flip')
pattern = pattern.transpose('lat', 'lon')
cesm2_r_eof2, p = pearsonr(np.abs(pattern.data.flatten()), snr_cesm2_pr.data.flatten())
print(cesm2_r_eof2)
canesm5_r_eof2, p = pearsonr(np.abs(pattern.data.flatten()), snr_canesm5_pr.data.flatten())
print(canesm5_r_eof2)
miroc6_r_eof2, p = pearsonr(np.abs(pattern.data.flatten()), snr_miroc6_pr.data.flatten())
print(miroc6_r_eof2)
mpi_r_eof2, p = pearsonr(np.abs(pattern.data.flatten()), snr_mpi_pr.data.flatten())
print(mpi_r_eof2)
miroces2l_r_eof2, p = pearsonr(np.abs(pattern.data.flatten()), snr_miroces2l_pr.data.flatten())
print(miroces2l_r_eof2)

In [None]:
path = '/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(1979)+'_2022/'+'pr'+'-solver-stand-True-month-False-unforced-False-joint-False'
with open(path, 'rb') as pfile:
    solver = pickle.load(pfile)
pattern = solver[0].eofs().isel(mode=1)
pattern = pattern.transpose('lat', 'lon')
pc = solver[0].pcs().isel(mode=1)
m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
if m<0:
    pattern = -pattern
    print('flip')
cesm2_std_r_eof2, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_cesm2_pr.data.flatten())
print(cesm2_std_r_eof2)
canesm5_std_r_eof2, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_canesm5_pr.data.flatten())
print(canesm5_std_r_eof2)
miroc6_std_r_eof2, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_miroc6_pr.data.flatten())
print(miroc6_std_r_eof2)
mpi_std_r_eof2, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_mpi_pr.data.flatten())
print(mpi_std_r_eof2)
miroces2l_std_r_eof2, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_miroces2l_pr.data.flatten())
print(miroces2l_std_r_eof2)

## MbyM

In [None]:
path = '/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(1979)+'_2022/'+'pr'+'-solver-stand-True-month-True-unforced-False-joint-False'
with open(path, 'rb') as pfile:
    solver = pickle.load(pfile)
cesm2_corr_std_eof2 = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('CESM2', month=m)
    # snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=1)
    pc = solver[m-1].pcs().isel(mode=1)
    slope, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if slope<0:
        pattern = -pattern
        print('flip', m)
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_cesm2_pr_m.data.flatten())
    cesm2_corr_std_eof2.append(r)
canesm5_corr_std_eof2 = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('CanESM5', month=m)
    # snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=1)
    pc = solver[m-1].pcs().isel(mode=1)
    slope, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if slope<0:
        pattern = -pattern
        print('flip', m)
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_cesm2_pr_m.data.flatten())
    canesm5_corr_std_eof2.append(r)
miroc6_corr_std_eof2 = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MIROC6', month=m)
    # snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=1)
    pc = solver[m-1].pcs().isel(mode=1)
    slope, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if slope<0:
        pattern = -pattern
        print('flip', m)
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_cesm2_pr_m.data.flatten())
    miroc6_corr_std_eof2.append(r)
miroces2l_corr_std_eof2 = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MIROC-ES2L', month=m)
    # snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=1)
    pc = solver[m-1].pcs().isel(mode=1)
    slope, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if slope<0:
        pattern = -pattern
        print('flip', m)
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_cesm2_pr_m.data.flatten())
    miroces2l_corr_std_eof2.append(r)
mpi_corr_std_eof2 = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MPI-ESM1-2-LR', month=m)
    # snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=1)
    pc = solver[m-1].pcs().isel(mode=1)
    slope, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if slope<0:
        pattern = -pattern
        print('flip', m)
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_std_cesm2_pr_m.data.flatten())
    mpi_corr_std_eof2.append(r)

In [None]:
path = '/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(1979)+'_2022/'+'pr'+'-solver-stand-False-month-True-unforced-False-joint-False'
with open(path, 'rb') as pfile:
    solver = pickle.load(pfile)
cesm2_corr_eof2 = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('CESM2', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    # snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=1)
    pc = solver[m-1].pcs().isel(mode=1)
    s, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if s<0:
        pattern = -pattern
        print('flip', m)
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_cesm2_pr_m.data.flatten())
    cesm2_corr_eof2.append(r)
canesm5_corr_eof2 = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('CanESM5', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    # snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=1)
    pc = solver[m-1].pcs().isel(mode=1)
    s, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if s<0:
        pattern = -pattern
        print('flip', m)
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_cesm2_pr_m.data.flatten())
    canesm5_corr_eof2.append(r)
miroc6_corr_eof2 = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MIROC6', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    # snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=1)
    pc = solver[m-1].pcs().isel(mode=1)
    s, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if s<0:
        pattern = -pattern
        print('flip', m)
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_cesm2_pr_m.data.flatten())
    miroc6_corr_eof2.append(r)
miroces2l_corr_eof2 = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MIROC-ES2L', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    # snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=1)
    pc = solver[m-1].pcs().isel(mode=1)
    s, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if s<0:
        pattern = -pattern
        print('flip', m)
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_cesm2_pr_m.data.flatten())
    miroces2l_corr_eof2.append(r)
mpi_corr_eof2 = []
for m in range(1, 13):
    trend_forced_std_m, trend_std_std_m, trend_forced_m, trend_std_m  = get_data('MPI-ESM1-2-LR', month=m)
    snr_cesm2_pr_m = np.abs(trend_forced_m)/trend_std_m
    # snr_std_cesm2_pr_m = np.abs(trend_forced_std_m)/trend_std_std_m
    pattern = solver[m-1].eofs().isel(mode=1)
    pc = solver[m-1].pcs().isel(mode=1)
    s, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
    if s<0:
        pattern = -pattern
        print('flip', m)
    pattern = pattern.transpose('lat', 'lon')
    r, p = pearsonr(np.abs(pattern.data.flatten()), snr_cesm2_pr_m.data.flatten())
    mpi_corr_eof2.append(r)

# 1, 4, 7, 10

In [None]:
def plot_snr_spatial(ax1, ax2, snr1, snr2, model_name, cmap, norm):
    ax1.contourf(snr1.lon, snr1.lat, snr1, transform=cartopy.crs.PlateCarree(), cmap=cmap, norm=norm, extend='both')
    ax2.contourf(snr2.lon, snr2.lat, snr2, transform=cartopy.crs.PlateCarree(), cmap=cmap, norm=norm, extend='both')
    ax1.set_title(model_name)
    r = r2_score(snr1.data.flatten(), snr2.data.flatten())
    ax2.set_title(r'$R^2=$'+str(round(r, 2)))
    ax1.add_feature(cartopy.feature.COASTLINE, alpha=.5)
    ax2.add_feature(cartopy.feature.COASTLINE, alpha=.5)

In [None]:
sns_colors = sns.color_palette("muted", 8)
sns_colors

In [None]:
def plot_correlations(corr, corr_std, corr_month, corr_month_std, corr_eof2, corr_month_eof2, ax):
    ax.bar(x=np.arange(1, 13), height=corr_month_std, color=sns_colors[1], alpha=.8)
    ax.bar(x=np.arange(1, 13), height=corr_month, color=sns_colors[0], alpha=.8)
    ax.bar(x=np.arange(1, 13), height=corr_month_eof2, color=sns_colors[-1], alpha=.5)
    ax.axhline(corr, color=sns_colors[0], linewidth=2)
    ax.axhline(corr_std, color=sns_colors[1], linewidth=2)
    ax.axhline(corr_eof2, color=sns_colors[-1], linewidth=2, linestyle='--')
    ax.axhline(0, color='black', linestyle='dotted')
    

In [None]:
custom_lines = [# Line2D([0], [0], color=sns_colors[0], lw=2, label='S-INV Anomaly EOF1'),
                # Line2D([0], [0], color=sns_colors[2], lw=2, label='S-INV Anomaly EOF2'),
                # Line2D([0], [0], color=sns_colors[1], lw=2, label='S-INV StdAnomaly'),
                Patch(facecolor=sns_colors[0], edgecolor='none',alpha=.8,
                         label='RawAnomaly EOF1'),
                Patch(facecolor=sns_colors[-1], edgecolor='none',alpha=.5, hatch='//',
                         label='RawAnomaly EOF2'),
                Patch(facecolor=sns_colors[1], edgecolor='none',alpha=.8,
                         label='StdAnomaly')]

In [None]:
fig = plt.figure(figsize=(16, 9))
cmap = sns.color_palette("rocket_r", as_cmap=True)
color = sns.color_palette("muted", 8)[-3]
norm_pr = colors.Normalize(vmin=0, vmax=5)
gs = GridSpec(5, 5, figure=fig, wspace=0.05, height_ratios=[1.3, 1.3, 1, 0.01, 1])
ax1 = fig.add_subplot(gs[0, 0], projection=cartopy.crs.Robinson(central_longitude=180))
ax2 = fig.add_subplot(gs[1, 0], projection=cartopy.crs.Robinson(central_longitude=180))
plot_snr_spatial(ax1, ax2, snr_canesm5_pr, snr_std_canesm5_pr, 'CanESM5', cmap, norm_pr)

ax1 = fig.add_subplot(gs[0, 1], projection=cartopy.crs.Robinson(central_longitude=180))
ax2 = fig.add_subplot(gs[1, 1], projection=cartopy.crs.Robinson(central_longitude=180))
plot_snr_spatial(ax1, ax2, snr_cesm2_pr, snr_std_cesm2_pr, 'CESM2', cmap, norm_pr)

ax1 = fig.add_subplot(gs[0, 2], projection=cartopy.crs.Robinson(central_longitude=180))
ax2 = fig.add_subplot(gs[1, 2], projection=cartopy.crs.Robinson(central_longitude=180))
plot_snr_spatial(ax1, ax2, snr_miroc6_pr, snr_std_miroc6_pr, 'MIROC6', cmap, norm_pr)

ax1 = fig.add_subplot(gs[0, 3], projection=cartopy.crs.Robinson(central_longitude=180))
ax2 = fig.add_subplot(gs[1, 3], projection=cartopy.crs.Robinson(central_longitude=180))
plot_snr_spatial(ax1, ax2, snr_mpi_pr, snr_std_mpi_pr, 'MPI-ESM1-2-LR', cmap, norm_pr)

ax1 = fig.add_subplot(gs[0, 4], projection=cartopy.crs.Robinson(central_longitude=180))
ax2 = fig.add_subplot(gs[1, 4], projection=cartopy.crs.Robinson(central_longitude=180))
plot_snr_spatial(ax1, ax2, snr_miroces2l_pr, snr_std_miroces2l_pr, 'MIROC-ES2L', cmap, norm_pr)

# MbyM r2(SNR, SNR-std)
axes = []
ax = fig.add_subplot(gs[2, 0])
axes.append(ax)
ax.bar(x=np.arange(1, 13), height=canesm5_r2s, color=color, alpha=.8)
ax.set_yticks([.9, .95, 1])
ax.set_yticklabels([.9, .95, 1])
ax.set_ylabel(r'$R^2$')
ax.set_ylim([.89, 1.02])
ax.set_xticks([1, 4, 7, 10])
ax.set_xticklabels(['Jan', 'Apr', 'Jul', 'Oct'])

ax1 = fig.add_subplot(gs[2, 1], sharey=ax, sharex=ax)
axes.append(ax1)
ax1.bar(x=np.arange(1, 13), height=cesm2_r2s, color=color, alpha=.8)
plt.setp(ax1.get_yticklabels(), visible=False)

ax1 = fig.add_subplot(gs[2, 2], sharey=ax, sharex=ax)
axes.append(ax1)
ax1.bar(x=np.arange(1, 13), height=miroc6_r2s, color=color, alpha=.8)
plt.setp(ax1.get_yticklabels(), visible=False)

ax1 = fig.add_subplot(gs[2, 3], sharey=ax, sharex=ax)
axes.append(ax1)
ax1.bar(x=np.arange(1, 13), height=mpi_r2s, color=color, alpha=.8)
plt.setp(ax1.get_yticklabels(), visible=False)

ax1 = fig.add_subplot(gs[2, 4], sharey=ax, sharex=ax)
axes.append(ax1)
ax1.bar(x=np.arange(1, 13), height=miroces2l_r2s, color=color, alpha=.8)
plt.setp(ax1.get_yticklabels(), visible=False)

# correlation between fingerprint and SNR

axes = []
ax = fig.add_subplot(gs[4, 0])
axes.append(ax)
plot_correlations(corr=canesm5_r, corr_std=canesm5_std_r, corr_month=canesm5_corr, 
                  corr_month_std=canesm5_corr_std, corr_eof2=canesm5_r_eof2, corr_month_eof2=canesm5_corr_eof2, ax=ax)
ax.set_yticks([-.4, 0, .4])
ax.set_yticklabels([-.4, 0, .4])
ax.set_ylabel('Correlation\nCoefficients')
ax.set_xticks([1, 4, 7, 10])
ax.set_xticklabels(['Jan', 'Apr', 'Jul', 'Oct'])

ax1 = fig.add_subplot(gs[4, 1], sharey=ax, sharex=ax)
axes.append(ax1)
plot_correlations(corr=cesm2_r, corr_std=cesm2_std_r, corr_month=cesm2_corr, 
                  corr_month_std=cesm2_corr_std, corr_eof2=cesm2_r_eof2, corr_month_eof2=cesm2_corr_eof2, ax=ax1)
plt.setp(ax1.get_yticklabels(), visible=False)

ax1 = fig.add_subplot(gs[4, 2], sharey=ax, sharex=ax)
axes.append(ax1)
plot_correlations(corr=miroc6_r, corr_std=miroc6_std_r, corr_month=miroc6_corr, 
                  corr_month_std=miroc6_corr_std, corr_eof2=miroc6_r_eof2, corr_month_eof2=miroc6_corr_eof2, ax=ax1)
plt.setp(ax1.get_yticklabels(), visible=False)

ax1 = fig.add_subplot(gs[4, 3], sharey=ax, sharex=ax)
axes.append(ax1)
plot_correlations(corr=mpi_r, corr_std=mpi_std_r, corr_month=mpi_corr, 
                  corr_month_std=mpi_corr_std, corr_eof2=mpi_r_eof2, corr_month_eof2=mpi_corr_eof2, ax=ax1)
plt.setp(ax1.get_yticklabels(), visible=False)

ax1 = fig.add_subplot(gs[4, 4], sharey=ax, sharex=ax)
axes.append(ax1)
plot_correlations(corr=miroces2l_r, corr_std=miroces2l_std_r, corr_month=miroces2l_corr, 
                  corr_month_std=miroces2l_corr_std, corr_eof2=miroces2l_r_eof2, corr_month_eof2=miroces2l_corr_eof2, ax=ax1)
plt.setp(ax1.get_yticklabels(), visible=False)

plt.tight_layout()

cbar_ax = fig.add_axes([.92, 0.6, 0.01, 0.2])
cb = fig.colorbar(mappable=mpl.cm.ScalarMappable(norm=norm_pr, cmap=cmap), cax=cbar_ax, extend='both', label='SNR')
cbar_ax.yaxis.set_ticks_position('right')

plt.annotate('a', xy=(.06, .79), xycoords='figure fraction', fontsize=12, weight='bold')
plt.annotate('b', xy=(.06, .58), xycoords='figure fraction', fontsize=12, weight='bold')
plt.annotate('c', xy=(.06, .39), xycoords='figure fraction', fontsize=12, weight='bold')
plt.annotate('d', xy=(.06, .19), xycoords='figure fraction', fontsize=12, weight='bold')
fig.legend(handles=custom_lines, bbox_to_anchor=[.91, 0.12], loc='lower left')

# gs.tight_layout(fig)
plt.savefig('localSNR-pattern.png', bbox_inches='tight', dpi=300)
plt.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)
ax.bar(x=np.arange(1, 13), height=canesm5_corr_std, color=sns_colors[0], alpha=.8)
ax.bar(x=np.arange(1, 13), height=canesm5_corr, color=sns_colors[1], alpha=.8)
ax.bar(x=np.arange(1, 13), height=canesm5_corr_eof2, color=sns_colors[2], alpha=.8)
# ax.bar(x=np.arange(1, 13), height=canesm5_corr_std_eof2, color=sns_colors[3], hatch='//', alpha=.5)
ax.axhline(canesm5_r, color=sns_colors[1], linewidth=2)
ax.axhline(canesm5_std_r, color=sns_colors[0], linewidth=2)
ax.axhline(canesm5_r_eof2, color=sns_colors[2], linewidth=2, linestyle='--')
# ax.axhline(canesm5_std_r_eof2, color='tab:red', linewidth=2, linestyle='--')
ax.axhline(0, color='black', linestyle='dotted')

# Simple figure

In [None]:
sns_colors

In [None]:
def plot_correlations_simple(corr, corr_std, corr_month, corr_month_std, corr_eof2, corr_month_eof2, ax):
    ax.bar(x=np.arange(1, 13), height=corr_month_std, color=sns_colors[1], alpha=.8)
    ax.bar(x=np.arange(1, 13), height=corr_month, color=sns_colors[0], alpha=.8)
    
    # ax.bar(x=np.arange(1, 13), height=corr_month_eof2, color=sns_colors[-1], alpha=.5)
    ax.axhline(corr, color=sns_colors[-1], linewidth=2)
    ax.axhline(corr_std, color=sns_colors[1], linewidth=2)
    ax.axhline(corr_eof2, color=sns_colors[0], linewidth=2)
    ax.axhline(0, color='black', linestyle='dotted')
    print('corr: ', corr)
    print('corr_std: ', corr_std)
    print('corr_eof2: ', corr_eof2)
    

In [None]:
custom_lines = [Line2D([0], [0], color=sns_colors[-1], lw=2, label=r'$\mathrm{EOF}_{1, \mathrm{S-INV}}$'),
                Line2D([0], [0], color=sns_colors[0], lw=2, label=r'$\mathrm{EOF}_{2, \mathrm{S-INV}}$'),
                Line2D([0], [0], color=sns_colors[1], lw=2, label=r'$\mathrm{EOF}_{1, \mathrm{S-INV, StdA}}$'),
                Patch(facecolor=sns_colors[0], edgecolor='none',alpha=.8,
                         label=r'$\mathrm{EOF}_{1, \mathrm{S-VAR}}$'),
                # Patch(facecolor=sns_colors[-1], edgecolor='none',alpha=.5, hatch='//',
                #          label='RawAnomaly EOF2'),
                Patch(facecolor=sns_colors[1], edgecolor='none',alpha=.8,
                         label=r'$\mathrm{EOF}_{1, \mathrm{S-VAR, StdA}}$')]

In [None]:
fig = plt.figure(figsize=(16, 9))
cmap = sns.color_palette("rocket_r", as_cmap=True)
color = sns.color_palette("muted", 8)[-3]
norm_pr = colors.Normalize(vmin=0, vmax=5)
gs = GridSpec(5, 5, figure=fig, wspace=0.05, height_ratios=[1.3, 1.3, 1, 0.01, 1])
ax1 = fig.add_subplot(gs[0, 0], projection=cartopy.crs.Robinson(central_longitude=180))
ax2 = fig.add_subplot(gs[1, 0], projection=cartopy.crs.Robinson(central_longitude=180))
plot_snr_spatial(ax1, ax2, snr_canesm5_pr, snr_std_canesm5_pr, 'CanESM5', cmap, norm_pr)

ax1 = fig.add_subplot(gs[0, 1], projection=cartopy.crs.Robinson(central_longitude=180))
ax2 = fig.add_subplot(gs[1, 1], projection=cartopy.crs.Robinson(central_longitude=180))
plot_snr_spatial(ax1, ax2, snr_cesm2_pr, snr_std_cesm2_pr, 'CESM2', cmap, norm_pr)

ax1 = fig.add_subplot(gs[0, 2], projection=cartopy.crs.Robinson(central_longitude=180))
ax2 = fig.add_subplot(gs[1, 2], projection=cartopy.crs.Robinson(central_longitude=180))
plot_snr_spatial(ax1, ax2, snr_miroc6_pr, snr_std_miroc6_pr, 'MIROC6', cmap, norm_pr)

ax1 = fig.add_subplot(gs[0, 3], projection=cartopy.crs.Robinson(central_longitude=180))
ax2 = fig.add_subplot(gs[1, 3], projection=cartopy.crs.Robinson(central_longitude=180))
plot_snr_spatial(ax1, ax2, snr_mpi_pr, snr_std_mpi_pr, 'MPI-ESM1-2-LR', cmap, norm_pr)

ax1 = fig.add_subplot(gs[0, 4], projection=cartopy.crs.Robinson(central_longitude=180))
ax2 = fig.add_subplot(gs[1, 4], projection=cartopy.crs.Robinson(central_longitude=180))
plot_snr_spatial(ax1, ax2, snr_miroces2l_pr, snr_std_miroces2l_pr, 'MIROC-ES2L', cmap, norm_pr)

# MbyM r2(SNR, SNR-std)
axes = []
ax = fig.add_subplot(gs[2, 0])
axes.append(ax)
ax.bar(x=np.arange(1, 13), height=canesm5_r2s, color=color, alpha=.8)
ax.set_yticks([.9, .95, 1])
ax.set_yticklabels([.9, .95, 1])
ax.set_ylabel(r'$R^2$')
ax.set_ylim([.89, 1.02])
ax.set_xticks([1, 4, 7, 10])
ax.set_xticklabels(['Jan', 'Apr', 'Jul', 'Oct'])

ax1 = fig.add_subplot(gs[2, 1], sharey=ax, sharex=ax)
axes.append(ax1)
ax1.bar(x=np.arange(1, 13), height=cesm2_r2s, color=color, alpha=.8)
plt.setp(ax1.get_yticklabels(), visible=False)

ax1 = fig.add_subplot(gs[2, 2], sharey=ax, sharex=ax)
axes.append(ax1)
ax1.bar(x=np.arange(1, 13), height=miroc6_r2s, color=color, alpha=.8)
plt.setp(ax1.get_yticklabels(), visible=False)

ax1 = fig.add_subplot(gs[2, 3], sharey=ax, sharex=ax)
axes.append(ax1)
ax1.bar(x=np.arange(1, 13), height=mpi_r2s, color=color, alpha=.8)
plt.setp(ax1.get_yticklabels(), visible=False)

ax1 = fig.add_subplot(gs[2, 4], sharey=ax, sharex=ax)
axes.append(ax1)
ax1.bar(x=np.arange(1, 13), height=miroces2l_r2s, color=color, alpha=.8)
plt.setp(ax1.get_yticklabels(), visible=False)

# correlation between fingerprint and SNR

axes = []
ax = fig.add_subplot(gs[4, 0])
axes.append(ax)
plot_correlations_simple(corr=canesm5_r, corr_std=canesm5_std_r, corr_month=canesm5_corr, 
                  corr_month_std=canesm5_corr_std, corr_eof2=canesm5_r_eof2, corr_month_eof2=canesm5_corr_eof2, ax=ax)
ax.set_yticks([-.4, 0, .4])
ax.set_yticklabels([-.4, 0, .4])
ax.set_ylabel('Correlation\nCoefficients')
ax.set_xticks([1, 4, 7, 10])
ax.set_xticklabels(['Jan', 'Apr', 'Jul', 'Oct'])

ax1 = fig.add_subplot(gs[4, 1], sharey=ax, sharex=ax)
axes.append(ax1)
plot_correlations_simple(corr=cesm2_r, corr_std=cesm2_std_r, corr_month=cesm2_corr, 
                  corr_month_std=cesm2_corr_std, corr_eof2=cesm2_r_eof2, corr_month_eof2=cesm2_corr_eof2, ax=ax1)
plt.setp(ax1.get_yticklabels(), visible=False)

ax1 = fig.add_subplot(gs[4, 2], sharey=ax, sharex=ax)
axes.append(ax1)
plot_correlations_simple(corr=miroc6_r, corr_std=miroc6_std_r, corr_month=miroc6_corr, 
                  corr_month_std=miroc6_corr_std, corr_eof2=miroc6_r_eof2, corr_month_eof2=miroc6_corr_eof2, ax=ax1)
plt.setp(ax1.get_yticklabels(), visible=False)

ax1 = fig.add_subplot(gs[4, 3], sharey=ax, sharex=ax)
axes.append(ax1)
plot_correlations_simple(corr=mpi_r, corr_std=mpi_std_r, corr_month=mpi_corr, 
                  corr_month_std=mpi_corr_std, corr_eof2=mpi_r_eof2, corr_month_eof2=mpi_corr_eof2, ax=ax1)
plt.setp(ax1.get_yticklabels(), visible=False)

ax1 = fig.add_subplot(gs[4, 4], sharey=ax, sharex=ax)
axes.append(ax1)
plot_correlations_simple(corr=miroces2l_r, corr_std=miroces2l_std_r, corr_month=miroces2l_corr, 
                  corr_month_std=miroces2l_corr_std, corr_eof2=miroces2l_r_eof2, corr_month_eof2=miroces2l_corr_eof2, ax=ax1)
plt.setp(ax1.get_yticklabels(), visible=False)

plt.tight_layout()

cbar_ax = fig.add_axes([.92, 0.6, 0.01, 0.2])
cb = fig.colorbar(mappable=mpl.cm.ScalarMappable(norm=norm_pr, cmap=cmap), cax=cbar_ax, extend='both', label='SNR')
cbar_ax.yaxis.set_ticks_position('right')

plt.annotate('a', xy=(.06, .79), xycoords='figure fraction', fontsize=12, weight='bold')
plt.annotate('b', xy=(.06, .58), xycoords='figure fraction', fontsize=12, weight='bold')
plt.annotate('c', xy=(.06, .39), xycoords='figure fraction', fontsize=12, weight='bold')
plt.annotate('d', xy=(.06, .19), xycoords='figure fraction', fontsize=12, weight='bold')
fig.legend(handles=custom_lines, bbox_to_anchor=[.91, 0.09], loc='lower left')

# gs.tight_layout(fig)
plt.savefig('localSNR-pattern.png', bbox_inches='tight', dpi=300)
plt.show()