In [None]:
import numpy as np
import xarray as xa
import xcdat as xc
from matplotlib import pyplot as plt
import cartopy
import pickle
from matplotlib import colors
import pandas as pd
import matplotlib as mpl
mpl.rc('font', family='DejaVu Serif') 
import seaborn as sns
from matplotlib.gridspec import GridSpec

In [None]:
gpcp = []
for i in range(1983, 2021):
    data = xc.open_dataset('/p/lustre3/shiduan/GPCP/regrid/'+str(i)+'.nc')
    gpcp.append(data)
gpcp = xa.concat(gpcp, dim='time')
print(gpcp.time)
gpcp = gpcp["__xarray_dataarray_variable__"].transpose('time', 'lon', 'lat')
gpcp = gpcp.fillna(0)


# longitude mean precip

In [None]:
fig = plt.figure(figsize=(12, 6))
for i in range(1, 13):
    gpcp_jan = gpcp.sel(time=gpcp.time.dt.month==i).mean(dim='time')
    gpcp_jan_longitude = gpcp_jan.median(dim=['lat'])
    gpcp_jan_mask = gpcp_jan > gpcp_jan_longitude
    gpcp_jan_mask = gpcp_jan_mask.transpose('lat', 'lon')

    ax = fig.add_subplot(3, 4, i, projection=cartopy.crs.Robinson(central_longitude=180))
    con = ax.contourf(gpcp_jan_mask.lon, gpcp_jan_mask.lat, gpcp_jan_mask, 
                      transform=cartopy.crs.PlateCarree(), cmap='BrBG')
    ax.add_feature(cartopy.feature.COASTLINE)
    plt.colorbar(con, ax=ax, shrink=.5)
    ax.set_title(str(i))
plt.tight_layout()
plt.show()

In [None]:
fig = plt.figure(figsize=(12, 6))
for i in range(1, 13):
    gpcp_jan = gpcp.sel(time=gpcp.time.dt.month==i).mean(dim='time')
    gpcp_jan_longitude = gpcp_jan.mean(dim=['lat'])
    gpcp_jan_mask = gpcp_jan>gpcp_jan_longitude
    gpcp_jan_mask = gpcp_jan_mask.transpose('lat', 'lon')

    ax = fig.add_subplot(3, 4, i, projection=cartopy.crs.PlateCarree(central_longitude=180))
    con = ax.contourf(gpcp_jan_mask.lon, gpcp_jan_mask.lat, gpcp_jan_mask, transform=cartopy.crs.PlateCarree(), cmap='BrBG')
    ax.add_feature(cartopy.feature.COASTLINE)
    plt.colorbar(con, ax=ax, shrink=.5)
    ax.set_title(str(i))
plt.tight_layout()
plt.show()

In [None]:
variable = 'pr'
eof_start = 1979
with open('/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(eof_start)+'_2022/'+variable+'-solver-stand-True-month-False-unforced-False-joint-False', 'rb') as pfile:
    solver_list_stand = pd.read_pickle(pfile)
with open('/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(eof_start)+'_2022/'+variable+'-solver-stand-True-month-True-unforced-False-joint-False', 'rb') as pfile:
    solver_list_month_stand = pd.read_pickle(pfile)

In [None]:
finger = solver_list_stand[0].eofs().isel(mode=0)
finger = finger.transpose('lat', 'lon')
finger
finger_jan = solver_list_month_stand[0].eofs().isel(mode=0)
finger_jan = finger_jan.transpose('lat', 'lon')

In [None]:
all_months = []
for i in range(12):
    eof_finger = solver_list_month_stand[i].eofs().isel(mode=0)
    eof_finger = eof_finger.transpose('lat', 'lon')
    eof_sign = xa.where(eof_finger>0, 1, -1)
    gpcp_month = gpcp.sel(time=gpcp.time.dt.month==i+1).median(dim='time')
    gpcp_month_longitude = gpcp_month.mean(dim=['lat'])
    mask = gpcp_month>gpcp_month_longitude
    mask = mask.transpose('lat', 'lon')
    mask = xa.where(mask>0, 1, -1)
    total = mask.shape[0]*mask.shape[1]
    ww = ((eof_sign == 1) & (mask == 1)).sum()/(total)
    dd = ((eof_sign == -1) & (mask == -1)).sum()/(total)
    dw = ((eof_sign == 1) & (mask == -1)).sum()/(total)
    wd = ((eof_sign == -1) & (mask == 1)).sum()/(total)
    wwdd = ww+dd
    if i==0:
        plt.bar(i+1, ww, color='green', label='wet-wetter')
        plt.bar(i+1, dd, bottom=ww, color='brown', label='dry-drier')
        plt.bar(i+1, wd, bottom=ww+dd, color='blue', label='wet-drier')
        plt.bar(i+1, dw, bottom=ww+dd+wd, color='red', label='dry-wetter')
    else:
        plt.bar(i+1, ww, color='green')
        plt.bar(i+1, dd, bottom=ww, color='brown')
        plt.bar(i+1, wd, bottom=ww+dd, color='blue')
        plt.bar(i+1, dw, bottom=ww+dd+wd, color='red')
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(12, 5))
for ind, i in enumerate([0, 3, 6, 9]):
    eof_finger = solver_list_month_stand[i].eofs().isel(mode=0)
    eof_finger = eof_finger.transpose('lat', 'lon')
    eof_sign = xa.where(eof_finger>0, 1, -1)
    gpcp_month = gpcp.sel(time=gpcp.time.dt.month==i+1).median(dim='time')
    gpcp_month_longitude = gpcp_month.mean(dim=['lat'])
    mask = gpcp_month>gpcp_month_longitude
    mask = mask.transpose('lat', 'lon')
    mask = xa.where(mask>0, 1, -1)
    ww = (eof_sign == 1) & (mask == 1)
    dd = (eof_sign == -1) & (mask == -1)
    dw = (eof_sign == 1) & (mask == -1)
    wd = (eof_sign == -1) & (mask == 1)
    classification = xa.where(ww, 1, 
                    xa.where(dd, -1, 
                    xa.where(wd, 0.5, 
                    xa.where(dw, -0.5, 0))))
    
    ax = fig.add_subplot(2, 2, ind+1, projection=cartopy.crs.PlateCarree(central_longitude=180))
    c = ax.contourf(mask.lon, mask.lat, classification, transform=cartopy.crs.PlateCarree(), 
                    cmap='BrBG', levels=[-1, -0.5, 0, 0.5, 1])
    ax.add_feature(cartopy.feature.COASTLINE)
    # Add a colorbar with discrete labels
    cbar = plt.colorbar(c, ax=ax, orientation='vertical', shrink=0.5, 
                        pad=0.05, ticks=[-.75, -0.25, 0.25, .75])
    # cbar.set_label('Classification')
    # Set the tick labels for the colorbar
    cbar.ax.set_yticklabels([
        f"dd (dry-drier) {dd.sum().data / (72*144) * 100:.2f}%", 
        f"dw (dry-wetter) {dw.sum().data / (72*144) * 100:.2f}%", 
        f"wd (wet-drier) {wd.sum().data / (72*144) * 100:.2f}%", 
        f"ww (wet-wetter) {ww.sum().data / (72*144) * 100:.2f}%"
    ], fontsize=12)
    ax.set_title('Month: '+str(i+1))
plt.tight_layout()
plt.show()


In [None]:
class_values = [-1, -0.5, 0.5, 1]
class_labels = ['dry-drier', 'dry-wetter', 'wet-drier', 'wet-wetter']

# Choose 4 consistent colors (you can also reverse the order if needed)
palette = sns.color_palette("BrBG", n_colors=4)
color_dict = {
    -1: palette[0],     # dry-drier
    -0.5: palette[1],   # dry-wetter
    0.5: palette[2],    # wet-drier
    1: palette[3]       # wet-wetter
}
from matplotlib.colors import ListedColormap, BoundaryNorm
levels = [-1.5, -0.75, -0.25, 0.75, 1.5]  # 4 bins
cmap = ListedColormap([color_dict[val] for val in class_values])
norm = BoundaryNorm(levels, ncolors=cmap.N)

# Land

In [None]:
maskfile = "/p/lustre1/shiduan/REGEN/REGEN_mask_forcesmip.nc"
missing_data_maskx = xa.open_dataset(maskfile)
missing_data = np.where(
    np.isnan(missing_data_maskx.p.transpose('lon', 'lat')), np.nan, 1)
missing_xa = xa.where(np.isnan(missing_data_maskx.p), np.nan, 1)
missing_xa.shape

In [None]:
with open('/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(eof_start)+'_2022/'+variable+'-solver-stand-True-month-True-unforced-False-joint-False-REGEN-mask', 'rb') as pfile:
    solver_list_month_stand_land = pd.read_pickle(pfile)

In [None]:
missing_xa.sum()

In [None]:
missing_xa.plot()

## Longitude
* global longitude + land mask

In [None]:
fig = plt.figure(figsize=(12, 5))
for ind, i in enumerate([0, 3, 6, 9]):
    eof_finger = solver_list_month_stand_land[i].eofs().isel(mode=0)
    eof_finger = eof_finger.transpose('lat', 'lon')
    eof_sign = xa.where(eof_finger>0, 1, -1)
    gpcp_month = gpcp.sel(time=gpcp.time.dt.month==i+1).mean(dim='time')
    gpcp_month_longitude = gpcp_month.median(dim=['lat'])
    mask = gpcp_month>gpcp_month_longitude
    mask = mask.transpose('lat', 'lon')
    mask = xa.where(mask>0, 1, -1)*missing_xa
    ww = (eof_sign == 1) & (mask*missing_xa == 1)
    dd = (eof_sign == -1) & (mask*missing_xa == -1)
    dw = (eof_sign == 1) & (mask*missing_xa == -1)
    wd = (eof_sign == -1) & (mask*missing_xa == 1)
    classification = xa.where(ww, 1, 
                    xa.where(dd, -1, 
                    xa.where(wd, 0.5, 
                    xa.where(dw, -0.5, 0))))*missing_xa
    
    ax = fig.add_subplot(2, 2, ind+1, projection=cartopy.crs.PlateCarree(central_longitude=180))
    c = ax.contourf(mask.lon, mask.lat, classification, transform=cartopy.crs.PlateCarree(), 
                    cmap='BrBG', levels=[-1, -0.5, 0, 0.5, 1])
    ax.add_feature(cartopy.feature.COASTLINE)
    # Add a colorbar with discrete labels
    cbar = plt.colorbar(c, ax=ax, orientation='vertical', shrink=0.3, 
                        pad=0.05, ticks=[-.75, -0.25, 0.25, .75])
    cbar.set_label('Classification')
    # Set the tick labels for the colorbar
    cbar.ax.set_yticklabels([
        f"dd (dry-dry) {dd.sum().data / 2203 * 100:.2f}%", 
        f"dw (dry-wet) {dw.sum().data / 2203 * 100:.2f}%", 
        f"wd (wet-dry) {wd.sum().data / 2203 * 100:.2f}%", 
        f"ww (wet-wet) {ww.sum().data / 2203 * 100:.2f}%"
    ])
    ax.set_title('Month: '+str(i+1))
plt.tight_layout()
plt.show()


# Combine

# land vs global

In [None]:
eof_finger1 = solver_list_month_stand_land[i].eofs().isel(mode=0)
eof_finger1 = eof_finger1.transpose('lat', 'lon')
eof_sign1 = xa.where(eof_finger1>0, 1, -1)*missing_xa
eof_finger2 = solver_list_month_stand[i].eofs().isel(mode=0)
eof_finger2 = eof_finger2.transpose('lat', 'lon')
eof_sign2 = xa.where(eof_finger2>0, 1, -1)*missing_xa

In [None]:
(eof_sign2-eof_sign1).plot()

# into one figure

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
fig = plt.figure(figsize=(15, 9))
gs = GridSpec(nrows=3, ncols=3, figure=fig, height_ratios=[1, 1, .5], width_ratios=[1, .1, 1], wspace=0.2)
months = ['Jan', 'Apr', 'Jul', 'Oct']
# Global
for ind, i in enumerate([0, 3, 6, 9]):
    eof_finger = solver_list_month_stand[i].eofs().isel(mode=0)
    eof_finger = eof_finger.transpose('lat', 'lon')
    eof_sign = xa.where(eof_finger>0, 1, -1)
    gpcp_month = gpcp.sel(time=gpcp.time.dt.month==i+1).mean(dim='time')
    gpcp_month_longitude = gpcp_month.median(dim=['lat'])
    mask = gpcp_month>gpcp_month_longitude
    mask = mask.transpose('lat', 'lon')
    mask = xa.where(mask>0, 1, -1)
    ww = (eof_sign == 1) & (mask == 1)
    dd = (eof_sign == -1) & (mask == -1)
    dw = (eof_sign == 1) & (mask == -1)
    wd = (eof_sign == -1) & (mask == 1)
    classification = xa.where(ww, 1, 
                    xa.where(dd, -1, 
                    xa.where(wd, 0.5, 
                    xa.where(dw, -0.5, 0))))
    row = (ind+1)//2
    column = (ind)%2
    if column==1: column = 2
    print(row, column)
    ax = fig.add_subplot(gs[(ind)//2, column], projection=cartopy.crs.Robinson(central_longitude=180))
    
    c = ax.contourf(mask.lon, mask.lat, classification, transform=cartopy.crs.PlateCarree(), 
                    cmap=cmap, levels=levels)
    ax.add_feature(cartopy.feature.COASTLINE)
    # Add a colorbar with discrete labels
    cbar = plt.colorbar(c, ax=ax, orientation='vertical', shrink=0.7, 
                        pad=0.02, ticks=[-1.1, -0.5, .25, 1.1])
    # land
    dd_land = dd*missing_xa
    ww_land = ww*missing_xa
    dw_land = dw*missing_xa
    wd_land = wd*missing_xa
    # Set the tick labels for the colorbar
    cbar.ax.set_yticklabels([
        f"dry-drier\n{dd.sum().data / (72*144) * 100:.2f}%|{dd_land.sum().data / (2203) * 100:.2f}%", 
        f"dry-wetter\n{dw.sum().data / (72*144) * 100:.2f}%|{dw_land.sum().data / (2203) * 100:.2f}%", 
        f"wet-drier\n{wd.sum().data / (72*144) * 100:.2f}%|{wd_land.sum().data / (2203) * 100:.2f}%", 
        f"wet-wetter\n{ww.sum().data / (72*144) * 100:.2f}%|{ww_land.sum().data / (2203) * 100:.2f}%"
    ], fontsize=10)
    ax.set_title(months[ind], fontsize=12)
    pos = ax.get_position()
    zonal_ax = fig.add_axes([
        pos.x0 - 0.04,  # 左侧偏移一点
        pos.y0,
        0.02,           # 宽度你自己定
        pos.height
    ])
    # zonal_ax = fig.add_subplot(gs[(ind)//2, (ind)%2*2])
    dd_zonal = dd.mean(dim='lon')
    ww_zonal = ww.mean(dim='lon')
    dw_zonal = dw.mean(dim='lon')
    wd_zonal = wd.mean(dim='lon')
    zonal_ax.fill_betweenx(ww_zonal.lat, x1=0, x2=ww_zonal, color=color_dict[1])
    zonal_ax.fill_betweenx(dd_zonal.lat, x1=ww_zonal, x2=dd_zonal+ww_zonal, color=color_dict[-1])
    zonal_ax.fill_betweenx(wd_zonal.lat, x1=dd_zonal+ww_zonal, x2=wd_zonal+dd_zonal+ww_zonal, color=color_dict[.5])
    zonal_ax.fill_betweenx(dw_zonal.lat, x1=wd_zonal+dd_zonal+ww_zonal, x2=dw_zonal+wd_zonal+dd_zonal+ww_zonal, color=color_dict[-.5])
    plt.gca().invert_xaxis()
    zonal_ax.set_ylabel('Latitude', fontsize=11)
    zonal_ax.set_yticks(np.arange(-90, 91, 30))
    zonal_ax.set_aspect('auto')
ax = fig.add_subplot(gs[2, 0])
all_months = []
for i in range(12):
    eof_finger = solver_list_month_stand[i].eofs().isel(mode=0)
    eof_finger = eof_finger.transpose('lat', 'lon')
    eof_sign = xa.where(eof_finger>0, 1, -1)
    gpcp_month = gpcp.sel(time=gpcp.time.dt.month==i+1).mean(dim='time')
    gpcp_month_longitude = gpcp_month.median(dim=['lat'])
    mask = gpcp_month>gpcp_month_longitude
    mask = mask.transpose('lat', 'lon')
    mask = xa.where(mask>0, 1, -1)
    total = mask.shape[0]*mask.shape[1]
    ww = ((eof_sign == 1) & (mask == 1)).sum()/(total)*100
    dd = ((eof_sign == -1) & (mask == -1)).sum()/(total)*100
    dw = ((eof_sign == 1) & (mask == -1)).sum()/(total)*100
    wd = ((eof_sign == -1) & (mask == 1)).sum()/(total)*100
    wwdd = ww+dd
    if i==0:
        ax.bar(i+1, ww, color=color_dict[1], label='wet-wetter')
        ax.bar(i+1, dd, bottom=ww, color=color_dict[-1], label='dry-drier')
        ax.bar(i+1, wd, bottom=ww+dd, color=color_dict[0.5], label='wet-drier')
        ax.bar(i+1, dw, bottom=ww+dd+wd, color=color_dict[-0.5], label='dry-wetter')
    else:
        ax.bar(i+1, ww, color=color_dict[1])
        ax.bar(i+1, dd, bottom=ww, color=color_dict[-1])
        ax.bar(i+1, wd, bottom=ww+dd, color=color_dict[0.5])
        ax.bar(i+1, dw, bottom=ww+dd+wd, color=color_dict[-0.5])
    all_months.append(wwdd)
ax.set_xlabel('Month', fontsize=11)
ax.set_ylabel('Percentage', fontsize=11)
ax.set_xticks([1, 4, 7, 10])
ax.set_xticklabels(months)
ax.set_title(f'Global (std={np.std(all_months):.2f})', fontsize=12)
plt.legend(loc='upper right', bbox_to_anchor=(1.28, 1.01))

ax = fig.add_subplot(gs[2, 2])
all_months = []
for i in range(12):
    eof_finger = solver_list_month_stand_land[i].eofs().isel(mode=0)
    eof_finger = eof_finger.transpose('lat', 'lon')
    eof_sign = xa.where(eof_finger>0, 1, -1)*missing_xa
    gpcp_month = gpcp.sel(time=gpcp.time.dt.month==i+1).mean(dim='time')
    gpcp_month_longitude = gpcp_month.median(dim=['lat'])
    mask = gpcp_month>gpcp_month_longitude
    mask = mask.transpose('lat', 'lon')
    mask = xa.where(mask>0, 1, -1)*missing_xa
    total = 2203
    ww = ((eof_sign == 1) & (mask == 1)).sum()/(total)*100
    dd = ((eof_sign == -1) & (mask == -1)).sum()/(total)*100
    dw = ((eof_sign == 1) & (mask == -1)).sum()/(total)*100
    wd = ((eof_sign == -1) & (mask == 1)).sum()/(total)*100
    wwdd = ww+dd
    if i==0:
        ax.bar(i+1, ww, color=color_dict[1], label='wet-wetter')
        ax.bar(i+1, dd, bottom=ww, color=color_dict[-1], label='dry-drier')
        ax.bar(i+1, wd, bottom=ww+dd, color=color_dict[0.5], label='wet-drier')
        ax.bar(i+1, dw, bottom=ww+dd+wd, color=color_dict[-0.5], label='dry-wetter')
    else:
        ax.bar(i+1, ww, color=color_dict[1])
        ax.bar(i+1, dd, bottom=ww, color=color_dict[-1])
        ax.bar(i+1, wd, bottom=ww+dd, color=color_dict[0.5])
        ax.bar(i+1, dw, bottom=ww+dd+wd, color=color_dict[-0.5])
    all_months.append(wwdd)
ax.set_xlabel('Month', fontsize=11)
ax.set_ylabel('Percentage', fontsize=11)
ax.set_xticks([1, 4, 7, 10])
ax.set_xticklabels(months)
ax.set_title(f'Land (std={np.std(all_months):.2f})', fontsize=12)
plt.legend(loc='upper right', bbox_to_anchor=(1.28, 1.01))

# plt.tight_layout()
plt.annotate('a)', xy=(0.04, 0.825), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
plt.annotate('b)', xy=(0.04, 0.22), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
plt.savefig('DDWW_grid_ONE.png', dpi=150, bbox_inches='tight')

# MSWEP

In [None]:
mswep = xc.open_dataset(
    '/p/lustre3/shiduan/MSWEP/MSWEP-V280-Past-v20231102-monpr-forcesmip.nc')
mswep = mswep.fillna(0)
mswep = mswep.sel(time=slice('1983-01-01', '2021-01-01'))
mswep["__xarray_dataarray_variable__"] = mswep["__xarray_dataarray_variable__"].transpose('time', 'lon', 'lat')
mswep = mswep.bounds.add_missing_bounds()
mswep = mswep['__xarray_dataarray_variable__']

In [None]:
fig = plt.figure(figsize=(15, 9))
gs = GridSpec(nrows=3, ncols=3, figure=fig, height_ratios=[1, 1, .5], width_ratios=[1, .1, 1], wspace=0.2)
months = ['Jan', 'Apr', 'Jul', 'Oct']
# Global
for ind, i in enumerate([0, 3, 6, 9]):
    eof_finger = solver_list_month_stand[i].eofs().isel(mode=0)
    eof_finger = eof_finger.transpose('lat', 'lon')
    eof_sign = xa.where(eof_finger>0, 1, -1)
    mswep_month = mswep.sel(time=mswep.time.dt.month==i+1).mean(dim='time')
    mswep_month_longitude = mswep_month.median(dim=['lat'])
    mask = mswep_month>mswep_month_longitude
    mask = mask.transpose('lat', 'lon')
    mask = xa.where(mask>0, 1, -1)
    ww = (eof_sign == 1) & (mask == 1)
    dd = (eof_sign == -1) & (mask == -1)
    dw = (eof_sign == 1) & (mask == -1)
    wd = (eof_sign == -1) & (mask == 1)
    classification = xa.where(ww, 1, 
                    xa.where(dd, -1, 
                    xa.where(wd, 0.5, 
                    xa.where(dw, -0.5, 0))))
    row = (ind)//2
    column = (ind)%2
    if column==1: column = 2
    print(row, column)
    ax = fig.add_subplot(gs[(ind)//2, column], projection=cartopy.crs.Robinson(central_longitude=180))
    
    c = ax.contourf(mask.lon, mask.lat, classification, transform=cartopy.crs.PlateCarree(), 
                    cmap=cmap, levels=levels)
    ax.add_feature(cartopy.feature.COASTLINE)
    # Add a colorbar with discrete labels
    cbar = plt.colorbar(c, ax=ax, orientation='vertical', shrink=0.7, 
                        pad=0.02, ticks=[-1.1, -0.5, .25, 1.1])
    # land
    dd_land = dd*missing_xa
    ww_land = ww*missing_xa
    dw_land = dw*missing_xa
    wd_land = wd*missing_xa
    # Set the tick labels for the colorbar
    cbar.ax.set_yticklabels([
        f"dry-drier\n{dd.sum().data / (72*144) * 100:.2f}%|{dd_land.sum().data / (2203) * 100:.2f}%", 
        f"dry-wetter\n{dw.sum().data / (72*144) * 100:.2f}%|{dw_land.sum().data / (2203) * 100:.2f}%", 
        f"wet-drier\n{wd.sum().data / (72*144) * 100:.2f}%|{wd_land.sum().data / (2203) * 100:.2f}%", 
        f"wet-wetter\n{ww.sum().data / (72*144) * 100:.2f}%|{ww_land.sum().data / (2203) * 100:.2f}%"
    ], fontsize=10)
    ax.set_title(months[ind], fontsize=12)
    pos = ax.get_position()
    zonal_ax = fig.add_axes([
        pos.x0 - 0.04,  # 左侧偏移一点
        pos.y0,
        0.02,           # 宽度你自己定
        pos.height
    ])
    # zonal_ax = fig.add_subplot(gs[(ind)//2, (ind)%2*2])
    dd_zonal = dd.mean(dim='lon')
    ww_zonal = ww.mean(dim='lon')
    dw_zonal = dw.mean(dim='lon')
    wd_zonal = wd.mean(dim='lon')
    zonal_ax.fill_betweenx(ww_zonal.lat, x1=0, x2=ww_zonal, color=color_dict[1])
    zonal_ax.fill_betweenx(dd_zonal.lat, x1=ww_zonal, x2=dd_zonal+ww_zonal, color=color_dict[-1])
    zonal_ax.fill_betweenx(wd_zonal.lat, x1=dd_zonal+ww_zonal, x2=wd_zonal+dd_zonal+ww_zonal, color=color_dict[.5])
    zonal_ax.fill_betweenx(dw_zonal.lat, x1=wd_zonal+dd_zonal+ww_zonal, x2=dw_zonal+wd_zonal+dd_zonal+ww_zonal, color=color_dict[-.5])
    plt.gca().invert_xaxis()
    zonal_ax.set_yticks(np.arange(-90, 91, 30))
    zonal_ax.set_aspect('auto')
    zonal_ax.set_ylabel('Latitude', fontsize=11)
ax = fig.add_subplot(gs[2, 0])
all_months = []
for i in range(12):
    eof_finger = solver_list_month_stand[i].eofs().isel(mode=0)
    eof_finger = eof_finger.transpose('lat', 'lon')
    eof_sign = xa.where(eof_finger>0, 1, -1)
    mswep_month = mswep.sel(time=mswep.time.dt.month==i+1).mean(dim='time')
    mswep_month_longitude = mswep_month.median(dim=['lat'])
    mask = mswep_month>mswep_month_longitude
    mask = mask.transpose('lat', 'lon')
    mask = xa.where(mask>0, 1, -1)
    total = mask.shape[0]*mask.shape[1]
    ww = ((eof_sign == 1) & (mask == 1)).sum()/(total)*100
    dd = ((eof_sign == -1) & (mask == -1)).sum()/(total)*100
    dw = ((eof_sign == 1) & (mask == -1)).sum()/(total)*100
    wd = ((eof_sign == -1) & (mask == 1)).sum()/(total)*100
    wwdd = ww+dd
    if i==0:
        ax.bar(i+1, ww, color=color_dict[1], label='wet-wetter')
        ax.bar(i+1, dd, bottom=ww, color=color_dict[-1], label='dry-drier')
        ax.bar(i+1, wd, bottom=ww+dd, color=color_dict[0.5], label='wet-drier')
        ax.bar(i+1, dw, bottom=ww+dd+wd, color=color_dict[-0.5], label='dry-wetter')
    else:
        ax.bar(i+1, ww, color=color_dict[1])
        ax.bar(i+1, dd, bottom=ww, color=color_dict[-1])
        ax.bar(i+1, wd, bottom=ww+dd, color=color_dict[0.5])
        ax.bar(i+1, dw, bottom=ww+dd+wd, color=color_dict[-0.5])
    all_months.append(wwdd)
ax.set_xlabel('Month', fontsize=11)
ax.set_ylabel('Percentage', fontsize=11)
ax.set_xticks([1, 4, 7, 10])
ax.set_xticklabels(months)
ax.set_title(f'Global (std={np.std(all_months):.2f})', fontsize=12)
plt.legend(loc='upper right', bbox_to_anchor=(1.28, 1.01))

ax = fig.add_subplot(gs[2, 2])
all_months = []
for i in range(12):
    eof_finger = solver_list_month_stand_land[i].eofs().isel(mode=0)
    eof_finger = eof_finger.transpose('lat', 'lon')
    eof_sign = xa.where(eof_finger>0, 1, -1)*missing_xa
    mswep_month = mswep.sel(time=mswep.time.dt.month==i+1).mean(dim='time')
    mswep_month_longitude = mswep_month.median(dim=['lat'])
    mask = mswep_month>mswep_month_longitude
    mask = mask.transpose('lat', 'lon')
    mask = xa.where(mask>0, 1, -1)*missing_xa
    total = 2203
    ww = ((eof_sign == 1) & (mask == 1)).sum()/(total)*100
    dd = ((eof_sign == -1) & (mask == -1)).sum()/(total)*100
    dw = ((eof_sign == 1) & (mask == -1)).sum()/(total)*100
    wd = ((eof_sign == -1) & (mask == 1)).sum()/(total)*100
    wwdd = ww+dd
    if i==0:
        ax.bar(i+1, ww, color=color_dict[1], label='wet-wetter')
        ax.bar(i+1, dd, bottom=ww, color=color_dict[-1], label='dry-drier')
        ax.bar(i+1, wd, bottom=ww+dd, color=color_dict[0.5], label='wet-drier')
        ax.bar(i+1, dw, bottom=ww+dd+wd, color=color_dict[-0.5], label='dry-wetter')
    else:
        ax.bar(i+1, ww, color=color_dict[1])
        ax.bar(i+1, dd, bottom=ww, color=color_dict[-1])
        ax.bar(i+1, wd, bottom=ww+dd, color=color_dict[0.5])
        ax.bar(i+1, dw, bottom=ww+dd+wd, color=color_dict[-0.5])
    all_months.append(wwdd)
ax.set_xlabel('Month', fontsize=11)
ax.set_ylabel('Percentage', fontsize=11)
ax.set_xticks([1, 4, 7, 10])
ax.set_xticklabels(months)
ax.set_title(f'Land (std={np.std(all_months):.2f})', fontsize=12)
plt.legend(loc='upper right', bbox_to_anchor=(1.28, 1.01))

# plt.tight_layout()
plt.annotate('a)', xy=(0.04, 0.825), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
plt.annotate('b)', xy=(0.04, 0.22), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
plt.savefig('DDWW_grid_ONE_MSWEP.png', dpi=150, bbox_inches='tight')

# use model-specific eofs

In [None]:
def plot_model_precip(fig, gs, precip, solver_list_month_stand, row_inc = 0):
    months = ['Jan', 'Apr', 'Jul', 'Oct']
    # Global
    for ind, i in enumerate([0, 3, 6, 9]):
        eof_finger = solver_list_month_stand[i].eofs().isel(mode=0)
        pc = solver_list_month_stand[i].pcs().isel(mode=0)
        m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
        if m<0:
            eof_finger = -eof_finger
            print('reversed')
        eof_finger = eof_finger.transpose('lat', 'lon')
        eof_sign = xa.where(eof_finger>0, 1, -1)
        precip_month = precip.sel(time=precip.time.dt.month==i+1).mean(dim='time')
        precip_month_longitude = precip_month.median(dim=['lat'])
        mask = precip_month>precip_month_longitude
        mask = mask.transpose('lat', 'lon')
        mask = xa.where(mask>0, 1, -1)
        ww = (eof_sign == 1) & (mask == 1)
        dd = (eof_sign == -1) & (mask == -1)
        dw = (eof_sign == 1) & (mask == -1)
        wd = (eof_sign == -1) & (mask == 1)
        classification = xa.where(ww, 1, 
                        xa.where(dd, -1, 
                        xa.where(wd, 0.5, 
                        xa.where(dw, -0.5, 0))))
        row = (ind)//2
        column = (ind)%2
        if column==1: column = 2
        print(row, column)
        ax = fig.add_subplot(gs[row+row_inc, column], projection=cartopy.crs.PlateCarree(central_longitude=180))
        
        c = ax.contourf(mask.lon, mask.lat, classification, transform=cartopy.crs.PlateCarree(), 
                        cmap=cmap, levels=levels)
        ax.add_feature(cartopy.feature.COASTLINE)
        # Add a colorbar with discrete labels
        cbar = plt.colorbar(c, ax=ax, orientation='vertical', shrink=0.7, 
                            pad=0.02, ticks=[-1.1, -0.5, .25, 1.1])
        # land
        dd_land = dd*missing_xa
        ww_land = ww*missing_xa
        dw_land = dw*missing_xa
        wd_land = wd*missing_xa
        # Set the tick labels for the colorbar
        cbar.ax.set_yticklabels([
            f"dry-drier\n{dd.sum().data / (72*144) * 100:.2f}%|{dd_land.sum().data / (2203) * 100:.2f}%", 
            f"dry-wetter\n{dw.sum().data / (72*144) * 100:.2f}%|{dw_land.sum().data / (2203) * 100:.2f}%", 
            f"wet-drier\n{wd.sum().data / (72*144) * 100:.2f}%|{wd_land.sum().data / (2203) * 100:.2f}%", 
            f"wet-wetter\n{ww.sum().data / (72*144) * 100:.2f}%|{ww_land.sum().data / (2203) * 100:.2f}%"
        ], fontsize=10)
        ax.set_title(months[ind], fontsize=12)
        pos = ax.get_position()
        zonal_ax = fig.add_axes([
            pos.x0 - 0.04,  # 左侧偏移一点
            pos.y0,
            0.02,           # 宽度你自己定
            pos.height
        ])
        # zonal_ax = fig.add_subplot(gs[(ind)//2, (ind)%2*2])
        dd_zonal = dd.mean(dim='lon')
        ww_zonal = ww.mean(dim='lon')
        dw_zonal = dw.mean(dim='lon')
        wd_zonal = wd.mean(dim='lon')
        zonal_ax.fill_betweenx(ww_zonal.lat, x1=0, x2=ww_zonal, color=color_dict[1])
        zonal_ax.fill_betweenx(dd_zonal.lat, x1=ww_zonal, x2=dd_zonal+ww_zonal, color=color_dict[-1])
        zonal_ax.fill_betweenx(wd_zonal.lat, x1=dd_zonal+ww_zonal, x2=wd_zonal+dd_zonal+ww_zonal, color=color_dict[.5])
        zonal_ax.fill_betweenx(dw_zonal.lat, x1=wd_zonal+dd_zonal+ww_zonal, x2=dw_zonal+wd_zonal+dd_zonal+ww_zonal, color=color_dict[-.5])
        plt.gca().invert_xaxis()
        zonal_ax.set_yticks(np.arange(-90, 91, 30))
        zonal_ax.set_aspect('auto')
        zonal_ax.set_ylabel('Latitude', fontsize=11)
    ax = fig.add_subplot(gs[2+row_inc, 0])
    all_months = []
    for i in range(12):
        eof_finger = solver_list_month_stand[i].eofs().isel(mode=0)
        eof_finger = eof_finger.transpose('lat', 'lon')
        pc = solver_list_month_stand[i].pcs().isel(mode=0)
        m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
        if m<0:
            eof_finger = -eof_finger
            print('reversed')
        eof_sign = xa.where(eof_finger>0, 1, -1)
        precip_month = precip.sel(time=precip.time.dt.month==i+1).mean(dim='time')
        precip_month_longitude = precip_month.median(dim=['lat'])
        mask = precip_month>precip_month_longitude
        mask = mask.transpose('lat', 'lon')
        mask = xa.where(mask>0, 1, -1)
        total = mask.shape[0]*mask.shape[1]
        ww = ((eof_sign == 1) & (mask == 1)).sum()/(total)*100
        dd = ((eof_sign == -1) & (mask == -1)).sum()/(total)*100
        dw = ((eof_sign == 1) & (mask == -1)).sum()/(total)*100
        wd = ((eof_sign == -1) & (mask == 1)).sum()/(total)*100
        wwdd = ww+dd
        if i==0:
            ax.bar(i+1, ww, color=color_dict[1], label='wet-wetter')
            ax.bar(i+1, dd, bottom=ww, color=color_dict[-1], label='dry-drier')
            ax.bar(i+1, wd, bottom=ww+dd, color=color_dict[0.5], label='wet-drier')
            ax.bar(i+1, dw, bottom=ww+dd+wd, color=color_dict[-0.5], label='dry-wetter')
        else:
            ax.bar(i+1, ww, color=color_dict[1])
            ax.bar(i+1, dd, bottom=ww, color=color_dict[-1])
            ax.bar(i+1, wd, bottom=ww+dd, color=color_dict[0.5])
            ax.bar(i+1, dw, bottom=ww+dd+wd, color=color_dict[-0.5])
        
        all_months.append(ww+dd)
        # plt.legend()
    ax.set_xlabel('Month', fontsize=11)
    ax.set_ylabel('Percentage', fontsize=11)
    ax.set_xticks([1, 4, 7, 10])
    ax.set_xticklabels(months)
    ax.set_title(f'Global (std={np.std(all_months):.2f})', fontsize=12)
    print('Global std', np.std(all_months))
    ax = fig.add_subplot(gs[2+row_inc, 2])
    all_months = []
    for i in range(12):
        eof_finger = solver_list_month_stand[i].eofs().isel(mode=0)
        eof_finger = eof_finger.transpose('lat', 'lon')
        pc = solver_list_month_stand[i].pcs().isel(mode=0)
        m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
        # print(m, i)
        if m<0:
            eof_finger = -eof_finger
            # print('reversed')
        eof_sign = xa.where(eof_finger>0, 1, -1)*missing_xa
        precip_month = precip.sel(time=precip.time.dt.month==i+1).mean(dim='time')
        precip_month_longitude = precip_month.median(dim=['lat'])
        mask = precip_month>precip_month_longitude
        mask = mask.transpose('lat', 'lon')
        mask = xa.where(mask>0, 1, -1)*missing_xa
        total = 2203
        ww = ((eof_sign == 1) & (mask == 1)).sum()/(total)*100
        dd = ((eof_sign == -1) & (mask == -1)).sum()/(total)*100
        dw = ((eof_sign == 1) & (mask == -1)).sum()/(total)*100
        wd = ((eof_sign == -1) & (mask == 1)).sum()/(total)*100
        wwdd = ww+dd
        if i==0:
            ax.bar(i+1, ww, color=color_dict[1], label='wet-wetter')
            ax.bar(i+1, dd, bottom=ww, color=color_dict[-1], label='dry-drier')
            ax.bar(i+1, wd, bottom=ww+dd, color=color_dict[0.5], label='wet-drier')
            ax.bar(i+1, dw, bottom=ww+dd+wd, color=color_dict[-0.5], label='dry-wetter')
        else:
            ax.bar(i+1, ww, color=color_dict[1])
            ax.bar(i+1, dd, bottom=ww, color=color_dict[-1])
            ax.bar(i+1, wd, bottom=ww+dd, color=color_dict[0.5])
            ax.bar(i+1, dw, bottom=ww+dd+wd, color=color_dict[-0.5])
        all_months.append(ww+dd)
    ax.set_xlabel('Month', fontsize=11)
    ax.set_ylabel('Percentage', fontsize=11)
    ax.set_xticks([1, 4, 7, 10])
    ax.set_xticklabels(months)
    ax.set_title(f'Land (std={np.std(all_months):.2f})', fontsize=12)
    plt.legend(loc='upper right', bbox_to_anchor=(1.28, 1.01))
    print('Land std', np.std(all_months))
    # plt.tight_layout()
    # plt.annotate('MSWEP', xy=(0.05, 0.95), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
    # plt.annotate('b)', xy=(0.02, 0.31), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
    # plt.savefig('DDWW_grid_ONE_precip.png', dpi=150, bbox_inches='tight')

In [None]:
def get_model(model, stand, month):
    path = '/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/' + \
        str(1979)+'_2022/model_'+str(model)+'/'+'pr' + \
        '-CMIP-metrics-stand-'+str(stand)+'-month-'+str(month)+'-unforced-False-joint-False'
    with open(path, 'rb') as pfile:
        results = pickle.load(pfile)
    return results

In [None]:
model = 'CanESM5'
solver = '/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(1979)+'_2022/model_'+model+'/'
solver = solver+'pr'+'-solver-stand-True-month-True-unforced-False-joint-False'
with open(solver, 'rb') as pfile:
    solver = pickle.load(pfile)
fig = plt.figure(figsize=(15, 19))
gs = GridSpec(nrows=7, ncols=3, figure=fig, height_ratios=[1, 1, .8, .01, 1, 1, .8], width_ratios=[1, .1, 1], wspace=0.2)
plot_model_precip(fig, gs, gpcp, solver)
plot_model_precip(fig, gs, mswep, solver, row_inc=4)
plt.annotate('GPCP', xy=(0.05, 0.80), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
plt.annotate('MSWEP', xy=(0.05, 0.395), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
plt.savefig('WWDD-'+model+'.png', dpi=180, bbox_inches='tight')

In [None]:
model = 'CESM2'
solver = '/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(1979)+'_2022/model_'+model+'/'
solver = solver+'pr'+'-solver-stand-True-month-True-unforced-False-joint-False'
with open(solver, 'rb') as pfile:
    solver = pickle.load(pfile)
fig = plt.figure(figsize=(15, 19))
gs = GridSpec(nrows=7, ncols=3, figure=fig, height_ratios=[1, 1, .8, .01, 1, 1, .8], width_ratios=[1, .1, 1], wspace=0.2)
plot_model_precip(fig, gs, gpcp, solver)
plot_model_precip(fig, gs, mswep, solver, row_inc=4)
plt.annotate('GPCP', xy=(0.05, 0.80), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
plt.annotate('MSWEP', xy=(0.05, 0.395), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
plt.savefig('WWDD-'+model+'.png', dpi=180, bbox_inches='tight')

In [None]:
pattern = solver[4].eofs().isel(mode=0)
pc = solver[4].pcs().isel(mode=0)
plt.plot(pc)
m, b = np.polyfit(np.arange(pc.shape[0]), pc, deg=1)
if (m<0):
    print('reversed')

In [None]:
model = 'MIROC6'
solver = '/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(1979)+'_2022/model_'+model+'/'
solver = solver+'pr'+'-solver-stand-True-month-True-unforced-False-joint-False'
with open(solver, 'rb') as pfile:
    solver = pickle.load(pfile)
fig = plt.figure(figsize=(15, 19))
gs = GridSpec(nrows=7, ncols=3, figure=fig, height_ratios=[1, 1, .8, .01, 1, 1, .8], width_ratios=[1, .1, 1], wspace=0.2)
plot_model_precip(fig, gs, gpcp, solver)
plot_model_precip(fig, gs, mswep, solver, row_inc=4)
plt.annotate('GPCP', xy=(0.05, 0.80), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
plt.annotate('MSWEP', xy=(0.05, 0.395), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
plt.savefig('WWDD-'+model+'.png', dpi=180, bbox_inches='tight')

In [None]:
model = 'MPI-ESM1-2-LR'
solver = '/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(1979)+'_2022/model_'+model+'/'
solver = solver+'pr'+'-solver-stand-True-month-True-unforced-False-joint-False'
with open(solver, 'rb') as pfile:
    solver = pickle.load(pfile)
fig = plt.figure(figsize=(15, 19))
gs = GridSpec(nrows=7, ncols=3, figure=fig, height_ratios=[1, 1, .8, .01, 1, 1, .8], width_ratios=[1, .1, 1], wspace=0.2)
plot_model_precip(fig, gs, gpcp, solver)
plot_model_precip(fig, gs, mswep, solver, row_inc=4)
plt.annotate('GPCP', xy=(0.05, 0.80), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
plt.annotate('MSWEP', xy=(0.05, 0.395), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
plt.savefig('WWDD-'+model+'.png', dpi=180, bbox_inches='tight')

In [None]:
model = 'MIROC-ES2L'
solver = '/p/lustre2/shiduan/ForceSMIP/EOF/modes_all/'+str(1979)+'_2022/model_'+model+'/'
solver = solver+'pr'+'-solver-stand-True-month-True-unforced-False-joint-False'
with open(solver, 'rb') as pfile:
    solver = pickle.load(pfile)
fig = plt.figure(figsize=(15, 19))
gs = GridSpec(nrows=7, ncols=3, figure=fig, height_ratios=[1, 1, .8, .01, 1, 1, .8], width_ratios=[1, .1, 1], wspace=0.2)
plot_model_precip(fig, gs, gpcp, solver)
plot_model_precip(fig, gs, mswep, solver, row_inc=4)
plt.annotate('GPCP', xy=(0.05, 0.80), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
plt.annotate('MSWEP', xy=(0.05, 0.395), xycoords='figure fraction', ha='center', fontsize=14, weight='bold')
plt.savefig('WWDD-'+model+'.png', dpi=180, bbox_inches='tight')