In [1]:
%matplotlib inline       
import scipy.io as scio
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.feature as cfeature
import cartopy.crs as ccrs
import netCDF4 as nc
import numpy as np
from Load_data import Data_from_nc
import xarray as xr
import math
import gc
import os
from scipy.signal import savgol_filter
# mpl.use('Agg')

def load_nc_data(filepath, var_name, skip_lat=29, invalid_threshold=1e6):
    """
    Load and clean a variable from a NetCDF file.
    - Skips the first `skip_lat` latitudes.
    - Masks values with absolute > invalid_threshold as NaN.
    """
    with nc.Dataset(filepath) as ds:
        data = np.array(ds.variables[var_name])
    # slice off unwanted latitudes
    data = data[:, skip_lat:, :]
    # mask invalid values
    mask = np.abs(data) > invalid_threshold
    data[mask] = np.nan
    return np.squeeze(data)


def load_mat_data(filepath, var_name, transpose=False, skip_first=None):
    """
    Load a MATLAB variable and optionally slice or transpose.
    - skip_first: number of leading rows to drop (None to skip slicing).
    - transpose: whether to transpose the result.
    """
    mat = scio.loadmat(filepath)[var_name]
    if skip_first is not None:
        mat = mat[:, skip_first:]
    return mat.T if transpose else mat


def calc_global_sum(data, area, region_mask=None):
    """
    Calculate global (or regional) sum of a 3D field:
    - data: time x lat x lon array
    - area: 2D lat x lon area array (transposed to lon x lat if needed)
    - region_mask: boolean mask of same spatial shape; non-True pixels are ignored
    Returns: 1D array over time
    """
    # align area dims
    arr = data * (area.T / 1e6)
    if region_mask is not None:
        arr = np.where(region_mask[np.newaxis, ...], arr, np.nan)
    # sum lat/lon dims
    return np.nansum(np.nansum(arr, axis=-1), axis=-1)


def calc_area(global_map, region_mask=None):
    """
    Calculate total irrigated area over time, optionally for a region.
    - global_map: 3D time x lat x lon map
    - region_mask: boolean mask of same lat x lon
    Returns: 1D array over time
    """
    arr = global_map.copy()
    if region_mask is not None:
        arr[:, ~region_mask] = np.nan
    return np.nansum(np.nansum(arr, axis=-1), axis=-1)


# === Constants & Filepaths ===
# TODO! 
# TODO! 
# TODO! 
# TODO! 
# TODO! 
BASE_DIR = '' # Fill the path to the files

AREA = get_data_from_mat(os.path.join(BASE_DIR, 'plotting_tools/AREA.mat'), 'AREA').T  # [lat, lon]
AR6_REGION = get_data_from_mat(os.path.join(BASE_DIR, 'plotting_tools/ar6_region.mat'), 'ar6_region').T  # [lat, lon]

GLOBAL_MAP_FILE = f'{BASE_DIR}/area_quipped_for_irrigation/global_irri_land_map.mat'
# Load static datasets
AREA = load_mat_data(AREA_FILE, 'AREA', transpose=True)
ar6_region = load_mat_data(REGION_FILE, 'ar6_region', transpose=True, skip_first=29)

# Load global irrigation land map once
global_map = load_mat_data(GLOBAL_MAP_FILE, 'global_irri_land_map')
# reorder dims to time, lat, lon and drop first 29 lats
global_map = np.transpose(global_map, (2, 1, 0))[:, 29:, :]

# Calculate global irrigated land over time
global_irri_land = calc_area(global_map)

# Define AR6 region IDs of interest
AR6_REGIONS = {
    'WNA': 4,  'CNA': 5,  'NCA': 7,  'MED': 20,
    'WCA': 33, 'EAS': 36,'SAS': 38, 'SEA': 39
}

# Compute regional irrigation area
regional_irri_land = {}
for name, rid in AR6_REGIONS.items():
    mask = (ar6_region == rid)
    regional_irri_land[name] = calc_area(global_map, region_mask=mask)

# Example: loading additional data
irr_diff = load_mat_data(f'{BASE_DIR}/plotting_tools/irr_diff_out.mat', 'irr_diff_out', transpose=True, skip_first=29)

# Loading NetCDF variables containing lat and lon information
nc_file = f'{BASE_DIR}/plotting_tools/surfdata_irrigation_method.nc'
surface = Data_from_nc(nc_file)
irrigation_method = surface.load_variable('irrigation_method')
lon = surface.load_variable('LONGXY')[0, :]
lat = surface.load_variable('LATIXY')[29:, 0]

# Now you can call `load_nc_data`, `calc_global_sum`, and plot or analyze.


In [3]:
# === Directory & Static Maps ===


REGION_IDS = {'WNA':4, 'CNA':5, 'NCA':7, 'MED':20, 'WCA':33, 'EAS':36, 'SAS':38, 'SEA':39} # Not all used here

# === IWW Calculation Helpers ===
def compute_global_iww(series: np.ndarray) -> np.ndarray:
    """
    Compute global IWW time series via area weighting.
    series: [time, lat, lon]
    """
    mask = np.isnan(series[0])
    weighted = np.where(mask[None], np.nan, series * AREA)
    return np.nansum(np.nansum(weighted, axis=2), axis=1) / np.nansum(AREA)


def compute_regional_iww(series: np.ndarray, region_id: int) -> np.ndarray:
    """
    Compute regional IWW time series for given AR6 region.
    """
    mask = np.isnan(series[0])
    region_mask = (AR6_REGION == region_id)
    area_masked = np.where(region_mask, AREA, np.nan)
    weighted = np.where(mask[None], np.nan, series * area_masked)
    return np.nansum(np.nansum(weighted, axis=2), axis=1) / np.nansum(area_masked)

# === Experiment Configurations ===
EXPERIMENTS = [
    # model,     subdir,                                       var,           irr_file,                                           noi_file
    ('CESM2',   'irrigation_water_withdrawal/CESM2',          'QIRRIG',      'CESM2_tranirr_timeseries_1901_2014.nc',            'CESM2_1901irr_timeseries_1901_2014.nc'),
    ('CESM2_gw','irrigation_water_withdrawal/CESM2_gw',       'QIRRIG',      'CESM2_gw_tranirr_timeseries_1901_2014.nc',         'CESM2_gw_1901irr_timeseries_1901_2014.nc'),
    ('NorESM',  'irrigation_water_withdrawal/NorESM2',        'QIRRIG',      'NorESM2_tranirr_timeseries_1901_2014.nc',          'NorESM2_1901irr_timeseries_1901_2014.nc'),
    ('IPSL',    'irrigation_water_withdrawal/IPSL-CM6',       'irr',         'IPSL-CM6_tranirr_timeseries_1901_2014.nc',         None),
    ('CNRM',    'irrigation_water_withdrawal/CNRM-CM6-1',     'airrww',      'CNRM-CM6-1_tranirr_timeseries_1901_2014.nc',       'CNRM-CM6-1_1901irr_timeseries_1901_2014.nc'),
    ('E3SM',    'irrigation_water_withdrawal/E3SMv2',         'QIRRIG_REAL', 'E3SMv2_tranirr_timeseries_1901_2014.nc',           'E3SMv2_1901irr_timeseries_1901_2014.nc'),
    ('MIROC',   'irrigation_water_withdrawal/MIROC-INTEG-ES', 'irrac',       'MIROC-INTEG-ES_tranirr_timeseries_1901_2014.nc',   'MIROC-INTEG-ES_1901irr_timeseries_1901_2014.nc'),
]

# === Pipeline Execution ===
results = {}
for model, subdir, var, irr_file, noi_file in EXPERIMENTS:
    folder = os.path.join(BASE_DIR, subdir)
    # IRR scenario
    irr_path = os.path.join(folder, irr_file)
    irr_series = load_nc_data(irr_path, var)
    results[f'{model}_global_irr'] = compute_global_iww(irr_series)
    for name, rid in REGION_IDS.items():
        results[f'{model}_{name}_irr'] = compute_regional_iww(irr_series, rid)
    # NOI scenario (if provided)
    if noi_file:
        noi_path = os.path.join(folder, noi_file)
        noi_series = load_nc_data(noi_path, var)
        results[f'{model}_global_noi'] = compute_global_iww(noi_series)
        for name, rid in REGION_IDS.items():
            results[f'{model}_{name}_noi'] = compute_regional_iww(noi_series, rid)

# `results` now contains all global and regional IWW series for IRR/NOI.


In [4]:
# 1️⃣ Model‐specific settings, including NOI availability
model_params = {
    'CESM2':    {'slice_end': -1,   'mult': 1,           'has_noi': True},
    'CESM2_gw': {'slice_end': -1,   'mult': 1,           'has_noi': True},
    'NorESM':   {'slice_end': -1,   'mult': 1,           'has_noi': True},
    'IPSL':     {'slice_end': None, 'mult': 1,           'has_noi': False},
    'E3SM':     {'slice_end': -1,   'mult': 1,           'has_noi': True},
    'MIROC':    {'slice_end': None, 'mult': 365*86400,   'has_noi': True},
    'CNRM':     {'slice_end': None, 'mult': 1,           'has_noi': True},
}

regions   = ['SAS', 'MED', 'CNA', 'WCA']
scenarios = ['irr', 'noi']

# 2️⃣ Stack the series for each (region,scenario)
IWW_stack = {}
for region in regions:
    for scen in scenarios:
        rows = []
        for model, params in model_params.items():
            # skip NOI for models without NOI
            if scen == 'noi' and not params['has_noi']:
                continue
            key = f"{model}_{region}_{scen}"
            if key not in results:
                continue
            arr = results[key]
            if params['slice_end'] is not None:
                arr = arr[:params['slice_end']]
            rows.append(arr * params['mult'])
        IWW_stack[(region, scen)] = np.vstack(rows)

# 3️⃣ Compute summary stats (25th, median, 75th, mean)
IWW_stats = {}
for region in regions:
    for scen in scenarios:
        data = IWW_stack[(region, scen)]
        IWW_stats[(region, scen)] = {
            'p25':    np.percentile(data, 0, axis=0), # finally we decided to not use 25th and 75th here
            'median': np.percentile(data, 50, axis=0),
            'p75':    np.percentile(data, 100, axis=0),
            'mean':   np.mean(data, axis=0)
        }

# Example: South Asia IRR median
SAS_irr_median = IWW_stats[('SAS','irr')]['median']

In [None]:
def plot_area(ax, index, data1, data2, title, title1, label, color, ylabel, xlabel, marker):
    ax.text(-0.05, 1.05, index, color='dimgrey', fontsize=14, transform=ax.transAxes, weight='bold')
    # ax.step(range(1901, 2015), data, label = label, color = color, linewidth = 1, marker = marker, markersize=1.2)
    ax.fill_between(range(1901, 2015), data1, data2, label = label, color = color, linewidth = 1, alpha=0.8)
    plt.title(title, loc='right')
    plt.title(title1, loc='left')
    plt.ylabel(ylabel, fontsize = 14)
    plt.xlabel(xlabel, fontsize = 14)
    plt.yticks(fontsize = 14)
    plt.xticks(fontsize=14)
    plt.xlim(1901,2015)

# === Plotting Function Without Smoothing ===
def plot_water(ax, index, data_mean, data_min, data_max,
              title_right='', title_left='', label='', color='k',
              ylabel='', xlabel='', marker=None):
    """
    Plot a central line with shaded min-max range.
    """
    years = np.arange(1901, 2015)
    # Index label
    ax.text(-0.05, 1.05, index, color='dimgrey', fontsize=14,
            transform=ax.transAxes, weight='bold')
    # Mean line
    ax.plot(years, data_mean, label=label, color=color,
            linewidth=2, marker=marker, markersize=4, alpha=0.8)
    # Shaded range
    ax.fill_between(years, data_min, data_max, color=color, alpha=0.2)
    # Titles
    if title_right:
        ax.set_title(title_right, loc='right', fontsize=16)
    if title_left:
        ax.set_title(title_left, loc='left', fontsize=16)
    # Labels
#     ax.set_ylabel(ylabel, fontsize=18)
    ax.set_xlabel('year', fontsize=16)
    # Ticks and limits
    ax.tick_params(axis='both', labelsize=14)
    ax.set_xlim(1901, 2015)
    ax.legend(loc='upper left', fontsize=16)
    ax.grid(linestyle='--', alpha=0.5)

# === Create Figure ===
fig, axes = plt.subplots(2, 4, figsize=(20, 8), dpi=300)
fig.subplots_adjust(hspace=0.4, wspace=0.3,
                    left=0.05, right=0.95,
                    top=0.95, bottom=0.05)

ax1 = plt.subplot(241, frameon=True)
ax1.spines["top"].set_visible(True)
ax1.spines["right"].set_visible(True)
plot_area(ax1, 'a', 0, regional_irri_land['SAS'] / 1000000, '', '', '', 'orange', 'AEI ($\mathregular{10^6}$ $\mathregular{km^2}$)', 'year', '^')    
plt.grid(linestyle = '--')
plt.ylim(0,0.8)
plt.legend(loc = 'upper left')
ax1.set_title('South Asia', loc='right', fontsize=16)

ax1 = plt.subplot(242, frameon=True)
ax1.spines["top"].set_visible(True)
ax1.spines["right"].set_visible(True)
plot_area(ax1, 'b', 0, regional_irri_land['MED'] / 1000000, '', '', '', 'orange', '', 'year', '^')    
plt.grid(linestyle = '--')
plt.ylim(0,0.25)
plt.legend(loc = 'upper left')
ax1.set_title('Mediterranean', loc='right', fontsize=16)

ax1 = plt.subplot(243, frameon=True)
ax1.spines["top"].set_visible(True)
ax1.spines["right"].set_visible(True)
plot_area(ax1, 'c', 0, regional_irri_land['CNA'] / 1000000, '', '', '', 'orange', '', 'year', '^')    
plt.grid(linestyle = '--')
plt.ylim(0,0.15)
plt.legend(loc = 'upper left')
ax1.set_title('Central North America', loc='right', fontsize=16)

ax1 = plt.subplot(244, frameon=True)
ax1.spines["top"].set_visible(True)
ax1.spines["right"].set_visible(True)
plot_area(ax1, 'd', 0, regional_irri_land['WCA'] / 1000000, '', '', '', 'orange', '', 'year', '^')    
plt.grid(linestyle = '--')
plt.ylim(0,0.4)
plt.legend(loc = 'upper left')
ax1.set_title('West Central Asia', loc='right', fontsize=16)

# === Panel a: South Asia ===
ax = plt.subplot(245, frameon=True)
ax.spines["top"].set_visible(True)
ax.spines["right"].set_visible(True)
plot_water(
    ax, 'e',
    savgol_filter(IWW_stats[('SAS','irr')]['mean'],19,2),
    savgol_filter(IWW_stats[('SAS','irr')]['p25'],19,2),
    savgol_filter(IWW_stats[('SAS','irr')]['p75'],19,2),
    title_right='South Asia',
    label='tranirr', color='dodgerblue',
    ylabel=r'IWW (mm/year)', xlabel='', marker='^'
)


plot_water(
    ax, '',
    savgol_filter(IWW_stats[('SAS','noi')]['mean'],19,2),
    savgol_filter(IWW_stats[('SAS','noi')]['p25'],19,2),
    savgol_filter(IWW_stats[('SAS','noi')]['p75'],19,2),
    label='1901irr', color='brown', marker='o'
)
ax.set_ylabel(r'IWW (mm/year)', fontsize=14)

# === Panel b: Mediterranean ===
ax = plt.subplot(246, frameon=True)
ax.spines["top"].set_visible(True)
ax.spines["right"].set_visible(True)
plot_water(
    ax, 'f',
    savgol_filter(IWW_stats[('MED','irr')]['mean'],19,2),
    savgol_filter(IWW_stats[('MED','irr')]['p25'],19,2),
    savgol_filter(IWW_stats[('MED','irr')]['p75'],19,2),
    title_right='Mediterranean',
    label='tranirr', color='dodgerblue',
    ylabel=r'IWW (mm/year)', xlabel='', marker='^'
)
plot_water(
    ax, '',
    savgol_filter(IWW_stats[('MED','noi')]['mean'],19,2),
    savgol_filter(IWW_stats[('MED','noi')]['p25'],19,2),
    savgol_filter(IWW_stats[('MED','noi')]['p75'],19,2),
    label='1901irr', color='brown', marker='o'
)


# === Panel c: Central North America ===
ax = plt.subplot(247, frameon=True)
ax.spines["top"].set_visible(True)
ax.spines["right"].set_visible(True)
plot_water(
    ax, 'g',
    savgol_filter(IWW_stats[('CNA','irr')]['mean'],19,2),
    savgol_filter(IWW_stats[('CNA','irr')]['p25'],19,2),
    savgol_filter(IWW_stats[('CNA','irr')]['p75'],19,2),
    title_right='Central North America',
    label='tranirr', color='dodgerblue',
    ylabel=r'IWW (mm/year)', xlabel='', marker='^'
)
plot_water(
    ax, '',
    savgol_filter(IWW_stats[('CNA','noi')]['mean'],19,2),
    savgol_filter(IWW_stats[('CNA','noi')]['p25'],19,2),
    savgol_filter(IWW_stats[('CNA','noi')]['p75'],19,2),
    label='1901irr', color='brown', marker='o'
)
plt.ylim(0,30)

# === Panel d: West Central Asia ===
ax = plt.subplot(248, frameon=True)
ax.spines["top"].set_visible(True)
ax.spines["right"].set_visible(True)
plot_water(
    ax, 'h',
    savgol_filter(IWW_stats[('WCA','irr')]['mean'],19,2),
    savgol_filter(IWW_stats[('WCA','irr')]['p25'],19,2),
    savgol_filter(IWW_stats[('WCA','irr')]['p75'],19,2),
    title_right='West Central Asia',
    label='tranirr', color='dodgerblue',
    ylabel=r'IWW (mm/year)', xlabel='Year', marker='^'
)
plot_water(
    ax, '',
    savgol_filter(IWW_stats[('WCA','noi')]['mean'],19,2),
    savgol_filter(IWW_stats[('WCA','noi')]['p25'],19,2),
    savgol_filter(IWW_stats[('WCA','noi')]['p75'],19,2),
    label='1901irr', color='brown', marker='o'
)


plt.show()