2025.05.12, Zhang Chao

In [1]:
import numpy as np
import xarray as xr
import regionmask
import geopandas as gpd
import matplotlib.pyplot as plt
import os
import sys
import scipy
sys.path.append('/home/climate/chaoz/code/utils/')
from plot_utils import plot_settings

In [2]:
os.chdir('/home/climate/chaoz/project/03Irr_Ts_CN/processed/')
dLE_PMLv2 = xr.open_dataset('delta_LE_PMLv2v017_Yr_CN_2001_2020.nc')
dLE_MODIS = xr.open_dataset('delta_LE_MOD16A2GF_Yr_CN_2001_2020.nc')

shp_climzone = gpd.read_file('../shapefile/ClimateZone_3.shp')

In [3]:
def getSigFlag(p):
    """Get the significance sign (star)

    Args:
        p (float32): p-value

    Returns:
        str: significance sign
    """
    strSig = ''
    if p<0.05:
        strSig = '**'
    elif p<0.1:
        strSig = '*'
    else:
        strSig = ''
    return strSig
    

def stats_regionmean(ds,varname,shp):
    """Calculate the national and regional mean values

    Args:
        ds (xarray dataset): time-series dataset
        varname (str): the variable name in the xarray dataset
        shp (geo dataframe): shapefile loaded by geopandas.read_file()

    Returns:
        numpy.dataframe: time-series annual means for different regions
    """
    # Calculate the national mean values
    stat_cn = ds.mean(dim=('lat','lon'))[varname].values
    
    mask_region = regionmask.mask_geopandas(shp,ds.lon,ds.lat)
    stat_region = ds.groupby(mask_region).mean().to_dataframe().reset_index()
    
    mask_mapping = {0.0:'Humid', 1.0:'Arid', 2.0:'Semi'}
    stat_region['mask'] = stat_region['mask'].replace(mask_mapping)
    
    # Rearrange the dataframe by the different masks (regions)
    pivoted_df = stat_region.pivot_table(index='time', columns='mask', values=varname).reset_index()
    # Append the national mean values as a column 'China'
    pivoted_df['China'] = stat_cn
    
    # Re-order the columns
    pivoted_df = pivoted_df[['time','China','Arid','Semi','Humid']]
    
    return pivoted_df

In [4]:
mask_region = regionmask.mask_geopandas(shp_climzone, dLE_PMLv2.lon, dLE_PMLv2.lat)

In [5]:
ds=xr.Dataset(
    {
    'PMLv2':(['time','lat','lon'],dLE_PMLv2.LE.values),
    'MODIS':(['time','lat','lon'],dLE_MODIS.LE.values)
    },
    coords={
        'time':dLE_PMLv2.time,
        'lat':dLE_PMLv2.lat,
        'lon':dLE_PMLv2.lon
    }
    )

In [6]:
ds_mean = ds.mean(dim='time')
df_China = ds_mean.to_dataframe().reset_index().dropna().drop(columns=['lat','lon'])
df_Arid  = ds_mean.where(mask_region==1).to_dataframe().reset_index().dropna().drop(columns=['lat','lon'])
df_Semi  = ds_mean.where(mask_region==2).to_dataframe().reset_index().dropna().drop(columns=['lat','lon'])
df_Humid = ds_mean.where(mask_region==0).to_dataframe().reset_index().dropna().drop(columns=['lat','lon'])

In [7]:
def plot_curves(ax,y1,y2,ylim1):
    colors = ['#2A557F', '#45BC9C', '#F05073']
    x = range(0,y1.shape[0])
    ax.plot(x,y1,color=colors[0])
    parameter1 = np.polyfit(x, y1, 1)
    f1 = np.poly1d(parameter1)
    r1, sig1 = scipy.stats.pearsonr(x, y1)
    ax.plot(x, f1(x), ls='--', c=colors[0])
    
    ax.plot(x,y2,color=colors[1])
    parameter2 = np.polyfit(x, y2, 1)
    f2 = np.poly1d(parameter1)
    r2, sig2 = scipy.stats.pearsonr(x, y2)
    ax.plot(x, f2(x), ls='--', c=colors[1])
    
    ax.text(0.10,0.90, "slope = %.2f$^{%s}$" %
             (parameter1[0]*10, getSigFlag(sig1)),
             transform=ax.transAxes, c= 'k',fontsize=12)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xticks([4,14],['2005','2015'])
    ax.set_ylim(ylim1)
    
    return ax

def plot_scatter(ax,x,y,xlabel,ylabel,title):
    ax.scatter(x,y,color='grey',alpha=0.3,s=0.8)
    parameter = np.polyfit(x, y, 1)
    f = np.poly1d(parameter)
    r, sig = scipy.stats.pearsonr(x, y)
    newx=np.linspace(np.min(x),np.max(x),20)
    ax.plot(newx, f(newx), ls='--', c='k')
    ax.text(0.10,0.9, "R$^{2}$=%.2f$^{%s}$" %
             (r*r,getSigFlag(sig)),
             transform=ax.transAxes, c= 'k',fontsize=12)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    return ax

In [None]:
fig, axes = plt.subplots(nrows=1,ncols=4,figsize=(10,3),sharex=False)
plot_settings()
ylim1,ylim2 = [0,15],[-0.19,0.79]

plot_scatter(axes[0],df_China['PMLv2'],df_China['MODIS'],'PMLv2 [W/m$^{2}$]','MODIS [W/m$^{2}$]','China')
plot_scatter(axes[1],df_Arid['PMLv2'],df_Arid['MODIS']  ,'PMLv2 [W/m$^{2}$]','','Arid')
plot_scatter(axes[2],df_Semi['PMLv2'],df_Semi['MODIS']  ,'PMLv2 [W/m$^{2}$]','','Semi-arid/humid')
plot_scatter(axes[3],df_Humid['PMLv2'],df_Humid['MODIS'],'PMLv2 [W/m$^{2}$]','','Humid')


plt.tight_layout()
plt.savefig('../figures/Figure_S16.png',dpi=300)