Zhang Chao, 2025.04.25<br>
Plotting the energy flux changes

In [1]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import pandas as pd
import numpy as np
import xarray as xr
import geopandas as gpd
import regionmask
import matplotlib.patches as patches
import sys
import matplotlib as mpl
import os
from cartopy.feature import ShapelyFeature
import matplotlib.colors as mcolors
sys.path.append('/home/climate/chaoz/code/utils/')
from plot_utils import plot_settings, uneven_cmap

In [2]:
os.chdir('/home/climate/chaoz/project/03Irr_Ts_CN/processed/')

dLE = xr.open_dataset('delta_LE_Yr_CN_Mean_PMLv2.nc')
dHE = xr.open_dataset('delta_HE_Yr_CN_Mean_PMLv2.nc')
dLWo= xr.open_dataset('delta_LWRadout_Yr_CN_mean_PMLv2.nc')
dSWo= xr.open_dataset('delta_SWRadout_Yr_CN_mean_PMLv2.nc')

dLE_05d = xr.open_dataset('delta_LE_Yr_CN_2001_2020_PMLv2_05d.nc').mean(dim='time')
dHE_05d = xr.open_dataset('delta_HE_Yr_CN_2001_2020_PMLv2_05d.nc').mean(dim='time')
dLWo_05d= xr.open_dataset('delta_LWRadout_Yr_CN_2001_2020_PMLv2_05d.nc').mean(dim='time')
dSWo_05d= xr.open_dataset('delta_SWRadout_Yr_CN_2001_2020_PMLv2_05d.nc').mean(dim='time')

shp_cn = gpd.read_file('../shapefile/ChinaAll.shp')
shp_nanhai = gpd.read_file('../shapefile/Nanhai.shp')
shp_climzone = gpd.read_file('../shapefile/ClimateZone_3.shp')

In [3]:
def stats_regionmean_95p(ds, varname, shp):
    """Calculate the national and regional mean values and 95% CI half-width.

    The 95% CI half-width is calculated as: std / sqrt(n) * 1.96,
    approximating the confidence interval for the mean based on the
    standard deviation of the spatial grid cells within each region/nation.

    Args:
        ds (xarray dataset): time-series dataset containing the variable
                             with 'lat' and 'lon' dimensions.
        varname (str): the variable name in the xarray dataset.
        shp (geo dataframe): shapefile loaded by geopandas.read_file()
                             defining the regions.

    Returns:
        pandas.DataFrame: time-series means and 95p CI half-widths
                          for the nation and different regions.
    """
    # Ensure we are working with the specific variable DataArray
    da = ds[varname]

    # --- National Calculation ---
    # Count non-null values for the specific variable nationally
    national_count = da.notnull().sum().item()
    # Calculate national mean
    cn_mean = da.mean(dim=('lat', 'lon')).item()
    # Calculate national standard deviation
    cn_std_dev = da.std(dim=('lat', 'lon')).item()
    # Calculate national 95p CI half-width
    # Handle case where count might be zero
    if national_count > 0:
        cn_95p = (cn_std_dev / np.sqrt(national_count)) * 1.96
    else:
        cn_95p = np.nan # Assign NaN if no valid data points

    # --- Regional Calculation ---
    # Create the region mask
    mask_region = regionmask.mask_geopandas(shp, da.lon, da.lat)

    # Calculate regional means
    region_mean = da.groupby(mask_region).mean()
    # Calculate regional standard deviations
    region_std_dev = da.groupby(mask_region).std()
    # Calculate regional counts (number of non-null cells per region)
    region_count = da.notnull().groupby(mask_region).sum()

    # Calculate regional 95p CI half-width
    # Use xarray's where to handle potential division by zero if count is 0
    regional_95p = (region_std_dev / np.sqrt(region_count)).where(region_count > 0) * 1.96

    # Convert regional results to DataFrames
    region_mean_df = region_mean.to_dataframe(name='mean').reset_index()
    # Use the calculated regional_95p directly
    region_95p_df = regional_95p.to_dataframe(name='95p').reset_index()

    # Map numerical mask values to region names
    if 'name' in shp.columns: # Example: if shp has a 'name' column
         mask_mapping = {i: name for i, name in enumerate(shp['name'])}
    else: # Fallback to your original mapping if no name column
        mask_mapping = {1.0:'Arid', 2.0:'Semi', 0.0:'Humid'}
        print("Warning: Using default mask mapping. Verify it matches your shapefile regions.")


    region_mean_df['mask'] = region_mean_df['mask'].replace(mask_mapping)
    region_95p_df['mask'] = region_95p_df['mask'].replace(mask_mapping)

    # --- Combine Results ---
    # Merge regional mean and 95p based on the mask name
    df1 = pd.merge(region_mean_df[['mask', 'mean']],
                   region_95p_df[['mask', '95p']],
                   on='mask')

    # Create DataFrame for national results
    df2 = pd.DataFrame({'mask': ['China'], # Assuming 'China' is the national label
                        'mean': [cn_mean],
                        '95p': [cn_95p]}) # Use the new column name '95p'

    # Concatenate regional and national results
    df3 = pd.concat([df1, df2], axis=0, ignore_index=True).set_index('mask')

    # Reindex to desired order (ensure region names match those from mask_mapping)
    # Get the region names from the mapping used
    region_order = ['China'] + list(mask_mapping.values())
    # Filter df3 index to only include expected regions before reindexing
    df3 = df3[df3.index.isin(region_order)]
    df3 = df3.reindex(region_order) # Use the actual region names

    return df3


def stats_regionmean(ds,varname,shp):
    """Calculate the national and regional mean values

    Args:
        ds (xarray dataset): time-series dataset
        varname (str): the variable name in the xarray dataset
        shp (geo dataframe): shapefile loaded by geopandas.read_file()

    Returns:
        numpy.dataframe: time-series annual means for different regions
    """
    # Calculate the national mean values
    stat_cn = ds.mean(dim=('lat','lon'))[varname].values
    
    mask_region = regionmask.mask_geopandas(shp,ds.lon,ds.lat)
    stat_region = ds.groupby(mask_region).mean().to_dataframe().reset_index()
    
    mask_mapping = {0.0:'Humid', 1.0:'Arid', 2.0:'Semi'}
    stat_region['mask'] = stat_region['mask'].replace(mask_mapping)
    
    # Rearrange the dataframe by the different masks (regions)
    pivoted_df = stat_region.pivot_table(index='time', columns='mask', values=varname).reset_index()
    # Append the national mean values as a column 'China'
    pivoted_df['China'] = stat_cn
    
    # Re-order the columns
    pivoted_df = pivoted_df[['time','China','Arid','Semi','Humid']]
    
    return pivoted_df


In [4]:
df_dLE = stats_regionmean_95p(dLE,'LE',shp_climzone)
df_dHE = stats_regionmean_95p(dHE,'H',shp_climzone)

df_dLWo = stats_regionmean_95p(dLWo,'Lr',shp_climzone)
df_dSWo = stats_regionmean_95p(dSWo,'Sr',shp_climzone)



In [5]:
def plot_mean_map(fig, pos, ds, extent,shape_cn,shape_nanhai,levels, mycmap, No,title,fameon_sing,nanhai_sign):
    lambert_proj = ccrs.LambertConformal(central_longitude=105, central_latitude=35, standard_parallels=(25, 47))
    ax = fig.add_axes(pos, projection=lambert_proj, frame_on=fameon_sing)
    # Define Lambert Conformal projection (customized for China)

    ds.plot(ax=ax, levels=levels,transform=ccrs.PlateCarree(),#lambert_proj
                     cmap=mycmap, add_colorbar=False, rasterized=True) #vmax=vmax, vmin=vmin,
    ax.set_extent(extent, crs=ccrs.PlateCarree())
    # ax1.add_feature(fea_hpa, edgecolor='k', linewidth=0.8)
    fea_cn = ShapelyFeature(shape_cn.geometry, crs=ccrs.PlateCarree())
    fea_nanhai = ShapelyFeature(shape_nanhai.geometry, crs=ccrs.PlateCarree())
    ax.add_feature(fea_cn,facecolor='none',linewidth=.6)
    ax.add_feature(fea_nanhai,facecolor='none',linewidth=.6)
    
    ax.text(0.01, 1.0, No, transform=ax.transAxes, c='k', weight='bold', fontsize=14)
    ax.set_title(title)
    # Add the South China Sea map
    if nanhai_sign:
        hx,vx = 0.0381,0.10
        pos_scs = [ax.get_position().x1-hx,ax.get_position().y0,hx,vx]
        ax_scs = fig.add_axes(pos_scs,projection=lambert_proj,frame_on=True)
        ax_scs.add_feature(fea_cn    ,facecolor='none',linewidth=.6)
        ax_scs.add_feature(fea_nanhai,facecolor='none',linewidth=.6)
        ax_scs.set_extent([107, 120,3,23], crs=ccrs.PlateCarree())
    
    return ax


def plot_lat_stats(fig,pos,ds,xlim,ylim):
    ax = fig.add_axes(pos)
    lat_stats = ds.quantile([0.25,0.50,0.75], dim='lon')
    lat = lat_stats.lat
    p25 = lat_stats.sel(quantile=0.25)
    p50 = lat_stats.sel(quantile=0.50)
    p75 = lat_stats.sel(quantile=0.75)
    
    # Fill the region between the 25th and 75th percentiles
    ax.fill_betweenx(lat,p25,p75,color='gray',alpha=0.3)
    # Plot the median (50th percentile) as a solid line
    ax.plot(p50,lat,color='k',linewidth=1.2)
    ax.axvline(0,0,1,color='k',linewidth=1,linestyle='--')
    # ax.text(-0.12,1.0,No,transform = ax.transAxes,
    #          weight = 'bold',fontsize=14)
    ax.yaxis.tick_right()
    ax.set_ylim(ylim)
    ax.set_xlim(xlim)
    # ax.set_xticks([-0.02,0,0.02],[-0.2,0,0.2])
    ax.set_yticks([20,30,40,50],['20N','30N','40N','50N'])
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)
    
    return ax


def plot_bars(fig,pos,mean,std,No,title,left_flag):
    hi1,hi2=pos[3]*7/8,pos[3]*1/8
    pos1 = [pos[0], pos[1] + hi1+0.008, pos[2], hi2-0.008]  # Upper part
    pos2 = [pos[0], pos[1],  pos[2], hi1]
    # Create two axes for the broken y-axis
    ax1, ax2 = fig.add_axes(pos1), fig.add_axes(pos2)
    ax1.set_ylim([10.5, 11.5])  # Upper part
    ax2.set_ylim([-5.5, 3.5])  # Lower part
    ax2.axhline(0, 0, 10, color='k',linestyle='--',linewidth=0.8)
    colors = ['gray','#FFA726','#33A02C','#1E88E5']
    
    err_attri = dict(elinewidth=.8, ecolor='k', capsize=4)
    ax1.bar(range(1,len(mean)+1),mean,yerr=std,color=colors,error_kw=err_attri)
    ax2.bar(range(1,len(mean)+1),mean,yerr=std,color=colors,error_kw=err_attri)
    
    ax1.set_ylabel('')
    ax1.set_yticks([11],[11])
    ax2.set_yticks([-5,0],[-5,0])
    ax1.set_xticks([])
    # ax2.set_xticklabels([])  # Remove x-ticks from the lower axis
    ax2.set_xticks(range(1,len(mean)+1),['','','',''])
    ax1.text(-0.25,1.01,No,transform=ax1.transAxes,fontsize=14,weight='bold')
    
    # **Draw diagonal break marks**
    d = .015  # Offset for diagonal lines
    kwargs = dict(transform=ax1.transAxes, color='k', linewidth=1, clip_on=False)
    ax1.plot((-d, +d), (-d, +d+0.4), **kwargs)  # Top-left diagonal
    kwargs.update(transform=ax2.transAxes)  # Switch to the lower axes
    ax2.plot((-d, +d), (1 - d, 1 + d), **kwargs)  # Bottom-left diagonal
    
    ax1.spines["bottom"].set_visible(False)
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax2.spines["top"].set_visible(False)
    ax2.spines['right'].set_visible(False)
    if left_flag:
        ax1.set_ylabel('[W/m$^{2}$]')
        ax1.yaxis.set_label_coords(-0.16,-8)

    ax1.set_title(title)
    return ax1


# Function to add color bar for the main maps
def add_colorbar_map1(ax, mycmap, mynorm, levels,ticklabels):
    cb1 = mpl.colorbar.ColorbarBase(ax=ax, cmap=mycmap, norm=mynorm,
                                    orientation='vertical',
                                    ticks=levels,extend='neither')  # cmap=plt.get_cmap('hot')
    ax.set_yticklabels(ticklabels)
    ax.set_ylabel('[W/m$^{2}$]',labelpad=5)
    ax.tick_params(axis='y',right=True,length=0)
    
    
def add_bar_legend(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])
    colors = ['gray','#FFA726','#33A02C','#1E88E5']
    handles = [
        patches.Patch(color=colors[0], label='China'),
        patches.Patch(color=colors[1], label='Arid'),
        patches.Patch(color=colors[2], label='Semi'),
        patches.Patch(color=colors[3], label='Humid'),
    ]
    
    ax.legend(handles=handles, ncol=1, loc='lower right',frameon=False,
              handlelength=1, fontsize=10,bbox_to_anchor=[0.55,-0.1])
        

In [None]:
colors1 = plt.cm.BrBG(np.linspace(0.00, 0.40, 128))  # 128* seismic bwr PRGn
colors2 = plt.cm.BrBG(np.linspace(0.60, 1.00, 128))  # 128* seismic
colors = np.vstack((colors1, colors2))
mycmap = mcolors.LinearSegmentedColormap.from_list('my_colormap', colors)
plevel = np.array([0,1,2,4,6,8,10,50])
nlevel = -1 * plevel[1:]
levels1 = np.concatenate([nlevel[::-1], plevel])
cmap1, norm1 = uneven_cmap(levels1, cmap=mycmap)#RdYlBu_r 'BrBG'


fig = plt.figure(figsize=(8,7))
plot_settings()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'

x0,y0 = 0.01,0.02
xi,yi = 0.145,0.05
hx1,vx1 = 0.32,0.35

pos1 = [x0+(hx1+xi)*0, y0+(vx1+yi)*1, hx1, vx1]
pos2 = [x0+(hx1+xi)*1, y0+(vx1+yi)*1, hx1, vx1]
pos3 = [x0+(hx1+xi)*0, y0+(vx1+yi)*0, hx1, vx1]
pos4 = [x0+(hx1+xi)*1, y0+(vx1+yi)*0, hx1, vx1]

extent = [80, 127,15,54]

ax1 = plot_mean_map(fig, pos1, dLE_05d.LE,  extent,shp_cn,shp_nanhai,levels1, cmap1, 'e','$\Delta$LE',False,True)
ax2 = plot_mean_map(fig, pos2, dHE_05d.H ,  extent,shp_cn,shp_nanhai,levels1, cmap1, 'f','$\Delta$HE',False,True)
ax3 = plot_mean_map(fig, pos3, dSWo_05d.Sr, extent,shp_cn,shp_nanhai,levels1, cmap1, 'g','$\Delta$SWout',False,True)
ax4 = plot_mean_map(fig, pos4, dLWo_05d.Lr, extent,shp_cn,shp_nanhai,levels1, cmap1, 'h','$\Delta$LWout',False,True)

print(ax3.get_position().width, ax3.get_position().height)

# Plot the latitudinal statistics on the right
pos1_r = [ax1.get_position().x1+0.01,ax1.get_position().y0,0.07,ax1.get_position().height]
pos2_r = [ax2.get_position().x1+0.01,ax2.get_position().y0,0.07,ax2.get_position().height]
pos3_r = [ax3.get_position().x1+0.01,ax3.get_position().y0,0.07,ax3.get_position().height]
pos4_r = [ax4.get_position().x1+0.01,ax4.get_position().y0,0.07,ax4.get_position().height]

ax1_r = plot_lat_stats(fig,pos1_r,dLE_05d.LE,[-7,8],[extent[2],extent[3]])
ax2_r = plot_lat_stats(fig,pos2_r,dHE_05d.H ,[-7,8],[extent[2],extent[3]])
ax3_r = plot_lat_stats(fig,pos3_r,dSWo_05d.Sr,[-7,8],[extent[2],extent[3]])
ax4_r = plot_lat_stats(fig,pos4_r,dLWo_05d.Lr,[-7,8],[extent[2],extent[3]])

hx = 0.17
pos01 = [ax1.get_position().x0+0.05            ,ax1.get_position().y1+0.07,hx,0.13]
pos02 = [ax1.get_position().x0+0.05+(hx+0.05)*1,ax1.get_position().y1+0.07,hx,0.13]
pos03 = [ax1.get_position().x0+0.05+(hx+0.05)*2,ax1.get_position().y1+0.07,hx,0.13]
pos04 = [ax1.get_position().x0+0.05+(hx+0.05)*3,ax1.get_position().y1+0.07,hx,0.13]

plot_bars(fig,pos01,df_dLE['mean'] ,df_dLE['95p'] ,'a','$\Delta$LE',True)
plot_bars(fig,pos02,df_dHE['mean'] ,df_dHE['95p'] ,'b','$\Delta$HE',False)
plot_bars(fig,pos03,df_dSWo['mean'],df_dSWo['95p'],'c','$\Delta$SWout',False)
plot_bars(fig,pos04,df_dLWo['mean'],df_dLWo['95p'],'d','$\Delta$LWout',False)

pos_cbar1 = [ax2_r.get_position().x1+0.05,ax2_r.get_position().y0,
             0.01,ax2_r.get_position().height]
pos_cbar2 = [ax4_r.get_position().x1+0.05,ax4_r.get_position().y0,
             0.01,ax4_r.get_position().height]

# ax_cb1 = fig.add_axes(pos_cbar1)
ticklabels1 = ['','-10','-8','-6','-4','-2','-1','0',
                '1','2','4','6','8','10','']
add_colorbar_map1(fig.add_axes(pos_cbar1), cmap1, norm1, levels1, ticklabels1)
add_colorbar_map1(fig.add_axes(pos_cbar2), cmap1, norm1, levels1, ticklabels1)

pos_barlegend = [ax1.get_position().x0+0.05+(hx+0.05)*4,ax1.get_position().y1+0.07,0.12,0.1]
add_bar_legend(fig.add_axes(pos_barlegend))

plt.savefig('../figures/Figure_S04.png',dpi=300)