In [1]:
import numpy as np
import xarray as xr
import regionmask
import geopandas as gpd
import matplotlib.pyplot as plt
import os
import sys
import scipy
sys.path.append('/home/climate/chaoz/code/utils/')
from plot_utils import plot_settings

In [2]:
os.chdir('/home/climate/chaoz/project/03Irr_Ts_CN/processed/')
dLSTd = xr.open_dataset('delta_LSTday_Yr_CN_2001_2020.nc').mean(dim='time')
dLSTn = xr.open_dataset('delta_LSTnight_Yr_CN_2001_2020.nc').mean(dim='time')

dLSTle= xr.open_dataset('delta_LSTle_Yr_CN_2001_2020_PMLv2.nc').mean(dim='time')
dLSThe= xr.open_dataset('delta_LSThe_Yr_CN_2001_2020_PMLv2.nc').mean(dim='time')
dLSTsw= xr.open_dataset('delta_LSTsw_Yr_CN_2001_2020_PMLv2.nc').mean(dim='time')

dLSTobs= (dLSTd + dLSTn)/2
dLSTpre= dLSTle + dLSThe + dLSTsw

shp_climzone = gpd.read_file('../shapefile/ClimateZone_3.shp')

In [3]:
def stats_regionmean(ds,varname,shp):
    """Calculate the national and regional mean values

    Args:
        ds (xarray dataset): time-series dataset
        varname (str): the variable name in the xarray dataset
        shp (geo dataframe): shapefile loaded by geopandas.read_file()

    Returns:
        numpy.dataframe: time-series annual means for different regions
    """
    # Calculate the national mean values
    stat_cn = ds.mean(dim=('lat','lon'))[varname].values
    
    mask_region = regionmask.mask_geopandas(shp,ds.lon,ds.lat)
    stat_region = ds.groupby(mask_region).mean().to_dataframe().reset_index()
    
    mask_mapping = {0.0:'Humid', 1.0:'Arid', 2.0:'Semi'}
    stat_region['mask'] = stat_region['mask'].replace(mask_mapping)
    
    # Rearrange the dataframe by the different masks (regions)
    pivoted_df = stat_region.pivot_table(index='time', columns='mask', values=varname).reset_index()
    # Append the national mean values as a column 'China'
    pivoted_df['China'] = stat_cn
    
    # Re-order the columns
    pivoted_df = pivoted_df[['time','China','Arid','Semi','Humid']]
    
    return pivoted_df

In [4]:
ds=xr.Dataset(
    {
    'obs':(['lat','lon'],dLSTobs.Ts.values),
    'pre':(['lat','lon'],dLSTpre.Ts.values)
    },
    coords={
        'lat':dLSTobs.lat,
        'lon':dLSTobs.lon
    }
    )

In [5]:
mask_region = regionmask.mask_geopandas(shp_climzone, ds.lon, ds.lat)

df_China = ds.to_dataframe().reset_index().dropna().drop(columns=['lat','lon'])
df_Arid  = ds.where(mask_region==1).to_dataframe().reset_index().dropna().drop(columns=['lat','lon'])
df_Semi  = ds.where(mask_region==2).to_dataframe().reset_index().dropna().drop(columns=['lat','lon'])
df_Humid = ds.where(mask_region==0).to_dataframe().reset_index().dropna().drop(columns=['lat','lon'])

In [6]:
def getSigFlag(p):
    """Get the significance sign (star)

    Args:
        p (float32): p-value

    Returns:
        str: significance sign
    """
    strSig = ''
    if p<0.05:
        strSig = '**'
    elif p<0.1:
        strSig = '*'
    else:
        strSig = ''
    return strSig


def plot_scatter(ax,x,y,ylim,xlabel,ylabel,title):
    ax.scatter(x,y,color='grey',alpha=0.3,s=0.8)
    parameter = np.polyfit(x, y, 1)
    f = np.poly1d(parameter)
    r, sig = scipy.stats.pearsonr(x, y)
    newx=np.linspace(np.min(x),np.max(x),20)
    ax.plot(newx, f(newx), ls='--', c='k')
    ax.text(0.10,0.90, "R$^{2}$=%.2f$^{%s}$" %
             (r*r,getSigFlag(sig)),
             transform=ax.transAxes, c= 'k',fontsize=12)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylim(ylim)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    return ax

In [None]:
fig, axes = plt.subplots(nrows=1,ncols=4,figsize=(10,3),sharex=False)
plot_settings()
ylim1,ylim2 = [0,15],[-0.19,0.79]

ylim=[-8,8]
xlabel,ylabel='Observation [K]','Predication [K]'
plot_scatter(axes[0],df_China['obs'],df_China['pre'],ylim,xlabel,ylabel,'China')
plot_scatter(axes[1],df_Arid['obs'],df_Arid['pre']  ,ylim,xlabel,'','Arid')
plot_scatter(axes[2],df_Semi['obs'],df_Semi['pre']  ,ylim,xlabel,'','Semi-arid/humid')
plot_scatter(axes[3],df_Humid['obs'],df_Humid['pre'],ylim,xlabel,'','Humid')

plt.tight_layout()
plt.savefig('../figures/Figure_S17.png',dpi=300)