# PCA Plots

In [None]:
## Create a plot of the amount of variance explained by each EOF
def get_var_exp():
    varFrac = solver.explained_variance() / 100
    eof_var = np.concatenate(([0],np.cumsum(varFrac)))

    plt.figure(figsize=(12,6),dpi=200)

    plt.plot(np.arange(0,26),eof_var[0:26],marker='o',color='r')

    plt.axhline(y=1,color='black',linestyle='dotted')
    plt.xticks(np.arange(0,26))
    plt.ylim([0,1.05])
    plt.text(0.5,0.77,str(np.round(varFrac.sel(mode=slice(1,2)).sum().values*100,1))+'%',color='r')

    plt.xlabel('EOF Mode')
    plt.ylabel('Fraction of Variance Explained')
    plt.title('Variance Explained by Each EOF Mode');

In [None]:
## Compare SST PCs with nino3.4 time-series
def make_nino_plot():
    nino34_reg = regOnPcs(pc_patterns.sel(time=nino34_orig_nonan.time),nino34_orig_nonan,1,1).coefs
    nino34_recon = nino34_reg.dot(pc_patterns)
    nino34_corr = xr.corr(nino34_orig.sst,nino34_recon)
    pc_corr = xr.corr(nino34_orig.sst,(nino34_reg*pc_patterns))

    fig, ax = plt.subplots(1,1,figsize=(18,6),dpi=200)

    ax.plot(nino34_orig.sst.time,nino34_orig.sst,color='k',label='Nino3.4')
    ax.plot(nino34_recon.time,nino34_recon,color='purple',label='Nino3.4$_{recon}$ \t'+str(nino34_corr.values))
    ax.plot(nino34_recon.time,(nino34_reg*pc_SST).sel(mode=0),color='r',linestyle='dashed',label=str(np.around(nino34_reg.sel(mode=0).values,2))+'PC1')
    ax.plot(nino34_recon.time,(nino34_reg*pc_SST).sel(mode=1),color='b',linestyle='dashed',label=str(np.around(nino34_reg.sel(mode=1).values,2))+'PC2')

    ax.set_ylabel('K')
    ax.set_xlabel('Year')
    ax.axhline(y=0,color='k',linestyle='dotted')
    ax.set_title('Correlation of nino3.4 Time Series vs Reconstructed with Two PCs')
    ax.legend()

    print('Correlation of nino3.4 and reconstructed nino3.4 is: ' + str(nino34_corr.values))
    print('The Model Equation is: ' + str(np.around(nino34_reg.sel(mode=0).values,2)) + 'PC1 + ' + str(np.around(nino34_reg.sel(mode=1).values,2)) + 'PC2');

# Spatial Maps

In [None]:
## Create lat, lon maps of SST and radiation EOFs
def spatial_maps(eofs,nmap,ds_list,plt_type,nmap_title):
    fig, ax = plt.subplots(1,6,figsize=(24,3),subplot_kw={'projection': ccrs.Robinson(central_longitude=180)},dpi=200)
    fig.tight_layout(w_pad=5)
    
    # Initialization
    vmax = 4
    vmin = -vmax
    levels = np.linspace(vmin,vmax,10)
    # Plot names
    if plt_type == 'CERES':
        title = ['SST', 'All Sky', 'Clear Sky', 'Net CRE', 'Low Cloud CRE', 'High Cloud CRE']
    elif plt_type == 'CCF_CLR':
        title = ['SST', 'SST CRE', 'EIS CRE', 'Tadv CRE', 'Planck+LR', 'Water Vapor']
    
    # Contour plots of EOF maps
    cbar1_data = ax[0].contourf(eofs[ds_list[0]].longitude,eofs[ds_list[0]].latitude,eofs[ds_list[0]],transform=ccrs.PlateCarree(),cmap=plt.get_cmap('RdBu').reversed(),
                                  add_colorbar=False,levels=np.linspace(-1,1,10),extend='both')
    cbar2_data = ax[1].contourf(eofs[ds_list[2]].longitude,eofs[ds_list[2]].latitude,eofs[ds_list[2]],transform=ccrs.PlateCarree(),cmap=plt.get_cmap('RdBu').reversed(),
                                  add_colorbar=False,levels=levels,extend='both')    
    ax[2].contourf(eofs[ds_list[4]].longitude,eofs[ds_list[4]].latitude,eofs[ds_list[4]],transform=ccrs.PlateCarree(),cmap=plt.get_cmap('RdBu').reversed(),
                                  add_colorbar=False,levels=levels,extend='both')    
    ax[3].contourf(eofs[ds_list[6]].longitude,eofs[ds_list[6]].latitude,eofs[ds_list[6]],transform=ccrs.PlateCarree(),cmap=plt.get_cmap('RdBu').reversed(),
                                  add_colorbar=False,levels=levels,extend='both')    
    ax[4].contourf(eofs[ds_list[8]].longitude,eofs[ds_list[8]].latitude,eofs[ds_list[8]],transform=ccrs.PlateCarree(),cmap=plt.get_cmap('RdBu').reversed(),
                                  add_colorbar=False,levels=levels,extend='both')    
    ax[5].contourf(eofs[ds_list[10]].longitude,eofs[ds_list[10]].latitude,eofs[ds_list[10]],transform=ccrs.PlateCarree(),cmap=plt.get_cmap('RdBu').reversed(),
                                  add_colorbar=False,levels=levels,extend='both')    
    
    # Add hatches to values below 95% conf. int.
    ax[0].contourf(eofs[ds_list[1]].longitude,eofs[ds_list[1]].latitude,eofs[ds_list[1]] == False,transform=ccrs.PlateCarree(),colors='none',
               levels=levels,hatches=['','//'],add_colorbar=False)
    ax[1].contourf(eofs[ds_list[3]].longitude,eofs[ds_list[3]].latitude,eofs[ds_list[3]] == False,transform=ccrs.PlateCarree(),colors='none',
               levels=levels,hatches=['','//'],add_colorbar=False)
    ax[2].contourf(eofs[ds_list[5]].longitude,eofs[ds_list[5]].latitude,eofs[ds_list[5]] == False,transform=ccrs.PlateCarree(),colors='none',
               levels=levels,hatches=['','//'],add_colorbar=False)
    ax[3].contourf(eofs[ds_list[7]].longitude,eofs[ds_list[7]].latitude,eofs[ds_list[7]] == False,transform=ccrs.PlateCarree(),colors='none',
               levels=levels,hatches=['','//'],add_colorbar=False)
    ax[4].contourf(eofs[ds_list[9]].longitude,eofs[ds_list[9]].latitude,eofs[ds_list[9]] == False,transform=ccrs.PlateCarree(),colors='none',
               levels=levels,hatches=['','//'],add_colorbar=False)
    ax[5].contourf(eofs[ds_list[11]].longitude,eofs[ds_list[11]].latitude,eofs[ds_list[11]] == False,transform=ccrs.PlateCarree(),colors='none',
               levels=levels,hatches=['','//'],add_colorbar=False)

    for i in range(0,6):
        ax[i].coastlines()
        ax[i].set_aspect('auto')
        
        # Only give (1) first row titles, (2) last row colorbars, (3) all other rows nothing
        if nmap == 0 or nmap == 2:
            ax[i].set_title(title[i], fontsize=20)
        elif nmap == 3:
            cbar1_ax = fig.add_axes([0.01, -0.1, 0.15, 0.05])
            cbar1 = fig.colorbar(cbar1_data, cax=cbar1_ax, orientation='horizontal', extend='both', format='%g', ticks=[-1,0,1])
            cbar1.ax.tick_params(labelsize=20)
            cbar1.set_label(label='$K/K$', size=20)
            
            cbar2_ax = fig.add_axes([0.18, -0.1, 0.81, 0.05])
            cbar2 = fig.colorbar(cbar2_data, cax=cbar2_ax, orientation='horizontal', extend='both', format='%g', ticks=np.round(levels,1))
            cbar2.ax.tick_params(labelsize=20)
            cbar2.set_label(label='$W/m^2/K$', size=20)
            
            ax[i].set_title(None)
        else:
            ax[i].set_title(None)

        fig.suptitle(nmap_title[nmap], y=0.7, x=-0.01, rotation='vertical', fontsize=20)
        
#     fig.savefig('spatial_map_'+str(nmap)+'.png',bbox_inches='tight')

# Time-series Analysis

In [None]:
## Calculated adjusted degree of freedom using autocorr
def dof(data):
    tau = autocorrelation(data,np.arange(0,36))
    df = len(data.time)/(2*tau)
    return df

## Create plots of global-mean and reconstructed radiation time-series 
def ts_plot(ts_data,i,j,title,freq):
    # Display desired frequencies for time-series plots
    if freq == 'mon':    
        data_gm    = ts_data[0].sel(rad_type=rad_type_name[j])
        data_recon = ts_data[1].sel(rad_type=rad_type_name[j])
        data_trop  = ts_data[2].sel(rad_type=rad_type_name[j])
    if freq == 'yr':
        data_gm    = ts_data[0].sel(rad_type=rad_type_name[j]).rolling(time=12,center=True).mean().dropna(dim='time')
        data_recon = ts_data[1].sel(rad_type=rad_type_name[j]).rolling(time=12,center=True).mean().dropna(dim='time')
        data_trop  = ts_data[2].sel(rad_type=rad_type_name[j]).rolling(time=12,center=True).mean().dropna(dim='time')
    
    # Get r2 and p-values
    gm_model   = sm.OLS(data_gm.values,data_recon.values).fit()
    trop_model = sm.OLS(data_trop.values,data_recon.values).fit()

    data_gm.plot(ax=ax,label=('$\Delta R_{gm}$ $r^2 = $'+str(np.round(gm_model.rsquared,3))+',\t p = '+str(np.round(st.t.sf(gm_model.tvalues,df=dof(data_gm))[0],3))),color='k',linewidth=1)
    data_trop.plot(ax=ax,label=('$\Delta R_{eq \, pac}$ $r^2 = $'+str(np.round(trop_model.rsquared,3))+',\tp = '+str(np.round(st.t.sf(trop_model.tvalues,df=dof(data_trop))[0],3))),color='grey',linewidth=1)
    data_recon.plot(ax=ax,label=('$\Delta R_{recon}$'),color='r',linewidth=1)

    ax.axhline(y=0,linestyle='dotted',color='k',alpha=0.5)
    ax.set_ylim([-3,3])
    ax.set_yticks([-3,0,3])
    ax.set_yticklabels([-3,0,3],size=7)
    ax.set_ylabel('$W/m^2$',size=7)
    plt.yticks(size=7)
    
    ax.set_xlim([highCld_cre.time[0],highCld_cre.time[-1]])
    ax.set_xticks(['2003','2011','2019'])
    ax.set_xticklabels(['2003','2011','2019'],size=7,rotation=0,ha='center')
    ax.set_xlabel('Year',size=7)
    plt.xticks(size=7)
    
    ax.legend(prop={'size':5},loc='upper right',frameon=False)
    ax.set_title(title)

    if j != 0:
        ax.set_yticklabels([])
        ax.set_ylabel(None)

In [None]:
## Get magnitude-squared coherence plots
def coh_plot(ts_data,i,j,title):
    # Get magnitude-squared coherence
    f, Cxy = get_coherence(ts_data[0].sel(rad_type=rad_type_name[j]),
                           ts_data[1].sel(rad_type=rad_type_name[j]))
    
    # Get cross spectral density to be used to get phase relationships
    csd = signal.csd(ts_data[0].sel(rad_type=rad_type_name[j]),
                     ts_data[1].sel(rad_type=rad_type_name[j]),nperseg=120,noverlap=60)
    angle = np.angle(csd[1],deg=True)
    angle = angle / 360 / f
    
    # Confidence interval calculations
    ci_lower, ci_upper = conf_int(Cxy,3)
    sig_level = np.sqrt(1 - (0.05)**(1/(3-1)))
    angle_conf = np.where(Cxy > sig_level, angle, np.nan)

    ax.semilogx(f,Cxy,color='k',marker='.',linewidth=1,markersize=3)
    ax2 = ax.twinx()
    ax2.semilogx(f[2:6],angle[2:6],color='r',marker='.',linestyle=None,markersize=3)

    ax.axvline(x=0.0167,color='k',linestyle='dotted',linewidth=1)
    ax.axvline(x=0.04167,color='k',linestyle='dotted',linewidth=1)
    ax.axhline(y=sig_level,color='k',linestyle='dashed',linewidth=1)
    ax2.axhline(y=0,color='r',linestyle='dashed',linewidth=1)

    ax.set_xticks([0.004167,0.00833,0.0167,0.04167,0.0833,0.167,0.333])
    ax.set_xticklabels(['1/20','1/10','1/5','1/2','1/1','1/0.5','1/0.25'],size=7)
    ax.set_xlim([0.00833,f[-1]])
    ax.set_xlabel('Frequency (1/year)',size=7)
    
    ax.set_yticks([0,0.5,1])
    ax.set_ylim([0,1])
    ax.set_yticklabels(['0','0.5','1'],size=7)
    ax.set_ylabel('Coherence',size=7)
    plt.minorticks_off()

    ax2_yticks = [-15,0,15]  #[-180,-135,-90,-45,0,45,90,135,180]
    ax2.set_yticks(ax2_yticks)
    ax2.set_yticklabels(['-15','0','15'],size=7)
    ax2.set_ylim([ax2_yticks[0],ax2_yticks[-1]])
    ax2.set_ylabel('Lag (months)', color='r',size=7)
    ax2.spines['right'].set_color('red')
    ax2.tick_params(axis='y',colors='red',size=7)

    ax.set_title(title)

    # Condense x- and y- labels to first column/last row, and show lead/lag directions on ax2
    if j != 0:
        ax.set_ylabel(None)
    if j != n-1:
        ax2.set_ylabel(None)
    if j == n-1:
        ax2.arrow(2,1,0,10,length_includes_head=True,head_width=0.25, head_length=2,clip_on=False,facecolor='r',edgecolor='r')
        ax2.arrow(2,-1,0,-10,length_includes_head=True,head_width=0.25, head_length=2,clip_on=False,facecolor='r',edgecolor='r')
        ax2.text(1.5,-15,'R Leads',size=5,color='r')
        ax2.text(1.57,14,'R Lags',size=5,color='r')

In [None]:
## Get power spectral density plots
def psd_plot(ts_data,i,j,title):
    # Get PSDs
    signal = get_signal([ts_data[0].sel(rad_type=rad_type_name[j]),
                         ts_data[1].sel(rad_type=rad_type_name[j]),
                         ts_data[2].sel(rad_type=rad_type_name[j])])
    
    # conf. int. calculations
    ci_lower, ci_upper = conf_int(norm(signal[0][1]),90)
    ci_gm_lower, ci_gm_upper = conf_int(norm(signal[1][1]),90)
    ci_trop_lower, ci_trop_upper = conf_int(norm(signal[2][1]),90)
    ci_nino_lower, ci_nino_upper = conf_int(norm(nino_sig[1]),90)
    
    # Reconstructed radiation PSD
    ax.semilogx(signal[0][0],norm(signal[0][1]),color='purple',alpha=0.5,label='$R_{recon}$',linestyle='solid',marker='.',linewidth=1)
    ax.fill_between(signal[0][0],(ci_upper),(ci_lower),color='purple',alpha=0.1,linestyle='dotted',linewidth=1)
    
    # Full radiation PSD
    ax.semilogx(signal[1][0],norm(signal[1][1]),color='blue',alpha=0.5,label='$R_{full}$',marker='.',linewidth=1)
    ax.fill_between(signal[1][0],(ci_gm_upper),(ci_gm_lower),color='blue',alpha=0.1,linestyle='dotted',linewidth=1)

    # Equatorial Pacific radiation PSD
    ax.semilogx(signal[2][0],norm(signal[2][1]),color='orangered',alpha=0.5,label='$R_{full,trop}$',marker='.',linewidth=1)
    ax.fill_between(signal[2][0],(ci_trop_upper),(ci_trop_lower),color='orangered',alpha=0.1,linestyle='dotted',linewidth=1)

    # Nino3.4 PSD
    ax.semilogx(nino_sig[0],norm(nino_sig[1]),color='k',label='nino3.4',marker='.',linewidth=1)
    ax.fill_between(nino_sig[0],(ci_nino_upper),(ci_nino_lower),color='k',alpha=0.1,linestyle='dotted',linewidth=1)

#         ax.fill_between(nino_sig[0],norm(rd_noise),1e-6,alpha=0.33,color='r',linestyle='dashed',linewidth=1)
#         ax.fill_between(nino_sig[0],norm(wh_noise),1e-6,alpha=0.33,color='k',linestyle='dashed',linewidth=1)

    xticks = [0.004167,0.00833,0.0167,0.04167,0.0833,0.167,0.333]
    xlabels = ['1/20','1/10','1/5','1/2','1/1','1/0.5','1/0.25']
    ax.set_xlabel('Frequency (1/year)',size=7)
    ax.set_xlim([0.00833,nino_sig[0][-1]])
    
    ax.set_ylabel('Power Spectral Density',size=7)
    ax.set_ylim([0,0.3])
    ax.set_title(title)

    plt.xticks(xticks,xlabels,size=7)
    plt.yticks(size=7)
    plt.minorticks_off()
    
    ax.axvline(x=0.0167,color='k',linestyle='dotted',linewidth=1)
    ax.axvline(x=0.04167,color='k',linestyle='dotted',linewidth=1)
    ax.legend(prop={'size':5},loc='upper right',frameon=False);

    if j != 0:
        ax.set_ylabel(None)

# Lagged Regression Plots

In [None]:
## Plot global-mean radiation vs lag for EP and CP ENSO events
def lagged_gm_plot():
    fig, ax = plt.subplots(2,2,figsize=(12,12),dpi=300)

    global_mean(alag).sel(mode=0).plot(ax=ax[0,0],color='k',marker='.',linewidth=1,label='All Sky EP')
    global_mean(clag).sel(mode=0).plot(ax=ax[0,0],color='g',marker='.',linewidth=1,label='Clear Sky')
    global_mean(nlag).sel(mode=0).plot(ax=ax[0,0],color='purple',marker='.',linewidth=1,label='Net CRE')
    global_mean(llag).sel(mode=0).plot(ax=ax[0,0],color='r',marker='.',linewidth=1,label='Low Cloud CRE')
    global_mean(hlag).sel(mode=0).plot(ax=ax[0,0],color='b',marker='.',linewidth=1,label='High Cloud CRE')

    global_mean(sclag).sel(mode=0).plot(ax=ax[0,1],color='darkred',marker='.',linewidth=1,label='SST CRE')
    global_mean(eclag).sel(mode=0).plot(ax=ax[0,1],color='orangered',marker='.',linewidth=1,label='EIS CRE')
    global_mean(tclag).sel(mode=0).plot(ax=ax[0,1],color='lightcoral',marker='.',linewidth=1,label='Tadv CRE')
    global_mean(drdtlag).sel(mode=0).plot(ax=ax[0,1],color='darkgreen',marker='.',linewidth=1,label='Planck+LR')
    global_mean(drdqlag).sel(mode=0).plot(ax=ax[0,1],color='limegreen',marker='.',linewidth=1,label='Water Vapor')

    global_mean(alag).sel(mode=1).plot(ax=ax[1,0],color='k',marker='.',linewidth=1,label='All Sky')
    global_mean(clag).sel(mode=1).plot(ax=ax[1,0],color='g',marker='.',linewidth=1,label='Clear Sky')
    global_mean(nlag).sel(mode=1).plot(ax=ax[1,0],color='purple',marker='.',linewidth=1,label='Net CRE')
    global_mean(llag).sel(mode=1).plot(ax=ax[1,0],color='r',marker='.',linewidth=1,label='Low Cloud CRE')
    global_mean(hlag).sel(mode=1).plot(ax=ax[1,0],color='b',marker='.',linewidth=1,label='High Cloud CRE')

    global_mean(sclag).sel(mode=1).plot(ax=ax[1,1],color='darkred',marker='.',linewidth=1,label='SST CRE')
    global_mean(eclag).sel(mode=1).plot(ax=ax[1,1],color='orangered',marker='.',linewidth=1,label='EIS CRE')
    global_mean(tclag).sel(mode=1).plot(ax=ax[1,1],color='lightcoral',marker='.',linewidth=1,label='Tadv CRE')
    global_mean(drdtlag).sel(mode=1).plot(ax=ax[1,1],color='darkgreen',marker='.',linewidth=1,label='Planck+LR')
    global_mean(drdqlag).sel(mode=1).plot(ax=ax[1,1],color='limegreen',marker='.',linewidth=1,label='Water Vapor')

    for i in range(0,2):    
        for j in range(0,2):
            ax[0,j].set_title('E-Pattern Lag Coefficients',size=10)
            ax[1,j].set_title('C-Pattern Lag Coefficients',size=10)
            
            ax[i,j].set_xlabel('Lag (month)',size=10)
            ax[i,j].set_xlim([lag_list[0],lag_list[-1]])
            ax[i,j].tick_params(axis='x',labelsize=5)
            ax[i,j].set_ylim([-0.25,0.25])
            ax[i,j].tick_params(axis='y',labelsize=5)
            
            ax[i,j].axvline(x=0,linestyle='dotted',color='k',alpha=0.75)
            ax[i,j].axhline(y=0,linestyle='dotted',color='k',alpha=0.75)
            ax[i,j].legend(prop={'size':7},loc='upper left',frameon=False)

    # fig.savefig('lag_coefs_grad_descent.png')

In [None]:
## Plot global-mean radiation vs lag for EP and CP ENSO events
def lagged_EOF_plot():
    fig, ax = plt.subplots(2,2,figsize=(12,12),dpi=300)

    global_mean(alag_EOF).sel(mode=0).plot(ax=ax[0,0],color='k',marker='.',linewidth=1,label='All Sky EP')
    global_mean(clag_EOF).sel(mode=0).plot(ax=ax[0,0],color='g',marker='.',linewidth=1,label='Clear Sky')
    global_mean(nlag_EOF).sel(mode=0).plot(ax=ax[0,0],color='purple',marker='.',linewidth=1,label='Net CRE')
    global_mean(llag_EOF).sel(mode=0).plot(ax=ax[0,0],color='r',marker='.',linewidth=1,label='Low Cloud CRE')
    global_mean(hlag_EOF).sel(mode=0).plot(ax=ax[0,0],color='b',marker='.',linewidth=1,label='High Cloud CRE')

    global_mean(sclag_EOF).sel(mode=0).plot(ax=ax[0,1],color='darkred',marker='.',linewidth=1,label='SST CRE')
    global_mean(eclag_EOF).sel(mode=0).plot(ax=ax[0,1],color='orangered',marker='.',linewidth=1,label='EIS CRE')
    global_mean(tclag_EOF).sel(mode=0).plot(ax=ax[0,1],color='lightcoral',marker='.',linewidth=1,label='Tadv CRE')
    global_mean(drdtlag_EOF).sel(mode=0).plot(ax=ax[0,1],color='darkgreen',marker='.',linewidth=1,label='Planck+LR')
    global_mean(drdqlag_EOF).sel(mode=0).plot(ax=ax[0,1],color='limegreen',marker='.',linewidth=1,label='Water Vapor')

    global_mean(alag_EOF).sel(mode=1).plot(ax=ax[1,0],color='k',marker='.',linewidth=1,label='All Sky')
    global_mean(clag_EOF).sel(mode=1).plot(ax=ax[1,0],color='g',marker='.',linewidth=1,label='Clear Sky')
    global_mean(nlag_EOF).sel(mode=1).plot(ax=ax[1,0],color='purple',marker='.',linewidth=1,label='Net CRE')
    global_mean(llag_EOF).sel(mode=1).plot(ax=ax[1,0],color='r',marker='.',linewidth=1,label='Low Cloud CRE')
    global_mean(hlag_EOF).sel(mode=1).plot(ax=ax[1,0],color='b',marker='.',linewidth=1,label='High Cloud CRE')

    global_mean(sclag_EOF).sel(mode=1).plot(ax=ax[1,1],color='darkred',marker='.',linewidth=1,label='SST CRE')
    global_mean(eclag_EOF).sel(mode=1).plot(ax=ax[1,1],color='orangered',marker='.',linewidth=1,label='EIS CRE')
    global_mean(tclag_EOF).sel(mode=1).plot(ax=ax[1,1],color='lightcoral',marker='.',linewidth=1,label='Tadv CRE')
    global_mean(drdtlag_EOF).sel(mode=1).plot(ax=ax[1,1],color='darkgreen',marker='.',linewidth=1,label='Planck+LR')
    global_mean(drdqlag_EOF).sel(mode=1).plot(ax=ax[1,1],color='limegreen',marker='.',linewidth=1,label='Water Vapor')

    for i in range(0,2):    
        for j in range(0,2):
            ax[0,j].set_title('E-Pattern Lag Coefficients',size=10)
            ax[1,j].set_title('C-Pattern Lag Coefficients',size=10)
            
            ax[i,j].set_xlabel('Lag (month)',size=10)
            ax[i,j].set_xlim([lag_list[0],lag_list[-1]])
            ax[i,j].tick_params(axis='x',labelsize=5)
            ax[i,j].set_ylim([-0.25,0.25])
            ax[i,j].tick_params(axis='y',labelsize=5)
            
            ax[i,j].axvline(x=0,linestyle='dotted',color='k',alpha=0.75)
            ax[i,j].axhline(y=0,linestyle='dotted',color='k',alpha=0.75)
            ax[i,j].legend(prop={'size':7},loc='upper left',frameon=False)

    # fig.savefig('lag_coefs_grad_descent.png')

In [None]:
## Plot full and reconstructed radiation spatial maps vs lag for EP and CP ENSO events
def lagged_spatial_maps(full_rad_data,EOFs,nlag,name):
    fig, ax = plt.subplots(1,4,figsize=(12,2),subplot_kw={'projection': ccrs.Robinson(central_longitude=180)},dpi=200)
    fig.tight_layout(w_pad=10)
    
    # Initialization
    vmax = 4
    vmin = -vmax
    levels = np.linspace(vmin,vmax,10)
    title = ['$\Delta R_{full}$ E-Pattern', 'EOF E-Pattern', '$\Delta R_{full}$ C-Pattern', 'EOF C-Pattern']
    

    cbar_data = ax[0].contourf(full_rad_data.longitude,full_rad_data.latitude,full_rad_data.sel(mode=0,lag=nlag),transform=ccrs.PlateCarree(),cmap=plt.get_cmap('RdBu').reversed(),
                                  add_colorbar=False,levels=levels,extend='both')
    ax[1].contourf(EOFs.sel(mode=0).longitude,EOFs.sel(mode=0).latitude,EOFs.sel(mode=0,lag=nlag),transform=ccrs.PlateCarree(),cmap=plt.get_cmap('RdBu').reversed(),
                                  add_colorbar=False,levels=levels,extend='both')  
    ax[2].contourf(full_rad_data.longitude,full_rad_data.latitude,full_rad_data.sel(mode=1,lag=nlag),transform=ccrs.PlateCarree(),cmap=plt.get_cmap('RdBu').reversed(),
                                  add_colorbar=False,levels=levels,extend='both')    
    ax[3].contourf(EOFs.sel(mode=1).longitude,EOFs.sel(mode=1).latitude,EOFs.sel(mode=1,lag=nlag),transform=ccrs.PlateCarree(),cmap=plt.get_cmap('RdBu').reversed(),
                                  add_colorbar=False,levels=levels,extend='both')  

    for i in range(0,4):
        ax[i].coastlines()
        ax[i].set_aspect('auto')
        ax[i].set_title(title[i], fontsize=20)
        
        # Only give (1) first row titles, (2) last row colorbars, (3) all other rows nothing
        if nlag == -6:
            ax[i].set_title(title[i], fontsize=20)
        elif nlag == 6:
            cbar_ax = fig.add_axes([0.01, -0.1, 0.99, 0.05])
            cbar = fig.colorbar(cbar_data, cax=cbar_ax, orientation='horizontal', extend='both', format='%g', ticks=np.round(levels,2))
            cbar.ax.tick_params(labelsize=20)
            cbar.set_label(label='$W/m^2/K$', size=20)
            
            ax[i].set_title(None)
        else:
            ax[i].set_title(None)

        fig.suptitle('Lag = ' + str(nlag), y=0.75, x=-0.01, rotation='vertical', fontsize=15)
        
#         fig.savefig('lag_plots/lagged_regress_comparison_'+name+'_lag_'+str(nlag)+'.png',bbox_inches='tight')

In [None]:
def lagged_maps_tot(name):
    fig, ax = plt.subplots(7, 1, figsize=(24,28), dpi=200, gridspec_kw={'height_ratios': [1,1,1,1,1,1,1.4]})
    fig.tight_layout(h_pad=-1)

    path = '/data/keeling/a/tjhanke2/enso-energy-budget/Code/eof_analysis/lag_plots/'

    img1 = plt.imread(path + 'lagged_regress_comparison_'+name+'_lag_-9.png')
    img2 = plt.imread(path + 'lagged_regress_comparison_'+name+'_lag_-6.png')
    img3 = plt.imread(path + 'lagged_regress_comparison_'+name+'_lag_-3.png')
    img4 = plt.imread(path + 'lagged_regress_comparison_'+name+'_lag_0.png')
    img5 = plt.imread(path + 'lagged_regress_comparison_'+name+'_lag_3.png')
    img6 = plt.imread(path + 'lagged_regress_comparison_'+name+'_lag_6.png')
    img7 = plt.imread(path + 'lagged_regress_comparison_'+name+'_lag_9.png')

    ax[0].imshow(img1, aspect='auto')
    ax[1].imshow(img2, aspect='auto')
    ax[2].imshow(img3, aspect='auto')
    ax[3].imshow(img4, aspect='auto')
    ax[4].imshow(img5, aspect='auto')
    ax[5].imshow(img6, aspect='auto')
    ax[6].imshow(img7, aspect='auto')

    for i in range(0, 7):
        ax[i].axis('off')

    fig.savefig(path + 'lagged_regress_'+name+'_tot.png', bbox_inches='tight')
    plt.close()

In [None]:
def r2_plots():
    fig, ax = plt.subplots(2,2,figsize=(12,6),dpi=300)
    fig.tight_layout(h_pad=3)

#     ax[0,0].plot(lag_list,r2_sst,color='gray',marker='.',linewidth=1,label='SST')
    ax[0,0].plot(lag_list,r2_all_sky.sel(mode=0),color='k',marker='.',linewidth=1,label='All Sky')
    ax[0,0].plot(lag_list,r2_clr_sky.sel(mode=0),color='g',marker='.',linewidth=1,label='Clear Sky')
    ax[0,0].plot(lag_list,r2_net_cre.sel(mode=0),color='purple',marker='.',linewidth=1,label='Net CRE')
    ax[0,0].plot(lag_list,r2_loCld_cre.sel(mode=0),color='r',marker='.',linewidth=1,label='Low Cloud CRE')
    ax[0,0].plot(lag_list,r2_highCld_cre.sel(mode=0),color='b',marker='.',linewidth=1,label='High Cloud CRE')
    
#     ax[0,1].plot(lag_list,r2_sst,color='gray',marker='.',linewidth=1,label='SST')
    ax[0,1].plot(lag_list,r2_SST_CRE.sel(mode=0),color='darkred',marker='.',linewidth=1,label='SST CRE')
    ax[0,1].plot(lag_list,r2_EIS_CRE.sel(mode=0),color='orangered',marker='.',linewidth=1,label='EIS CRE')
    ax[0,1].plot(lag_list,r2_Tadv_CRE.sel(mode=0),color='lightcoral',marker='.',linewidth=1,label='Tadv CRE')
    ax[0,1].plot(lag_list,r2_dR_dT.sel(mode=0),color='darkgreen',marker='.',linewidth=1,label='Planck+LR')
    ax[0,1].plot(lag_list,r2_dR_dq.sel(mode=0),color='limegreen',marker='.',linewidth=1,label='Water Vapor')

#     ax[1,0].plot(lag_list,r2_sst,color='gray',marker='.',linewidth=1,label='SST')
    ax[1,0].plot(lag_list,r2_all_sky.sel(mode=1),color='k',marker='.',linewidth=1,label='All Sky')
    ax[1,0].plot(lag_list,r2_clr_sky.sel(mode=1),color='g',marker='.',linewidth=1,label='Clear Sky')
    ax[1,0].plot(lag_list,r2_net_cre.sel(mode=1),color='purple',marker='.',linewidth=1,label='Net CRE')
    ax[1,0].plot(lag_list,r2_loCld_cre.sel(mode=1),color='r',marker='.',linewidth=1,label='Low Cloud CRE')
    ax[1,0].plot(lag_list,r2_highCld_cre.sel(mode=1),color='b',marker='.',linewidth=1,label='High Cloud CRE')

#     ax[1,1].plot(lag_list,r2_sst,color='gray',marker='.',linewidth=1,label='SST')
    ax[1,1].plot(lag_list,r2_SST_CRE.sel(mode=1),color='darkred',marker='.',linewidth=1,label='SST CRE')
    ax[1,1].plot(lag_list,r2_EIS_CRE.sel(mode=1),color='orangered',marker='.',linewidth=1,label='EIS CRE')
    ax[1,1].plot(lag_list,r2_Tadv_CRE.sel(mode=1),color='lightcoral',marker='.',linewidth=1,label='Tadv CRE')
    ax[1,1].plot(lag_list,r2_dR_dT.sel(mode=1),color='darkgreen',marker='.',linewidth=1,label='Planck+LR')
    ax[1,1].plot(lag_list,r2_dR_dq.sel(mode=1),color='limegreen',marker='.',linewidth=1,label='Water Vapor')

    for i in range(0,2):    
        for j in range(0,2):
            ax[0,j].set_title('E-Pattern Lag Coefficients',size=10)
            ax[1,j].set_title('C-Pattern Lag Coefficients',size=10)

            ax[i,j].set_xlabel('Lag (month)',size=10)
            ax[i,j].set_xlim([lag_list[0],lag_list[-1]])
            ax[i,j].tick_params(axis='x',labelsize=5)
            
            ax[i,j].set_ylabel('$R^2$',size=5)
            ax[i,j].set_ylim([0,1])
            ax[i,j].tick_params(axis='y',labelsize=5)

            ax[i,j].axvline(x=0,linestyle='dotted',color='k',alpha=0.75)
            ax[i,j].axhline(y=0,linestyle='dotted',color='k',alpha=0.75)
            ax[i,j].legend(prop={'size':7},loc='upper left',frameon=False)

#     fig.savefig('lag_plots/r2.png',bbox_inches='tight')