In [79]:
%matplotlib inline
import scipy.io as scio
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.feature as cfeature
import cartopy.crs as ccrs
import netCDF4 as nc
import numpy as np
from Load_data import Data_from_nc
import xarray as xr
import math
import gc
import shapefile

def load_mat_data(filepath, var_name, transpose=False, skip_first=29):
    """
    Load a MATLAB variable and optionally slice or transpose.
    - skip_first: number of leading rows to drop (None to skip slicing).
    - transpose: whether to transpose the result.
    """
    mat = scio.loadmat(filepath)[var_name]
    if skip_first is not None:
        mat = mat[:, skip_first:]
    return mat.T if transpose else mat

def load_nc_data(filepath, var_name, skip_lat=29, invalid_threshold=1e6):
    """
    Load and clean a variable from a NetCDF file.
    - Skips the first `skip_lat` latitudes.
    - Masks values with absolute > invalid_threshold as NaN.
    """
    with nc.Dataset(filepath) as ds:
        data = np.array(ds.variables[var_name])
    # slice off unwanted latitudes
    data = data[:, skip_lat:, :]
    # mask invalid values
    mask = np.abs(data) > invalid_threshold
    data[mask] = np.nan
    return np.squeeze(data)

# This is where you locate all the files
BASE_DIR = ''
REGION_FILE = f'{BASE_DIR}/plotting_tools/ar6_region.mat'
ar6_region = load_mat_data(REGION_FILE, 'ar6_region', transpose=True, skip_first=29)


IRR_DIFF_FILE = f'{BASE_DIR}/plotting_tools/irr_diff_out.mat'
irr_diff = load_mat_data(IRR_DIFF_FILE, 'irr_diff_out', transpose=True, skip_first=29)

# Loading NetCDF variables containing lat and lon information
nc_file = f'{BASE_DIR}/surfdata_irrigation_method.nc'
surface = Data_from_nc(nc_file)
irrigation_method = surface.load_variable('irrigation_method')
data_lon = surface.load_variable('LONGXY')[0, :]
data_lat = surface.load_variable('LATIXY')[29:, 0]

In [62]:
# seconds per year
F_S2Y = 365 * 86400

# -------------------------------------------------------------------
# 1) Put all your model‐specific path rules into one dict
# -------------------------------------------------------------------
MODEL_CONFIG = {
    'CESM2': {
        'loader': load_nc_data,
        'template':
            "{base_dir}/water_fluxes/CESM2/"
            "CESM2_{exp}{run}_1901_2014_{var}{suffix}",
        'runs': ['01', '02', '03'],
        'suffix': {
            'early': "_yearmean_1901_1930_timmean",
            'late':  "_yearmean_1985_2014_timmean"
        },
        # experiment codes
        'exp_map': {
            'irr': {'exp': "IRR"},
            'noi': {'exp': "NOI"}
        }
    },

    'CESM2_gw': {
        'loader': load_nc_data,
        'template':
            "{base_dir}/water_fluxes/CESM2_gw/"
            "CESM2_gw_{exp}{run}_1901_2014_{var}{suffix}",
        'runs': ['01', '02', '03'],
        'suffix': {
            'early': "_yearmean_1901_1930_timmean",
            'late':  "_yearmean_1985_2014_timmean"
        },
        'exp_map': {
            'irr': {'exp': "IRR"},
            'noi': {'exp': "NOI"}
        }
    },

    'E3SM': {
        'loader': load_nc_data,
        'template':
            "{base_dir}/water_fluxes/E3SMv2/"
            "E3SM_{exp}{run}_1901_2014_{var}{suffix}",
        'runs': ['01', '02'],
        'suffix': {
            'early': "_yearmean_1901_1930_timmean_0.9x1.25",
            'late':  "_yearmean_1985_2014_timmean_0.9x1.25"
        },
        'exp_map': {
            'irr': {'exp': "IRR"},
            'noi': {'exp': "NOI"}
        }
    },

    'NorESM': {
        'loader': load_nc_data,
        'template':
            "{base_dir}/water_fluxes/NorESM/"
            "NorESM_{exp}{run}_1901_2014_{var}{suffix}",
        'runs': ['01', '02', '03'],
        'suffix': {
            'early': "_yearmean_1901_1930_timmean",
            'late':  "_yearmean_1985_2014_timmean"
        },
        'exp_map': {
            'irr': {'exp': "IRR", 'exp_dir': "IRR_"},
            'noi': {'exp': "NOI", 'exp_dir': "NOI_"}
        }
    },

    'IPSL': {
        'loader': load_nc_data,
        'template':
            "{base_dir}/water_fluxes/IPSL-CM6/"
            "{exp}{run}_{var}{suffix}",
        'runs': ['01'],   # only one realization
        'suffix': {
            'early': "_1901_2014_Month.nc_yearmean_1901_1930_timmean_0.9x1.25",
            'late':  "_1901_2014_Month.nc_yearmean_1901_1930_timmean_0.9x1.25"
        },
        'exp_map': {
            'irr': {'exp': "IRR"},
            'noi': {'exp': "NOI"}
        }
    },

    'CNRM': {
        'loader': load_nc_data,
        'template':
            "{base_dir}/water_fluxes/CNRM-CM6-1/"
            "{var}_{exp}{suffix}",
        'runs': [''],    # no run number
        'suffix': {
            'early': "_1901_1930_timmean_0.9x1.25",
            'late':  "_1985_2014_timmean_0.9x1.25"
        },
        'exp_map': {
            'irr': {'exp': "IRR"},
            'noi': {'exp': "NOI"}
        }
    },

    'MIROC': {
        'loader': load_nc_data,
        'template':
            "{base_dir}/water_fluxes/MIROC-INTEG-ES/"
            "{var}_mon_MIROC_{exp}{run}{suffix}",
        'runs': ['01', '02', '03'],
        'suffix': {
            'early': "_1901-2014.nc_0.9x1.25_yearmean_1901_1930_timmean",
            'late':  "_1901-2014.nc_0.9x1.25_yearmean_1985_2014_timmean"
        },
        'exp_map': {
            'irr': {'exp': "IRR", 'dir_prefix': "tranirr-"},
            'noi': {'exp': "NOI", 'dir_prefix': "1901irr-"}
        }
    }
}

# -------------------------------------------------------------------
# 2) One function to load & compute all three forcings
# -------------------------------------------------------------------
def get_forcings(model_name, variable):
    cfg = MODEL_CONFIG[model_name]
    loader = cfg['loader']
    runs   = cfg['runs']
    suffix = cfg['suffix']
    exp_map = cfg['exp_map']
    base_dir = BASE_DIR  

    # Accumulate each data array in lists
    d = {
      'irr_early': [], 'irr_late': [],
      'noi_early': [], 'noi_late': []
    }

    for exp_key in ('irr','noi'):
        exp_info = exp_map[exp_key]
        for run in runs:
            # build formatting dict
            fmt = {
                'base_dir': base_dir,
                'var':      variable,
                'run':      run,
                'suffix':   None  # fill in below
            }
            # pull in experiment code(s)
            fmt.update(exp_info)

            # two periods
            for period in ('early','late'):
                fmt['suffix'] = suffix[period]
                path = cfg['template'].format(**fmt)
                arr  = loader(path, variable)
                d[f"{exp_key}_{period}"].append(arr)

    # stack into (n_runs, ...) arrays
    irr_early = np.stack(d['irr_early'])
    irr_late  = np.stack(d['irr_late'])
    noi_early = np.stack(d['noi_early'])
    noi_late  = np.stack(d['noi_late'])

    # compute the three forcings
    all_forcings = np.mean(irr_late  - noi_early, axis=0) * F_S2Y
    oth_forcings = np.mean(noi_late  - noi_early, axis=0) * F_S2Y
    irr_forcings = np.mean(irr_late  - noi_late,  axis=0) * F_S2Y

    return all_forcings, oth_forcings, irr_forcings

In [63]:
# define which “pr” components each model needs
PR_COMPONENTS = {
    'CESM2':     ['RAIN_FROM_ATM', 'SNOW_FROM_ATM'],
    'CESM2_gw':  ['RAIN_FROM_ATM', 'SNOW_FROM_ATM'],
    'NorESM':    ['RAIN_FROM_ATM', 'SNOW_FROM_ATM'],
    'E3SM':      ['RAIN',           'SNOW'],
    'IPSL':      ['pr'],
    'CNRM':      ['pr'],
    'MIROC':     ['pr'],
}

def get_composite_forcings(model_name, vars):
    """Sum the individual get_forcings(model, var) results over a list of vars."""
    all_f = oth_f = irr_f = None
    for v in vars:
        a, o, i = get_forcings(model_name, v)
        if all_f is None:
            all_f, oth_f, irr_f = a, o, i
        else:
            all_f += a
            oth_f += o
            irr_f += i
    return all_f, oth_f, irr_f

# now run them all in a loop
results_pr = {}
for mdl, comps in PR_COMPONENTS.items():
    results_pr[mdl] = get_composite_forcings(mdl, comps)

# unpack if you really want individual names:
# all_forcings_cesm2_pr,     oth_forcings_cesm2_pr,     irr_forcings_cesm2_pr     = results_pr['CESM2']
# all_forcings_cesm2_gw_pr,  oth_forcings_cesm2_gw_pr,  irr_forcings_cesm2_gw_pr  = results_pr['CESM2_gw']
# all_forcings_noresm_pr,    oth_forcings_noresm_pr,    irr_forcings_noresm_pr    = results_pr['NorESM']
# all_forcings_e3sm_pr,      oth_forcings_e3sm_pr,      irr_forcings_e3sm_pr      = results_pr['E3SM']
# all_forcings_ipsl_pr,      oth_forcings_ipsl_pr,      irr_forcings_ipsl_pr      = results_pr['IPSL']
# all_forcings_cnrm_pr,      oth_forcings_cnrm_pr,      irr_forcings_cnrm_pr      = results_pr['CNRM']
# all_forcings_miroc_pr,     oth_forcings_miroc_pr,     irr_forcings_miroc_pr     = results_pr['MIROC']


In [64]:
# define which “et” components each model needs
ET_COMPONENTS = {
    'CESM2':     ['QFLX_EVAP_TOT'],
    'CESM2_gw':  ['QFLX_EVAP_TOT'],
    'NorESM':    ['QFLX_EVAP_TOT'],
    'E3SM':      ['QSOIL', 'QVEGE', 'QVEGT'],
    'IPSL':      ['evspsbl'],
    'CNRM':      ['evspsbl'],
    'MIROC':     ['evspsbl', 'tran'],
}


# now run them all in a loop
results_et = {}
for mdl, comps in ET_COMPONENTS.items():
    results_et[mdl] = get_composite_forcings(mdl, comps)

# unpack if you really want individual names:
# all_forcings_cesm2_et,     oth_forcings_cesm2_et,     irr_forcings_cesm2_et     = results_et['CESM2']
# all_forcings_cesm2_gw_et,  oth_forcings_cesm2_gw_et,  irr_forcings_cesm2_gw_et  = results_et['CESM2_gw']
# all_forcings_noresm_et,    oth_forcings_noresm_et,    irr_forcings_noresm_et    = results_et['NorESM']
# all_forcings_e3sm_et,      oth_forcings_e3sm_et,      irr_forcings_e3sm_et      = results_et['E3SM']
# all_forcings_ipsl_et,      oth_forcings_ipsl_et,      irr_forcings_ipsl_et      = results_et['IPSL']
# all_forcings_cnrm_et,      oth_forcings_cnrm_et,      irr_forcings_cnrm_et      = results_et['CNRM']
# all_forcings_miroc_et,     oth_forcings_miroc_et,     irr_forcings_miroc_et     = results_et['MIROC']


In [65]:
results_pr_et = {
    model: tuple(pr - et for pr, et in zip(results_pr[model], results_et[model]))
    for model in results_pr
}

In [66]:
# define which “pr” components each model needs
R_COMPONENTS = {
    'CESM2':     ['QRUNOFF'],
    'CESM2_gw':  ['QRUNOFF'],
    'NorESM':    ['QRUNOFF'],
    'E3SM':      ['QRUNOFF'],
    'IPSL':      ['mrro'],
    'CNRM':      ['mrro'],
    'MIROC':     ['mrro'],
}


# now run them all in a loop
results_r = {}
for mdl, comps in R_COMPONENTS.items():
    results_r[mdl] = get_composite_forcings(mdl, comps)

# unpack if you really want individual names:
# all_forcings_cesm2_r,     oth_forcings_cesm2_r,     irr_forcings_cesm2_r     = results_r['CESM2']
# all_forcings_cesm2_gw_r,  oth_forcings_cesm2_gw_r,  irr_forcings_cesm2_gw_r  = results_r['CESM2_gw']
# all_forcings_noresm_r,    oth_forcings_noresm_r,    irr_forcings_noresm_r    = results_r['NorESM']
# all_forcings_e3sm_r,      oth_forcings_e3sm_r,      irr_forcings_e3sm_r      = results_r['E3SM']
# all_forcings_ipsl_r,      oth_forcings_ipsl_r,      irr_forcings_ipsl_r      = results_r['IPSL']
# all_forcings_cnrm_r,      oth_forcings_cnrm_r,      irr_forcings_cnrm_r      = results_r['CNRM']
# all_forcings_miroc_r,     oth_forcings_miroc_r,     irr_forcings_miroc_r     = results_r['MIROC']


In [67]:
def compute_consistency(arrays, threshold=20):
    """
    Given a list of 2D arrays (all the same shape), returns a 2D
    consistency count: for each cell, how many arrays exceed +threshold
    or fall below –threshold, but only if the mean across arrays is
    also |mean| ≥ threshold.
    """
    stack = np.stack(arrays, axis=0)                # shape: (n_models, ny, nx)
    pos_counts = np.sum(stack >=  threshold, axis=0)
    neg_counts = np.sum(stack <= -threshold, axis=0)

    mean_map = np.mean(stack, axis=0)
    mask = np.abs(mean_map) >= threshold

    # zero out counts where mean is small
    pos_counts[~mask] = 0
    neg_counts[~mask] = 0

    # final consistency is the max of pos or neg counts
    return np.maximum(pos_counts, neg_counts)


def get_consistency(results, threshold=20):
    """
    forcing_dict: { model_name: (all_arr, oth_arr, irr_arr), … }
    Returns three 2D arrays: (all_consis, oth_consis, irr_consis).
    """
    
    forcing_dict = {
        'CESM2':    results['CESM2'],
        'CESM2_gw': results['CESM2_gw'],
        'NorESM':   results['NorESM'],
        'E3SM':     results['E3SM'],
        'MIROC':    results['MIROC'],
        'CNRM':     results['CNRM'],
        'IPSL':     results['IPSL'],
    }
    # unpack into three lists of arrays
    all_list = [vals[0] for vals in forcing_dict.values()]
    oth_list = [vals[1] for vals in forcing_dict.values()]
    irr_list = [vals[2] for vals in forcing_dict.values()]

    all_consis = compute_consistency(all_list, threshold)
    oth_consis = compute_consistency(oth_list, threshold)
    irr_consis = compute_consistency(irr_list, threshold)

    return all_consis, oth_consis, irr_consis


all_consis_pr, oth_consis_pr, irr_consis_pr = get_consistency(results_pr)
all_consis_et, oth_consis_et, irr_consis_et = get_consistency(results_et)
all_consis_pr_et, oth_consis_pr_et, irr_consis_pr_et = get_consistency(results_pr_et)
all_consis_r, oth_consis_r, irr_consis_r = get_consistency(results_r)

In [68]:
# unpack into three lists of 2D arrays
all_list = [vals[0] for vals in results_pr.values()]
oth_list = [vals[1] for vals in results_pr.values()]
irr_list = [vals[2] for vals in results_pr.values()]

# stack along a new first axis (n_models, ny, nx)
all_stack = np.stack(all_list, axis=0)
oth_stack = np.stack(oth_list, axis=0)
irr_stack = np.stack(irr_list, axis=0)

# take mean over the model axis
mean_all_pr = np.mean(all_stack, axis=0)
mean_oth_pr = np.mean(oth_stack, axis=0)
mean_irr_pr = np.mean(irr_stack, axis=0)

In [69]:
# unpack into three lists of 2D arrays
all_list = [vals[0] for vals in results_et.values()]
oth_list = [vals[1] for vals in results_et.values()]
irr_list = [vals[2] for vals in results_et.values()]

# stack along a new first axis (n_models, ny, nx)
all_stack = np.stack(all_list, axis=0)
oth_stack = np.stack(oth_list, axis=0)
irr_stack = np.stack(irr_list, axis=0)

# take mean over the model axis
mean_all_et = np.mean(all_stack, axis=0)
mean_oth_et = np.mean(oth_stack, axis=0)
mean_irr_et = np.mean(irr_stack, axis=0)

In [70]:
# unpack into three lists of 2D arrays
all_list = [vals[0] for vals in results_pr_et.values()]
oth_list = [vals[1] for vals in results_pr_et.values()]
irr_list = [vals[2] for vals in results_pr_et.values()]

# stack along a new first axis (n_models, ny, nx)
all_stack = np.stack(all_list, axis=0)
oth_stack = np.stack(oth_list, axis=0)
irr_stack = np.stack(irr_list, axis=0)

# take mean over the model axis
mean_all_pr_et = np.mean(all_stack, axis=0)
mean_oth_pr_et = np.mean(oth_stack, axis=0)
mean_irr_pr_et = np.mean(irr_stack, axis=0)

In [71]:
# unpack into three lists of 2D arrays
all_list = [vals[0] for vals in results_r.values()]
oth_list = [vals[1] for vals in results_r.values()]
irr_list = [vals[2] for vals in results_r.values()]

# stack along a new first axis (n_models, ny, nx)
all_stack = np.stack(all_list, axis=0)
oth_stack = np.stack(oth_list, axis=0)
irr_stack = np.stack(irr_list, axis=0)

# take mean over the model axis
mean_all_r = np.mean(all_stack, axis=0)
mean_oth_r = np.mean(oth_stack, axis=0)
mean_irr_r = np.mean(irr_stack, axis=0)

In [None]:
def plot_fre_single_custom_div(ax, data_xarray,data_signal, title, metric, cmap, levels, unit):
    
    # ——————————————————————————————————————————
    # 1) Load & cache your IPCC borders just once
    # ——————————————————————————————————————————
    sf = shapefile.Reader('IPCC-WGI-reference-regions-v4.shp')
    REGION_IDS = [32, 19, 37, 4]      # whichever polygons you need
    borders = [np.array(sf.shape(i).points).T for i in REGION_IDS]
    
    bwr = mpl.cm.get_cmap(cmap)
    vmin=0
    vmax=0.05
    alpha=0.05
    colors = [bwr(0.25),bwr(0.3),bwr(0.35), bwr(0.4), 'white', bwr(0.6), bwr(0.65),bwr(0.7), bwr(0.75)]
    levels = levels
    cmap=mpl.colors.ListedColormap(colors)
    im = data_xarray.plot(ax=ax, cmap=cmap,vmin=vmin,vmax=vmax,levels=levels, extend='both', add_colorbar=False, add_labels=False)
    
    cb = plt.colorbar(im,fraction= 0.3, pad= 0.04, extend='both', orientation='horizontal')
    cb.ax.tick_params(labelsize=20)
    cb.set_label(label = unit, fontsize=22)
    
    ax.coastlines(color='black', linewidth=0.5)
    ax.add_feature(cfeature.OCEAN, color='lightgrey')
    ax.set_title(title, loc='right', fontsize=26)
    ax.set_title(metric, loc='left', fontsize=20)
    
    for bx, by in borders:
        ax.plot(bx, by, 'k-', lw=1, alpha=0.5, transform=ccrs.PlateCarree())
    ax.contourf(data_lon, data_lat, data_signal, levels=[4.5, 7.5], hatches=['//'], colors= 'none')


    
proj=ccrs.PlateCarree()
fig = plt.figure(figsize=(12,24),dpi=300)
fig.subplots_adjust(hspace=0.1, wspace=0.1, left = 0.05, right = 0.95, top = 0.95, bottom = 0.05)


variable_for_show = 'a'
unit = '$\Delta$(P-ET) (mm/year)'
cmap_div = 'RdBu'
levels_div = [-50.0, -40.0, -30.0, -20.0, 20.0, 30.0, 40.0, 50.0]
data_xr_1 = xr.DataArray(mean_all_pr_et, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(311, projection=proj, frameon=True)
plot_fre_single_custom_div(ax, data_xr_1, all_consis_pr_et, 'all forcings', variable_for_show, cmap_div, levels_div, unit)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.text(0.01, 0.90, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')

variable_for_show = 'c'
data_xr_1 = xr.DataArray(mean_oth_pr_et, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(312, projection=proj, frameon=True)
plot_fre_single_custom_div(ax, data_xr_1, oth_consis_pr_et, 'other forcings', variable_for_show, cmap_div, levels_div, unit)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.text(0.01, 0.90, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')

variable_for_show = 'e'
data_xr_1 = xr.DataArray(mean_irr_pr_et, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(313, projection=proj, frameon=True)
plot_fre_single_custom_div(ax, data_xr_1, irr_consis_pr_et, 'irrigation expansion', variable_for_show, cmap_div, levels_div, unit)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.text(0.01, 0.90, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')

In [73]:
def get_composite_forcings_tws(model_name, vars):
    """Sum the individual get_forcings(model, var) results over a list of vars."""
    all_f = oth_f = irr_f = None
    for v in vars:
        a, o, i = get_forcings(model_name, v)
        if all_f is None:
            all_f, oth_f, irr_f = a/365/86400, o/365/86400, i/365/86400
        else:
            all_f += a/365/86400
            oth_f += o/365/86400
            irr_f += i/365/86400
    return all_f, oth_f, irr_f

# define which “pr” components each model needs
TWS_COMPONENTS = {
    
    'CESM2_gw':  ['TWS'],
    
    'E3SM':      ['TWS'],
    'IPSL':      ['mrtws'],
    'CNRM':      ['mrtws'],
    'MIROC':     ['tws'],
}


# now run them all in a loop
results_tws = {}
for mdl, comps in TWS_COMPONENTS.items():
    results_tws[mdl] = get_composite_forcings_tws(mdl, comps)

# unpack if you really want individual names:
# all_forcings_cesm2_tws,     oth_forcings_cesm2_tws,     irr_forcings_cesm2_tws     = results_tws['CESM2']
# all_forcings_cesm2_gw_tws,  oth_forcings_cesm2_gw_tws,  irr_forcings_cesm2_gw_tws  = results_tws['CESM2_gw']
# all_forcings_noresm_tws,    oth_forcings_noresm_tws,    irr_forcings_noresm_tws    = results_tws['NorESM']
# all_forcings_e3sm_tws,      oth_forcings_e3sm_tws,      irr_forcings_e3sm_tws      = results_tws['E3SM']
# all_forcings_ipsl_tws,      oth_forcings_ipsl_tws,      irr_forcings_ipsl_tws      = results_tws['IPSL']
# all_forcings_cnrm_tws,      oth_forcings_cnrm_tws,      irr_forcings_cnrm_tws      = results_tws['CNRM']
# all_forcings_miroc_tws,     oth_forcings_miroc_tws,     irr_forcings_miroc_tws     = results_tws['MIROC']

In [74]:
def compute_consistency_tws(arrays, threshold=50):
    """
    Given a list of 2D arrays (all the same shape), returns a 2D
    consistency count: for each cell, how many arrays exceed +threshold
    or fall below –threshold, but only if the mean across arrays is
    also |mean| ≥ threshold.
    """
    stack = np.stack(arrays, axis=0)                # shape: (n_models, ny, nx)
    pos_counts = np.sum(stack >=  threshold, axis=0)
    neg_counts = np.sum(stack <= -threshold, axis=0)

    mean_map = np.mean(stack, axis=0)
    mask = np.abs(mean_map) >= threshold

    # zero out counts where mean is small
    pos_counts[~mask] = 0
    neg_counts[~mask] = 0

    # final consistency is the max of pos or neg counts
    return np.maximum(pos_counts, neg_counts)


def get_consistency_tws(results, threshold=50):
    """
    forcing_dict: { model_name: (all_arr, oth_arr, irr_arr), … }
    Returns three 2D arrays: (all_consis, oth_consis, irr_consis).
    """
    
    forcing_dict = {
        
        'CESM2_gw': results['CESM2_gw'],
        
        'E3SM':     results['E3SM'],
        'MIROC':    results['MIROC'],
        'CNRM':     results['CNRM'],
        'IPSL':     results['IPSL'],
    }
    # unpack into three lists of arrays
    all_list = [vals[0] for vals in forcing_dict.values()]
    oth_list = [vals[1] for vals in forcing_dict.values()]
    irr_list = [vals[2] for vals in forcing_dict.values()]

    all_consis = compute_consistency_tws(all_list, threshold)
    oth_consis = compute_consistency_tws(oth_list, threshold)
    irr_consis = compute_consistency_tws(irr_list, threshold)

    return all_consis, oth_consis, irr_consis


all_consis_tws, oth_consis_tws, irr_consis_tws = get_consistency_tws(results_tws)

In [75]:
# unpack into three lists of 2D arrays
all_list = [vals[0] for vals in results_tws.values()]
oth_list = [vals[1] for vals in results_tws.values()]
irr_list = [vals[2] for vals in results_tws.values()]

# stack along a new first axis (n_models, ny, nx)
all_stack = np.stack(all_list, axis=0)
oth_stack = np.stack(oth_list, axis=0)
irr_stack = np.stack(irr_list, axis=0)

# take mean over the model axis
mean_all_tws = np.mean(all_stack, axis=0)
mean_oth_tws = np.mean(oth_stack, axis=0)
mean_irr_tws = np.mean(irr_stack, axis=0)

In [None]:
def plot_fre_single_custom_div_tws(ax, data_xarray,data_signal, title, metric, cmap, levels, unit):
    sf = shapefile.Reader('IPCC-WGI-reference-regions-v4.shp')
    REGION_IDS = [32, 19, 37, 4]      # whichever polygons you need
    borders = [np.array(sf.shape(i).points).T for i in REGION_IDS]
    
    bwr = mpl.cm.get_cmap(cmap)
    vmin=0
    vmax=0.05
    alpha=0.05
    colors = [bwr(0.25),bwr(0.3),bwr(0.35), bwr(0.4), 'white', bwr(0.6), bwr(0.65),bwr(0.7), bwr(0.75)]
    levels = levels
    cmap=mpl.colors.ListedColormap(colors)
    im = data_xarray.plot(ax=ax, cmap=cmap,vmin=vmin,vmax=vmax,levels=levels, extend='both', add_colorbar=False, add_labels=False)
    cb = plt.colorbar(im,fraction= 0.3, pad= 0.04, extend='both', orientation='horizontal')
    cb.ax.tick_params(labelsize=20)
    ax.coastlines(color='black', linewidth=0.5)
    ax.add_feature(cfeature.OCEAN, color='lightgrey')
    ax.set_title(title, loc='right', fontsize=26)
    ax.set_title(metric, loc='left', fontsize=20)
    for bx, by in borders:
        ax.plot(bx, by, 'k-', lw=1, alpha=0.5, transform=ccrs.PlateCarree())
    cb.set_label(label = unit, fontsize=22)
    ax.contourf(data_lon, data_lat, data_signal, levels=[3.5, 5.5], hatches=['///'], colors= 'none')




fig = plt.figure(figsize=(12,24),dpi=300)
fig.subplots_adjust(hspace=0.1, wspace=0.1, left = 0.05, right = 0.95, top = 0.95, bottom = 0.05)

proj=ccrs.PlateCarree()
variable_for_show = 'b'
unit = '$\Delta$TWS (mm)'
cmap_div = 'RdBu'
levels_div = [-300.0, -200.0, -100.0, -50, 50, 100.0, 200.0, 300.0]
data_xr_1 = xr.DataArray(mean_all_tws, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(3,1,1, projection=proj, frameon=True)
plot_fre_single_custom_div_tws(ax, data_xr_1, all_consis_tws, 'all forcings', variable_for_show, cmap_div, levels_div, unit)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.text(0.01, 0.90, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')

variable_for_show = 'd'
data_xr_1 = xr.DataArray(mean_oth_tws, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(3,1,2, projection=proj, frameon=True)
plot_fre_single_custom_div_tws(ax, data_xr_1, oth_consis_tws, 'other forcings', variable_for_show, cmap_div, levels_div, unit)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.text(0.01, 0.90, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')

variable_for_show = 'f'
data_xr_1 = xr.DataArray(mean_irr_tws, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(3,1,3, projection=proj, frameon=True)
plot_fre_single_custom_div_tws(ax, data_xr_1, irr_consis_tws, 'irrigation expansion', variable_for_show, cmap_div, levels_div, unit)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.text(0.01, 0.90, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')

In [None]:
fig = plt.figure(figsize=(24,16),dpi=300)
fig.subplots_adjust(hspace=0.1, wspace=0.1, left = 0.05, right = 0.95, top = 0.95, bottom = 0.05)

variable_for_show = 'a'
unit = '$\Delta$P mm/year'
cmap_div = 'RdBu'
levels_div = [-50.0, -40.0, -30.0, -20.0, 20.0, 30.0, 40.0, 50.0]

proj=ccrs.PlateCarree()
data_xr_1 = xr.DataArray(mean_all_pr, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(331, projection=proj, frameon=True)

plot_fre_single_custom_div(ax, data_xr_1, all_consis_pr, 'all forcings', variable_for_show, cmap_div, levels_div, unit)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.text(0.01, 0.90, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')

variable_for_show = 'd'
data_xr_1 = xr.DataArray(mean_oth_pr, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(334, projection=proj, frameon=True)
plot_fre_single_custom_div(ax, data_xr_1, oth_consis_pr, 'other forcings', variable_for_show, cmap_div, levels_div, unit)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.text(0.01, 0.90, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')

variable_for_show = 'g'
data_xr_1 = xr.DataArray(mean_irr_pr, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(337, projection=proj, frameon=True)
plot_fre_single_custom_div(ax, data_xr_1, irr_consis_pr, 'irrigation expansion', variable_for_show, cmap_div, levels_div, unit)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.text(0.01, 0.90, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')



variable_for_show = '$\Delta$ET'
unit = '$\Delta$ET (mm/year)'
cmap_div = 'RdBu'

variable_for_show = 'b'

data_xr_1 = xr.DataArray(mean_all_et, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(332, projection=proj, frameon=True)
plot_fre_single_custom_div(ax, data_xr_1, all_consis_et, 'all forcings', variable_for_show, cmap_div, levels_div, unit)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.text(0.01, 0.90, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')


variable_for_show = 'e'

data_xr_1 = xr.DataArray(mean_oth_et, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(335, projection=proj, frameon=True)
plot_fre_single_custom_div(ax, data_xr_1, oth_consis_et, 'other forcings', variable_for_show, cmap_div, levels_div, unit)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.text(0.01, 0.90, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')


variable_for_show = 'h'
data_xr_1 = xr.DataArray(mean_irr_et, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(338, projection=proj, frameon=True)
plot_fre_single_custom_div(ax, data_xr_1, irr_consis_et, 'irrigation expansion', variable_for_show, cmap_div, levels_div, unit)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.text(0.01, 0.90, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')


variable_for_show = ''
unit = '$\Delta$R (mm/year)'
cmap_div = 'RdBu'

variable_for_show = 'c'
data_xr_1 = xr.DataArray(mean_all_r, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(333, projection=proj, frameon=True)
plot_fre_single_custom_div(ax, data_xr_1, all_consis_r, 'all forcings', variable_for_show, cmap_div, levels_div, unit)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.text(0.01, 0.90, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')

variable_for_show = 'f'
data_xr_1 = xr.DataArray(mean_oth_r, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(336, projection=proj, frameon=True)
plot_fre_single_custom_div(ax, data_xr_1, oth_consis_r, 'other forcings', variable_for_show, cmap_div, levels_div, unit)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.text(0.01, 0.90, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')

variable_for_show = 'i'
data_xr_1 = xr.DataArray(mean_irr_r, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(339, projection=proj, frameon=True)
plot_fre_single_custom_div(ax, data_xr_1, irr_consis_r, 'irrigation expansion', variable_for_show, cmap_div, levels_div, unit)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.text(0.01, 0.90, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')

In [81]:
BASE_DIR = '/dodrio/scratch/projects/2022_200/project_output/cesm/yi_yao_IRRMIP'


IWW_IRR_1901_1930_CESM2 = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/CESM2/CESM2_IRR_YEARLYQIRRIG_timmean_1901_1930.nc', 'QIRRIG')
IWW_IRR_1985_2014_CESM2 = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/CESM2/CESM2_IRR_YEARLYQIRRIG_timmean_1985_2014.nc', 'QIRRIG')

IWW_IRR_1901_1930_CESM2_gw = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/CESM2_gw/CESM2_gw_IRR_YEARLYQIRRIG_timmean_1901_1930.nc', 'QIRRIG')
IWW_IRR_1985_2014_CESM2_gw = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/CESM2_gw/CESM2_gw_IRR_YEARLYQIRRIG_timmean_1985_2014.nc', 'QIRRIG')

IWW_IRR_1901_1930_NorESM = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/NorESM2/NorESM_IRR_YEARLYQIRRIG_timmean_1901_1930.nc', 'QIRRIG')
IWW_IRR_1985_2014_NorESM = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/NorESM2/NorESM_IRR_YEARLYQIRRIG_timmean_1985_2014.nc', 'QIRRIG')


IWW_IRR_1901_1930_IPSL = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/IPSL-CM6/IRR01_1901_2014_irr_Day_Month.nc_1901_1930_timmean_YEARLY_0.9x1.25', 'irr')
IWW_IRR_1985_2014_IPSL = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/IPSL-CM6/IRR01_1901_2014_irr_Day_Month.nc_1985_2014_timmean_YEARLY_0.9x1.25', 'irr')

IWW_IRR_1901_1930_E3SM = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/E3SMv2/IRR_QIRRIG_REAL_timmean_1901_1930.nc_0.9x1.25_YEARLY', 'QIRRIG_REAL')
IWW_IRR_1985_2014_E3SM = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/E3SMv2/IRR_QIRRIG_REAL_timmean_1985_2014.nc_0.9x1.25_YEARLY', 'QIRRIG_REAL')

IWW_IRR_1901_1930_CNRM = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/CNRM-CM6-1/airrww_mon_CNRM-CM6-1_hist-irr_r1i1p1f2_gr_190101-201412.nc_yearmean_YEARLY_1901_1930_timmean_0.9x1.25', 'airrww')
IWW_IRR_1985_2014_CNRM = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/CNRM-CM6-1/airrww_mon_CNRM-CM6-1_hist-irr_r1i1p1f2_gr_190101-201412.nc_yearmean_YEARLY_1985_2014_timmean_0.9x1.25', 'airrww')

# MIROC is the last model which was added to the analysis later so I forgot to do merge among ensemble members
IWW_IRR01_1901_1930_MIROC = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/MIROC-INTEG-ES/tranirr-01/after_remap/monthly/irrac_mon_MIROC_IRR01_1901-2014.nc_0.9x1.25_yearmean_1901_1930_timmean', 'irrac')
IWW_IRR01_1985_2014_MIROC = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/MIROC-INTEG-ES/tranirr-01/after_remap/monthly/irrac_mon_MIROC_IRR01_1901-2014.nc_0.9x1.25_yearmean_1985_2014_timmean', 'irrac')

IWW_NOI01_1901_1930_MIROC = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/MIROC-INTEG-ES/1901irr-01/after_remap/monthly/irrac_mon_MIROC_NOI01_1901-2014.nc_0.9x1.25_yearmean_1901_1930_timmean', 'irrac')
IWW_NOI01_1985_2014_MIROC = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/MIROC-INTEG-ES/1901irr-01/after_remap/monthly/irrac_mon_MIROC_NOI01_1901-2014.nc_0.9x1.25_yearmean_1985_2014_timmean', 'irrac')

IWW_IRR02_1901_1930_MIROC = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/MIROC-INTEG-ES/tranirr-02/after_remap/monthly/irrac_mon_MIROC_IRR02_1901-2014.nc_0.9x1.25_yearmean_1901_1930_timmean', 'irrac')
IWW_IRR02_1985_2014_MIROC = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/MIROC-INTEG-ES/tranirr-02/after_remap/monthly/irrac_mon_MIROC_IRR02_1901-2014.nc_0.9x1.25_yearmean_1985_2014_timmean', 'irrac')

IWW_NOI02_1901_1930_MIROC = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/MIROC-INTEG-ES/1901irr-02/after_remap/monthly/irrac_mon_MIROC_NOI02_1901-2014.nc_0.9x1.25_yearmean_1901_1930_timmean', 'irrac')
IWW_NOI02_1985_2014_MIROC = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/MIROC-INTEG-ES/1901irr-02/after_remap/monthly/irrac_mon_MIROC_NOI02_1901-2014.nc_0.9x1.25_yearmean_1985_2014_timmean', 'irrac')

IWW_IRR03_1901_1930_MIROC = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/MIROC-INTEG-ES/tranirr-03/after_remap/monthly/irrac_mon_MIROC_IRR03_1901-2014.nc_0.9x1.25_yearmean_1901_1930_timmean', 'irrac')
IWW_IRR03_1985_2014_MIROC = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/MIROC-INTEG-ES/tranirr-03/after_remap/monthly/irrac_mon_MIROC_IRR03_1901-2014.nc_0.9x1.25_yearmean_1985_2014_timmean', 'irrac')

IWW_NOI03_1901_1930_MIROC = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/MIROC-INTEG-ES/1901irr-03/after_remap/monthly/irrac_mon_MIROC_NOI03_1901-2014.nc_0.9x1.25_yearmean_1901_1930_timmean', 'irrac')
IWW_NOI03_1985_2014_MIROC = load_nc_data(f'{BASE_DIR}/irrigation_water_withdrawal/MIROC-INTEG-ES/1901irr-03/after_remap/monthly/irrac_mon_MIROC_NOI03_1901-2014.nc_0.9x1.25_yearmean_1985_2014_timmean', 'irrac')

IWW_IRR_1901_1930_MIROC = (IWW_IRR01_1901_1930_MIROC + IWW_IRR02_1901_1930_MIROC + IWW_IRR03_1901_1930_MIROC) / 3 * 365 * 86400
IWW_IRR_1985_2014_MIROC = (IWW_IRR01_1985_2014_MIROC + IWW_IRR02_1985_2014_MIROC + IWW_IRR03_1985_2014_MIROC) / 3 * 365 * 86400

IWW_IRR_1901_1930 = (IWW_IRR_1901_1930_CESM2+IWW_IRR_1901_1930_CESM2_gw+IWW_IRR_1901_1930_NorESM+IWW_IRR_1901_1930_MIROC+IWW_IRR_1901_1930_E3SM+IWW_IRR_1901_1930_CNRM+IWW_IRR_1901_1930_IPSL)/7
IWW_IRR_1985_2014 = (IWW_IRR_1985_2014_CESM2+IWW_IRR_1985_2014_CESM2_gw+IWW_IRR_1985_2014_NorESM+IWW_IRR_1985_2014_MIROC+IWW_IRR_1985_2014_E3SM+IWW_IRR_1985_2014_CNRM+IWW_IRR_1985_2014_IPSL)/7

globe_irr_land_1901 = load_mat_data(f'{BASE_DIR}/area_quipped_for_irrigation/irrigated_GRID_1901.mat', 'irrigated_GRID_1901')
globe_irr_land_1941 = load_mat_data(f'{BASE_DIR}/area_quipped_for_irrigation/irrigated_GRID_1941.mat', 'irrigated_GRID_1941')
globe_irr_land_1981 = load_mat_data(f'{BASE_DIR}/area_quipped_for_irrigation/irrigated_GRID_1981.mat', 'irrigated_GRID_1981')
globe_irr_land_2014 = load_mat_data(f'{BASE_DIR}/area_quipped_for_irrigation/irrigated_GRID_2014.mat', 'irrigated_GRID_2014')

globe_irr_land_1901[np.isnan(IWW_IRR_1901_1930_CESM2.T)] = np.nan # mask ocean grid cells
globe_irr_land_1941[np.isnan(IWW_IRR_1901_1930_CESM2.T)] = np.nan
globe_irr_land_1981[np.isnan(IWW_IRR_1901_1930_CESM2.T)] = np.nan
globe_irr_land_2014[np.isnan(IWW_IRR_1901_1930_CESM2.T)] = np.nan

In [87]:
def plot_fre_single(ax, data_xarray, title, metric):
    sf = shapefile.Reader('IPCC-WGI-reference-regions-v4.shp')
    REGION_IDS = [32, 19, 37, 4]      # whichever polygons you need
    borders = [np.array(sf.shape(i).points).T for i in REGION_IDS]
    
    bwr = mpl.cm.get_cmap('Blues')
    vmin=0
    vmax=0.05
    alpha=0.05
    colors = [bwr(0),bwr(0.2),bwr(0.3), bwr(0.4), bwr(0.5), bwr(0.6), bwr(0.7), bwr(0.8),bwr(0.9), bwr(1.0)]
    levels = [1, 10, 25, 50, 75, 100, 150, 200, 250]
    cmap=mpl.colors.ListedColormap(colors)
    im = data_xarray.plot(ax=ax, cmap=cmap,vmin=vmin,vmax=vmax,levels=levels, extend='both', add_colorbar=False, add_labels=False)
    
    cb = plt.colorbar(im,fraction= 0.02, pad= 0.04, extend='both')
    cb.ax.tick_params(labelsize=16)
    ax.coastlines(color='white', linewidth=0.5)
    ax.add_feature(cfeature.OCEAN, color='lightgrey')
    ax.set_title(title, loc='right', fontsize=18)
    ax.set_title(metric, loc='left', fontsize=18)
    
    for bx, by in borders:
        ax.plot(bx, by, 'g-', lw=3, alpha=0.5, transform=ccrs.PlateCarree())
    

    
def plot_fre_single_2(ax, data_xarray, title, metric):
    sf = shapefile.Reader('IPCC-WGI-reference-regions-v4.shp')
    REGION_IDS = [32, 19, 37, 4]      # whichever polygons you need
    borders = [np.array(sf.shape(i).points).T for i in REGION_IDS]
    
    bwr = mpl.cm.get_cmap('Reds')
    vmin=0
    vmax=0.05
    alpha=0.05
    colors = ['white',bwr(0.2),bwr(0.3), bwr(0.4), bwr(0.5), bwr(0.6), bwr(0.7), bwr(0.8),bwr(0.9), bwr(1.0)]
    levels = [0.01, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40]
    cmap=mpl.colors.ListedColormap(colors)
    im = data_xarray.plot(ax=ax, cmap=cmap,vmin=vmin,vmax=vmax,levels=levels, extend='both', add_colorbar=False, add_labels=False)
    
    cb = plt.colorbar(im,fraction= 0.02, pad= 0.04, extend='both')
    cb.ax.tick_params(labelsize=16)
    cb.set_ticklabels(['1%','5%','10%','15%','20%','25%','30%','35%','40%'])
    ax.coastlines(color='white', linewidth=0.5)
    ax.add_feature(cfeature.OCEAN, color='lightgrey')
    ax.set_title(title, loc='right', fontsize=18)
    ax.set_title(metric, loc='left', fontsize=18)
    
    for bx, by in borders:
        ax.plot(bx, by, 'g-', lw=3, alpha=0.5, transform=ccrs.PlateCarree())
    
def plot_land_single(ax, data_xarray, title, metric):
    sf = shapefile.Reader('IPCC-WGI-reference-regions-v4.shp')
    REGION_IDS = [32, 19, 37, 4]      # whichever polygons you need
    borders = [np.array(sf.shape(i).points).T for i in REGION_IDS]
    
    bwr = mpl.cm.get_cmap('Greens')
    vmin=0
    vmax=0.05
    alpha=0.05
    colors = ['white',bwr(0.2),bwr(0.3), bwr(0.4), bwr(0.5), bwr(0.6), bwr(0.7), bwr(0.8),bwr(0.9), bwr(1.0)]
    levels = [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    cmap=mpl.colors.ListedColormap(colors)
    im = data_xarray.plot(ax=ax, cmap=cmap,vmin=vmin,vmax=vmax,levels=levels, extend='both', add_colorbar=False, add_labels=False)
    
    cb = plt.colorbar(im,fraction= 0.02, pad= 0.04, extend='both')
    
    cb.ax.tick_params(labelsize=16)
    cb.set_ticklabels(['1%','10%','20%','30%','40%','50%','60%','70%','80%'])
    ax.coastlines(color='dimgray', linewidth=0.5)
    ax.add_feature(cfeature.OCEAN, color='lightgrey')
    ax.set_title(title, loc='right', fontsize=18)
    ax.set_title(metric, loc='left', fontsize=18)
    
    for bx, by in borders:
        ax.plot(bx, by, 'r-', lw=3, alpha=0.5, transform=ccrs.PlateCarree())


In [None]:
irr_diff = load_mat_data(f'{BASE_DIR}/irr_diff_out.mat', 'irr_diff_out').T
data_xr_1 = xr.DataArray(globe_irr_land_1901.T, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
fig = plt.figure(figsize=(20,18),dpi=300)
fig.subplots_adjust(hspace=0.2, wspace=0.05, left = 0.05, right = 0.95, top = 0.95, bottom = 0.05)

proj=ccrs.PlateCarree()
ax = plt.subplot(421, projection=proj, frameon=True)
plot_land_single(ax, data_xr_1, 'irrigated fraction 1901', 'a')
ax.text(0.005, 0.92, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.coastlines(color='dimgray', linewidth=0.5)


data_xr_1 = xr.DataArray(globe_irr_land_2014.T, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(422, projection=proj, frameon=True)
ax.text(0.005, 0.92, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')
plot_land_single(ax, data_xr_1, 'irrigated fraction 2014', 'b')
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.coastlines(color='dimgray', linewidth=0.5)

irr_diff = (irr_diff + IWW_IRR_1901_1930 - IWW_IRR_1901_1930)/100
data_xr_1 = xr.DataArray(irr_diff, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(423, projection=proj, frameon=True)
ax.text(0.005, 0.92, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')
plot_fre_single_2(ax, data_xr_1, 'irrigated fraction 2014-1901', 'c')
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.coastlines(color='dimgray', linewidth=0.5)



data_xr_1 = xr.DataArray(IWW_IRR_1901_1930, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(425, projection=proj, frameon=True)
ax.text(0.005, 0.92, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')
plot_fre_single(ax, data_xr_1, 'Irrigation water withdrawal ($\mathregular{mm/yr}$) tranirr 1901-1930', 'e')
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.coastlines(color='dimgray', linewidth=0.5)

data_xr_1 = xr.DataArray(IWW_IRR_1985_2014, coords={'y': data_lat, 'x': data_lon}, dims=["y", "x"])
ax = plt.subplot(426, projection=proj, frameon=True)
ax.text(0.005, 0.92, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')
plot_fre_single(ax, data_xr_1, 'Irrigation water withdrawal ($\mathregular{mm/yr}$) tranirr 1985-2014', 'f')
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.coastlines(color='dimgray', linewidth=0.5)



TWS_dict = scio.loadmat(f'{BASE_DIR}/grace_datasets/tes_trend.mat')
TWS = TWS_dict['tws_trend']




lat_dict = scio.loadmat(f'{BASE_DIR}/grace_datasets/lat.mat')
lat = lat_dict['lat']
lat = lat[:,0]

lon_dict = scio.loadmat(f'{BASE_DIR}/grace_datasets/lon.mat')
lon = lon_dict['lon']
lon = lon[:,0]

trend_dict = scio.loadmat(f'{BASE_DIR}/grace_datasets/trend.mat')
trend = trend_dict['trend']

Rsquare_dict = scio.loadmat(f'{BASE_DIR}/grace_datasets/Rsquare.mat')
Rsquare = Rsquare_dict['Rsquare']

pvalue_dict = scio.loadmat(f'{BASE_DIR}/grace_datasets/pvalue.mat')
pvalue = pvalue_dict['pvalue']

file_obj = nc.Dataset(f'{BASE_DIR}/grace_datasets/land_mask.nc')


LANDMASK = file_obj.variables['LANDMASK']
LANDMASK = np.array(LANDMASK)
LANDMASK = np.roll(LANDMASK, 720)


# trend[Rsquare<0.6] = np.nan
trend[pvalue>0.1] = np.nan
trend[LANDMASK.T == 0] = np.nan
data_xr_1 = xr.DataArray(trend.T[108:, :] * 365 * 10, coords={'y':lat[108:], 'x': lon}, dims=["y", "x"])

ax = plt.subplot(424, projection=proj, frameon=True)
bwr = mpl.cm.get_cmap('RdBu')
vmin=0
vmax=0.05
alpha=0.05
colors = [bwr(0.2),bwr(0.25),bwr(0.3),bwr(0.35), bwr(0.4), 'white', bwr(0.6), bwr(0.65),bwr(0.7), bwr(0.75), bwr(0.8)]
levels = [-20, -10, -5, -2, -1, 1, 2, 5, 10, 20]
cmap=mpl.colors.ListedColormap(colors)
im = data_xr_1.plot(ax=ax, cmap=cmap,vmin=vmin,vmax=vmax,levels=levels, extend='both', add_colorbar=False, add_labels=False)
cb = plt.colorbar(im,fraction= 0.02, pad= 0.04, extend='both')
cb.ax.tick_params(labelsize=16)
cmap.set_bad(color='lightgrey')
sf = shapefile.Reader('IPCC-WGI-reference-regions-v4.shp')
REGION_IDS = [32, 19, 37, 4]      # whichever polygons you need
borders = [np.array(sf.shape(i).points).T for i in REGION_IDS]

for bx, by in borders:
    ax.plot(bx, by, 'r-', lw=3, alpha=0.5, transform=ccrs.PlateCarree())

# cb.set_label(label = 'TWS trend', fontsize=12)
ax.coastlines(color='black', linewidth=0.5)
ax.add_feature(cfeature.OCEAN, color='whitesmoke')
ax.set_title('GRACE TWS trend (mm/yr) (04/2002-02/2024)', loc='right', fontsize=18)
ax.set_title('d', loc='left', fontsize=18)
ax.text(0.005, 0.92, '', color='dimgrey', fontsize=16, transform=ax.transAxes, weight='bold')
ax.coastlines(color='dimgray', linewidth=0.5)