# Data wrangling and plots

Notebook for doing some data wrangling and creating the plots for all the basins.

In [None]:
import xarray as xr
import cartopy.crs as ccrs
from matplotlib import pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from matplotlib import ticker as plticker
import matplotlib.dates as mdates
from matplotlib.ticker import StrMethodFormatter, ScalarFormatter
from gha.hydro import calc_SPEI
import seaborn as sns
from sklearn.metrics import r2_score
import geopandas as gpd
import pandas as pd
import re
import os
import numpy as np
import warnings
warnings.filterwarnings(action='ignore')

In [None]:
plt.rcParams['figure.figsize'] = (10, 7)
sns.set_theme()
fig_path = '/home/users/eholmgren/www_eholmgren/msc_thesis/plots/'

In [None]:
def get_figsize(columnwidth, wf=0.5, hf=(5.**0.5-1.0)/2.0, ):
      """Parameters:
        - wf [float]:  width fraction in columnwidth units
        - hf [float]:  height fraction in columnwidth units.
                       Set by default to golden ratio.
        - columnwidth [float]: width of the column in latex. Get this from LaTeX 
                               using \showthe\columnwidth
      Returns:  [fig_width,fig_height]: that should be given to matplotlib
      """
      fig_width_pt = columnwidth*wf 
      inches_per_pt = 1.0/72.27               # Convert pt to inch
      fig_width = fig_width_pt*inches_per_pt  # width in inches
      fig_height = fig_width*hf      # height in inches
      return [fig_width, fig_height]

In [None]:
basins_df = gpd.read_file('./data/glacier_basins.shp')
gcm_df = pd.read_csv('/home/www/oggm/cmip6/all_gcm_list.csv', index_col=0)

In [None]:
basins_df

In [None]:
gcms = gcm_df.gcm.unique()
gcms = np.delete(gcms, 10)
gcms

We want to read in each scenario for each basin. Keep them in a dict.

In [None]:
# Base path where data is stored on the cluster.
base_path = '/home/users/eholmgren/work/gha_basins/'
# List with rcp scenarios
ssps = ['ssp126', 'ssp245', 'ssp370', 'ssp585']

## Read in the basin data in nested dicts
This takes a while. Would like to put all of it in a single netcdf.

In [None]:
# We have a dual loop to load the data into a nested dict.
basin_dict = {}
for basin in basins_df.MRBID:
    gcm_dict = {}
    mrbid = str(basin)
    for gcm in gcms:
        scenario_dict = {}
        for ssp in ssps:
            # Create the path to the basin folder.
            path = os.path.join(base_path, mrbid) 
            # Name of the file.
            file = f'{mrbid}_discharge_proj_{gcm}_{ssp}.nc' 
            # Full path to the file.
            path = os.path.join(path, file)
            # Open the file with xarray.
            try:
                with xr.open_dataset(path, use_cftime=True) as ds:
                    # Last year is not good.
                    scenario_dict[ssp] = ds.isel(time=slice(0, -12))
            except FileNotFoundError:
                continue
            # Put the df in a dict.
            gcm_dict[gcm] = scenario_dict
    # Put the gcm dict in the basin dict.
    basin_dict[mrbid] = gcm_dict 

## Year of peak water
Calc. the year of peak water.

In [None]:
df_peak_water = pd.DataFrame(columns=['mrbid', 'basin_name', 'ssp126-means', 'ssp245-means', 'ssp370-means',
                                      'ssp585-means'])


# Create a new dict.
peak_water_dict = {}
# Loop all basins.
for basin in basins_df.iloc:
    mrbid = str(basin.MRBID)
    # String formatting
    basin_name = re.split(r'[\(\)]', basin.RIVER_BASI)[0].title()
    # Each ssp scenario under a separate key.
    ssp_dict = {}
    # Loop ssps.
    for ssp in ssps:
        ens_list = []
        # Loop the gcms
        for gcm in gcms:
            # We try because not all gcms have all ssps etc.
            try:
                df = basin_dict[mrbid][gcm][ssp]
                annual_runoff = df.glacier_runoff.isel(time=slice(0, -12)).groupby('time.year').sum()
                annual_runoff = annual_runoff.rolling(year=11, center=True, min_periods=1).mean()
                ens_list.append(annual_runoff)
            except KeyError:
                continue

        # When all is in the list. concat.
        ens_df = xr.concat(ens_list, dim='gcm')
        year_means = ens_df.isel(year=ens_df.argmax(dim='year')).year.values
        ssp_dict[ssp] = year_means
    #ssp_df = xr.concat(ssp_list, dim='ssp')
    peak_water_dict[mrbid] = ssp_dict
#peak_water_df = xr.concat(basin_list, dim='basin')

In [None]:
# We want to put the peak water values in a large array for image.
image_data = np.zeros((75, 4, 14))
for i, basin in enumerate(basins_df.iloc):
    mrbid = str(basin.MRBID)
    for j, ssp in enumerate(ssps):
        values = peak_water_dict[mrbid][ssp]
        values.sort()
        image_data[i, j, :len(values)] = values
                          
image_data
image_data = image_data.reshape(image_data.shape[:-2] + (-1,))

In [None]:
image_data

In [None]:
# We basically want to reduce the dimensions of this dict a bit.
# Loop over the basins
peak_water_dict_csv = {}
for basin in peak_water_dict.keys():
    # Loop over the scenarios
    ssp_dict = {}
    for ssp in peak_water_dict[basin].keys():
        # Get the mean, max, min
        mean_year = peak_water_dict[basin][ssp].mean()
        std = peak_water_dict[basin][ssp].std()
        # Put in dict.
        ssp_dict[f'{ssp}_mean'] = int(mean_year)
        ssp_dict[f'{ssp}_pm'] = int(std)
    # Put ssp in basin dict.
    basin_name = re.split(r'[\(\)]', basins_df[basins_df.MRBID == int(basin)].RIVER_BASI.iloc[0])[0].title()
    peak_water_dict_csv[basin_name] = ssp_dict

    

In [None]:
df_peak_water = pd.DataFrame.from_dict(peak_water_dict_csv, orient='index')

In [None]:
df_peak_water

In [None]:
df_peak_water.to_csv('/home/users/eholmgren/www_eholmgren/msc_thesis/peak_water_years.csv', sep=',')

### Peak water scatter plot

Plot the year of peak water against the RGI area.

In [None]:
basins_df.loc[basins_df.MRBID == int(mrbid)].geometry.area

In [None]:
basins_df.loc[basins_df.MRBID == 2103]

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15/2.54, 15/2.54), sharex=True, sharey=True)
font_size = 7
# Scatter plot basin glaciated ratio to year of peak water for each ssp scenario.
# List to save one handle from each scatter.
handles = []
# Loop over ssps.
for i, ssp in enumerate(ssps):
    # Save years.
    yops = []
    areas = []
    # Loop basins.
    for basin in basins_df.iloc:
        mrbid = str(basin.MRBID)
        # Get the year of peak water.
        yop = int(peak_water_dict[mrbid][ssp].mean())
        yops.append(yop)
        # Get the area ratio
        area_ratio = (basins_df.loc[basins_df.MRBID == int(mrbid)].RGI_AREA /
                      basins_df.loc[basins_df.MRBID == int(mrbid)].AREA_CALC).iloc[0] * 100
        areas.append(area_ratio)
    
    yops = np.asarray(yops)
    areas = np.asarray(areas)
    # Scatter plot
    ax.flat[i].scatter(yops, areas , color=f"C{i}", label=ssp, s=10)
    # Calculate the correlation for all samples.
    fit = np.polyfit(yops, areas, deg=1)
    fit_x = np.poly1d(fit)
    # Plot the line.
    x = np.linspace(2020, 2100) 
    ax.flat[i].plot(x, fit_x(x), lw=0.8, ls="--", color="k")
    # Add stats.
    ax.flat[i].text(2020, 23, f"R$^2$ = {r2_score(areas, fit_x(yops)):.3f}", fontsize=font_size)
    
    # Get one handle.
    handles.append(ax.flat[i].get_legend_handles_labels()[0][0])
    
    # Grid
    ax.flat[i].grid(which='both', linewidth=0.5)
    # Label params
    ax.flat[i].tick_params(axis="y", labelsize=font_size, pad=-4)
    ax.flat[i].tick_params(axis="x", labelsize=font_size, pad=-2)
    
# Legend.
labels = ["SSP1-2.6", "SSP2-4.5", "SSP3-7.0", "SSP5-8.5"]
fig.legend(handles, labels, ncol=4, bbox_to_anchor=(0.5, 0.92), fontsize=font_size,
           columnspacing=0.5, handlelength=1, loc='center');
        

# Common x/y labels.
fig.text(0.5, 0.08, "Mean year of peak water", ha='center', fontsize=font_size)
fig.text(0.07, 0.5, 'Glaciated fraction [%]', va='center', rotation='vertical', fontsize=font_size)
# Title.
fig.text(0.5, 0.95, "Ensemble mean year of peak water compared to the basin\n initially glaciated area fraction.", ha='center', fontsize=font_size);
plt.subplots_adjust(wspace=0.05, hspace=0.05);

path = os.path.join(fig_path, "yop_scatter.pdf")
plt.savefig(path, bbox_inches="tight")

#### Scatter presentation

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(8/2.54, 8/2.54), sharex=True, sharey=True)
font_size = 4
# Scatter plot basin glaciated ratio to year of peak water for each ssp scenario.
# List to save one handle from each scatter.
handles = []
# Loop over ssps.
for i, ssp in enumerate(ssps):
    # Save years.
    yops = []
    areas = []
    # Loop basins.
    for basin in basins_df.iloc:
        mrbid = str(basin.MRBID)
        # Get the year of peak water.
        yop = int(peak_water_dict[mrbid][ssp].mean())
        yops.append(yop)
        # Get the area ratio
        area_ratio = (basins_df.loc[basins_df.MRBID == int(mrbid)].RGI_AREA /
                      basins_df.loc[basins_df.MRBID == int(mrbid)].AREA_CALC).iloc[0] * 100
        areas.append(area_ratio)
    
    yops = np.asarray(yops)
    areas = np.asarray(areas)
    # Scatter plot
    ax.flat[i].scatter(yops, areas , color=f"C{i}", label=ssp, s=5)
    # Calculate the correlation for all samples.
    fit = np.polyfit(yops, areas, deg=1)
    fit_x = np.poly1d(fit)
    # Plot the line.
    x = np.linspace(2020, 2100) 
    ax.flat[i].plot(x, fit_x(x), lw=0.5, ls="--", color="k")
    # Add stats.
    ax.flat[i].text(2020, 23, f"R$^2$ = {r2_score(areas, fit_x(yops)):.3f}", fontsize=font_size)
    
    # Get one handle.
    handles.append(ax.flat[i].get_legend_handles_labels()[0][0])
    
    # Grid
    ax.flat[i].grid(which='both', linewidth=0.5)
    # Label params
    ax.flat[i].tick_params(axis="y", labelsize=font_size, pad=-4)
    ax.flat[i].tick_params(axis="x", labelsize=font_size, pad=-2)
    
# Legend.
labels = ["SSP1-2.6", "SSP2-4.5", "SSP3-7.0", "SSP5-8.5"]
fig.legend(handles, labels, ncol=4, bbox_to_anchor=(0.5, 0.92), fontsize=font_size,
           columnspacing=0.5, handlelength=1, loc='center');
        

# Common x/y labels.
fig.text(0.5, 0.07, "Mean year of peak water", ha='center', fontsize=font_size)
fig.text(0.06, 0.5, 'Glaciated fraction [%]', va='center', rotation='vertical', fontsize=font_size)
# Title.
fig.text(0.5, 0.95, "Ensemble mean year of peak water compared to the basin\n initially glaciated area fraction.", ha='center', fontsize=font_size);
plt.subplots_adjust(wspace=0.05, hspace=0.05);

path = os.path.join(fig_path, "yop_scatter_pres.pdf")
plt.savefig(path, bbox_inches="tight", facecolor="none")

In [None]:
fit

### Huss and hock data

In [None]:
hh_df = pd.read_csv('./data/hh_peak_water.csv', sep='\s+', header=None)

In [None]:
hh_df

In [None]:
# Clean up the dataset a bit.
# Get the values
for i, (j, rcp) in enumerate(zip([1, 6, 11], ['RCP26', 'RCP45', 'RCP85'])):
    values = hh_df[j].str.split('±', expand=True)
    hh_df[j] = values[0]
    hh_df.insert(j+1+i, column=f'std-{j}', value=values[1])

    hh_df.rename(columns={j: rcp, f'std-{j}': f'std-tot-{rcp}', j+1: f'std-years-{rcp}'}, inplace=True)
# Drop things.
hh_df.drop([3, 4, 5, 8, 9, 10, 13, 14, 15], axis=1, inplace=True)
hh_df.rename(columns={0: 'BASIN'}, inplace=True)

In [None]:
hh_df

In [None]:
# Change the names back.
# Aral
hh_df.iat[3, 0] = 'ARAL SEA'
# Santa Crux
hh_df.iat[11, 0] = 'SANTA CRUZ'
# Jokulsa
hh_df.iat[24, 0] = 'JOKULSA A FJOLLUM'
# Huand He
hh_df.iat[37, 0] = 'HUANG HE'
# 

In [None]:
hh_df.head()

In [None]:
df_peak_water.index

In [None]:
# Now we can compare them.
# Some name corrections
basin_name_corr = {'tarim': 'tarim he ', 'balkhash': 'lake balkhash', 'issyk-kul': 'ysyk-kol ',
                   'oelfusa': 'olfusa', 'gloma': 'glomaa', 'lule': 'lulealven', 'amazon': 'amazon ',
                   'negro': 'negro ', 'kalixaelven': 'kalixalven', 'dramselv': 'dramselva'}
peak_w_comp_dict = {}
for basin in hh_df.iloc:
    basin_name = basin.BASIN.title()
    # Does the basin exist?
    try:
        basin_name = basin_name_corr[basin_name.lower()].title()
    except KeyError:
        pass
    
    try:
        basin_new = df_peak_water.loc[df_peak_water.index == basin_name.title()].iloc[0]
        # scenario dict
        scenario_dict = {}
        # Loop rcp/ssp.
        for (rcp, ssp) in zip(['RCP26', 'RCP45', 'RCP85'], ['ssp126', 'ssp245', 'ssp585']):
            # Is the difference larger then 10 year?
            if abs(int(basin[rcp]) - basin_new[f'{ssp}_mean']) > 10:
                # Dict making
                scenario_dict[f'{ssp}/{rcp}'] = (basin_new[f'{ssp}_mean'], int(basin[rcp]))
                
        peak_w_comp_dict[basin_name] = scenario_dict
    except IndexError:
        print('Not found', basin_name)

In [None]:
peak_w_comp_df = pd.DataFrame.from_dict(peak_w_comp_dict, orient='index')

In [None]:
peak_w_comp_df

In [None]:
peak_w_comp_df.to_csv('/home/users/eholmgren/www_eholmgren/msc_thesis/hh_compare.csv')

### Peak water image plot uncertainty

In [None]:
basin_names= [re.split(r'[\(\)]', name)[0].title() for name in basins_df.RIVER_BASI]

In [None]:
import matplotlib.transforms as transform

In [None]:
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

In [None]:
# Main figure
fig, ax = plt.subplots(figsize=(9/2.54, 28/2.54))
# color axes
divider = make_axes_locatable(ax)
cax = divider.append_axes('bottom', size='2%', pad=0.05)
# image data, mask bad values
im_data_masked = np.ma.masked_where(image_data == 0, image_data)
# Some utility vars.
xlen = image_data.shape[1]
ylen = image_data.shape[0]
#cmap
cmap = plt.get_cmap('viridis')
cmap.set_bad(color='lightgrey')
# The main colormesh
im = ax.pcolormesh(im_data_masked[::-1], cmap=cmap, vmin=2020., vmax=2100.,
                   edgecolors='face', lw=0)
# Add labels to the pixels. Median value.
for i in range(ylen):
    for j in range(4):
        # Mean value
        mean = im_data_masked[i, j*14:(j*14)+14].mean() 
        ax.text(7+14*j, (74-i)+0.4, int(mean - 2000),
                ha="center", va="center", color="w",
                fontsize=4, transform=ax.transData,
               )
       # ax.text(1+(14*j), (74-i)+0.4, int((data_im[i, 14*j]-2000)),
       #         ha="left", va="center", color="w",
       #         fontsize=5,
       #        )
       # ax.text(13+14*j, (74-i)+0.4, int(data_im[i, (j*12)+14]-2000),
       #         ha="right", va="center", color="w",
       #         fontsize=5, transform=ax.transData,
       #        )

# Ylabels and ticks
ax.set_yticks(np.arange(image_data.shape[0])[::-1]+0.5)
ax.set_yticklabels(basin_names)
# Xlabels and ticks. Loop cause I'm too lazy to deal with tickers.
ax.set_xticklabels([])
for i, ssp in enumerate(['SSP1-2.6', 'SSP2-4.5', 'SSP3-7.0', 'SSP5-8.5']):
    ax.text(7+i*14, 75.5, ssp, ha='center', va='center',
           transform=ax.transData,
           fontsize=5)
    
    
# Remove ticks
ax.tick_params('x', top=False, pad=-5)
# Change y tick padding.
ax.tick_params('y', pad=-4)
# Labelsize of yticks.
ax.tick_params('y', labelsize=7)
# Add the colorbar.
fig.colorbar(im, cax=cax, orientation='horizontal')
# Colorbar tick params.
cax.tick_params('both', labelsize=7, width=0.5, length=4)
cax.set_xlabel('Year', fontsize=7)
# Turn off grid.
ax.grid(None)

# Title
ax.set_title('Peak water ensemble estimations for 75 large scale basins', fontsize=7,
            x=0.5, y=1.02, transform=ax.transAxes, ha='center')
# Save figure.
#fig.savefig(fig_path+'peak_years.pdf', bbox_inches='tight')

#### Peak water image presentation.

In [None]:
image_data.shape

In [None]:
image_data.reshape((56, 75)).shape

In [None]:
# Main figure
fig, ax = plt.subplots(figsize=(7.5/2.54, 13.5/2.54))
# color axes
divider = make_axes_locatable(ax)
cax = divider.append_axes('bottom', size='2%', pad=0.02)
# image data, mask bad values
im_data_masked = np.ma.masked_where(image_data == 0, image_data)
# Some utility vars.
xlen = image_data.shape[1]
ylen = image_data.shape[0]
#cmap
cmap = plt.get_cmap('viridis')
cmap.set_bad(color='lightgrey')
# The main colormesh
im = ax.pcolormesh(im_data_masked[::-1], cmap=cmap, vmin=2020., vmax=2100.,
                   edgecolors='face', lw=0)
# Add labels to the pixels. Median value.
for i in range(ylen):
    for j in range(4):
        # Mean value
        mean = im_data_masked[i, j*14:(j*14)+14].mean() 
        ax.text(7+14*j, (74-i)+0.4, int(mean - 2000),
                ha="center", va="center", color="w",
                fontsize=2, transform=ax.transData,
                rotation=-90
               )
       # ax.text(1+(14*j), (74-i)+0.4, int((data_im[i, 14*j]-2000)),
       #         ha="left", va="center", color="w",
       #         fontsize=5,
       #        )
       # ax.text(13+14*j, (74-i)+0.4, int(data_im[i, (j*12)+14]-2000),
       #         ha="right", va="center", color="w",
       #         fontsize=5, transform=ax.transData,
       #        )

# Ylabels and ticks
ax.set_yticks(np.arange(image_data.shape[0])[::-1]+0.5)
ax.set_yticklabels(basin_names)
# Xlabels and ticks. Loop cause I'm too lazy to deal with tickers.
ax.set_xticklabels([])
for i, ssp in enumerate(['SSP1-2.6', 'SSP2-4.5', 'SSP3-7.0', 'SSP5-8.5']):
    ax.text(7+i*14, 75.5, ssp, ha='center', va='center',
           transform=ax.transData,
           fontsize=4)
    
    
# Remove ticks
ax.tick_params('x', top=False, pad=-5)
# Change y tick padding.
ax.tick_params('y', pad=-4)
# Labelsize of yticks.
ax.tick_params('y', labelsize=4)
# Add the colorbar.
fig.colorbar(im, cax=cax, orientation='horizontal')
# Colorbar tick params.
cax.tick_params('both', labelsize=4, width=0, length=2.5, rotation=-90, pad=0.05)
cax.set_xlabel('Year', fontsize=4)
# Turn off grid.
ax.grid(None)

# Title
ax.set_title('Peak water ensemble estimations for 75 large scale basins', fontsize=5,
            x=1.02, y=0.5, transform=ax.transAxes, ha='center', va="center", rotation=-90)
# Save figure.
fig.savefig(fig_path+'peak_years_presentation.pdf', bbox_inches='tight', facecolor="none")

## Plot prototype for ensemble

In [None]:
fig, ax = plt.subplots()

basin = '2306'
# Loop over the dict.
max_values = []
max_years = []
# Create the ensemble data.
for ssp in ssps:
    ens_list = []
    for gcm in gcms:
        try:
            df = basin_dict[basin][gcm][ssp]
            annual_runoff = df.glacier_runoff.isel(time=slice(0, -12)).groupby('time.year').sum()
            annual_runoff = annual_runoff.rolling(year=11, center=True, min_periods=1).mean()
            annual_runoff = (annual_runoff - annual_runoff.isel(year=0)) / annual_runoff.isel(year=0)
            ens_list.append(annual_runoff)
        except KeyError:
            continue
    
    
    # When all is in the list. concat.
    ens_df = xr.concat(ens_list, dim='gcm')
    q25 = ens_df.quantile(0.25, dim='gcm')
    q75 = ens_df.quantile(0.75, dim='gcm')
    # Fill the quantiles..
    ax.fill_between(q25.year, q25, q75, zorder=1, alpha=0.3)
    # label
    label = f'{ssp}: N={len(ens_df)}'
    ens_df.mean(dim='gcm').plot(ax=ax, label=label)
    # Get max values for each ssp.
    max_value = ens_df.mean(dim='gcm').max().values
    max_values.append(max_value)
    # What year does this correspond to?
    max_year = ens_df.mean(dim='gcm').isel(year=ens_df.mean(dim='gcm').argmax()).year
    max_years.append(max_year)
    ymin, ymax = ax.get_ylim()
    
colors = ['C0', 'C1', 'C2', 'C3']
for value, year, color in zip(max_values, max_years, colors):
    ax.axvline(year, ymax=(value-ymin)/(ymax-ymin), zorder=3,
               ls=':', color=color, lw=2)
plt.legend(loc='best');

### Asia and NZ

## Europe
Create Europe selection.

### Europe

In [None]:
continent = 'Europe'
basins = basins_df.loc[basins_df.CONTINENT == continent]

# Some subplots fixing, where should things go, how large etc.
shift = 1
nbasins = len(basins)
ncols = 4
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
labels = []
# font size
font_size = 7
# Plot all the basins
for i, basin in enumerate(basins.MRBID):
    #  Want to plot it in the last 75 subplots.
    idx = i + shift
    mrbid = str(basin)
    # Store some props for ensemble max and max year.
    max_values = []
    max_years = []
    ymin = 0
    ymax = 0
    # Loop over the dicts to create the ensemble.
    for ssp in ssps:
        ens_list = []
        for gcm in gcms:
            try:
                df = basin_dict[mrbid][gcm][ssp]
                annual_runoff = df.glacier_runoff.isel(time=slice(0, -12)).groupby('time.year').sum()
                annual_runoff = annual_runoff.rolling(year=11, center=True, min_periods=1).mean()
                annual_runoff_chg = (annual_runoff - annual_runoff.isel(year=0)) / annual_runoff.isel(year=0)
                annual_runoff_chg *= 100
                ens_list.append(annual_runoff_chg)
            except KeyError:
                continue

        # When all is in the list. concat.
        ens_df = xr.concat(ens_list, dim='gcm')
        q25 = ens_df.quantile(0.25, dim='gcm')
        q75 = ens_df.quantile(0.75, dim='gcm')
        # Fill quantiles.
        ax.flat[idx].fill_between(q25.year, q25, q75, zorder=1, alpha=0.3)
        # Label the ssp
        label = f'{ssp}\n N={len(ens_df)}'
        # Line plot of mean.
        ens_df.mean(dim='gcm').plot(ax=ax.flat[idx], lw=0.8, label=label)
        # Stuff for vlines: indicate peak years.
        max_value = ens_df.mean(dim='gcm').max().values
        max_values.append(max_value)
        max_year = ens_df.mean(dim='gcm').isel(year=ens_df.mean(dim='gcm').argmax()).year
        max_years.append(max_year)
        ymin_t, ymax_t = ax.flat[idx].get_ylim()
        if ymin_t < ymin:
            ymin = ymin_t
        if ymax_t > ymax:
            ymax = ymax_t
    
        # Labels.
        ax.flat[idx].set_ylabel('')
        ax.flat[idx].set_xlabel('')
        ax.flat[idx].tick_params(axis='y', labelsize=font_size, pad=-5)
        ax.flat[idx].tick_params(axis='x', labelsize=font_size, pad=-2)
        # years_locator = mdates.AutoDateLocator(minticks=2, maxticks=3)
        # ax.flat[idx].xaxis.set_major_locator(years_locator)
        ax.flat[idx].set_xticks([2020., 2060., 2100.])
        ax.flat[idx].tick_params("x", labelrotation=30.)

        # Grid
        ax.flat[idx].grid(which='both', linewidth=0.5)


        # Fix the years

        # Get the basin name. 
        name = basins.iloc[i].RIVER_BASI
        # We don't need to plot the alternative names for now, so split it.
        name = re.split(r'[\(\)]', name)[0]
        if mrbid == '2910':
            name = name.split()[1]
        if mrbid == '6101':
            name = name.split()[0]
        ax.flat[idx].set_title(f'{name.title()}', size=font_size, pad=1.1)
        ax.flat[idx].tick_params('both')
    # Plot vlines
    colors = ['C0', 'C1', 'C2', 'C3']
    for value, year, color in zip(max_values, max_years, colors):
        ax.flat[idx].axvline(year, ymax=(value-ymin)/(ymax-ymin), zorder=3,
                   ls=':', color=color, lw=1.)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
for i in range(nplots - (nplots - nbasins - 1), nplots):
    ax.flat[i].set_visible(False)

# Make a legend.
handles, labels = ax.flat[5].get_legend_handles_labels()
fig.legend(handles, labels, ncol=1, bbox_to_anchor=(0.2, 0.8), fontsize=font_size,
           columnspacing=0.5, handlelength=1, loc='center')
# Figure title
# Do it in inkscape
# Common x/y labels
fig.text(0.5, 0.07, 'Year', ha='center', fontsize=font_size)
fig.text(0.07, 0.5, 'Relative change (%)', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 0.91, f'Glacier runoff projection ensembles for {continent}', ha='center', fontsize=font_size)


#fig.tight_layout()
plt.savefig(fig_path+f'peak_water_europe.pdf', bbox_inches='tight')

### South America

## Asia / SW pacific

In [None]:
continent = 'Asia'
basins = basins_df.loc[(basins_df.CONTINENT.eq('Asia')) | (basins_df.CONTINENT.eq('South-West Pacific'))]

shift = 0
nbasins = len(basins)
ncols = 4
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
labels = []
# font size
font_size = 7
# Plot all the basins
for i, basin in enumerate(basins.MRBID):
    #  Want to plot it in the last 75 subplots.
    idx = i + shift
    mrbid = str(basin)
    # Saving v-lines props.
    max_values = []
    max_years = []
    ymin = 0
    ymax = 0
    # Loop over the dicts.
    for ssp in ssps:
        ens_list = []
        for gcm in gcms:
            try:
                df = basin_dict[mrbid][gcm][ssp]
                annual_runoff = df.glacier_runoff.isel(time=slice(0, -12)).groupby('time.year').sum()
                annual_runoff = annual_runoff.rolling(year=11, center=True, min_periods=1).mean()
                annual_runoff_chg = (annual_runoff - annual_runoff.isel(year=0)) / annual_runoff.isel(year=0)
                annual_runoff_chg *= 100
                ens_list.append(annual_runoff_chg)
            except KeyError:
                continue

        # When all is in the list. concat.
        ens_df = xr.concat(ens_list, dim='gcm')
        q25 = ens_df.quantile(0.25, dim='gcm')
        q75 = ens_df.quantile(0.75, dim='gcm')
        ax.flat[idx].fill_between(q25.year, q25, q75, zorder=1, alpha=0.3)
        label = f'{ssp}\n N={len(ens_df)}'
        ens_df.mean(dim='gcm').plot(ax=ax.flat[idx], lw=0.8, label=label)
        # Stuff for vlines
        max_value = ens_df.mean(dim='gcm').max().values
        max_values.append(max_value)
        max_year = ens_df.mean(dim='gcm').isel(year=ens_df.mean(dim='gcm').argmax()).year
        max_years.append(max_year)
        ymin_t, ymax_t = ax.flat[idx].get_ylim()
        if ymin_t < ymin:
            ymin = ymin_t
        if ymax_t > ymax:
            ymax = ymax_t
        
        # Axes labels.
        ax.flat[idx].set_ylabel('')
        ax.flat[idx].set_xlabel('')
        ax.flat[idx].tick_params(axis='y', labelsize=font_size, pad=-5)
        ax.flat[idx].tick_params(axis='x', labelsize=font_size, pad=-2)
        ax.flat[idx].set_xticks([2020., 2060., 2100.])
        ax.flat[idx].tick_params("x", labelrotation=30.)

        # Grid
        ax.flat[idx].grid(which='both', linewidth=0.5)


        # Fix the years
        #ax.flat[idx].set_xlim(2020, 2100)

        # Get the basin name. 
        name = basins.iloc[i].RIVER_BASI
        # We don't need to plot the alternative names for now, so split it.
        name = re.split(r'[\(\)]', name)[0]
        if mrbid == '2910':
            name = name.split()[1]
        if mrbid == '6101':
            name = name.split()[0]
        ax.flat[idx].set_title(f'{name.title()}', size=font_size, pad=1.1)
        ax.flat[idx].tick_params('both')
    # Plot vlines
    colors = ['C0', 'C1', 'C2', 'C3']
    for value, year, color in zip(max_values, max_years, colors):
        ax.flat[idx].axvline(year, ymax=(value-ymin)/(ymax-ymin), zorder=3,
                   ls=':', color=color, lw=1)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
for i in range(nplots - (nplots - nbasins - 1), nplots):
    ax.flat[i].set_visible(False)

# Make a legend.
handles, labels = ax.flat[5].get_legend_handles_labels()
fig.legend(handles, labels, ncol=4, bbox_to_anchor=(0.5, 0.92), fontsize=font_size,
           columnspacing=0.5, handlelength=1, loc='center')
# Figure title
# Do it in inkscape
# Common x/y labels
fig.text(0.5, 0.08, 'Year', ha='center', fontsize=font_size)
fig.text(0.07, 0.5, 'Relative change (%)', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 0.95, f'Glacier runoff projection ensembles for Asia and New Zealand', ha='center', fontsize=font_size)


#fig.tight_layout()
plt.savefig(fig_path+f'peak_water_asia_nz.pdf', bbox_inches='tight')

In [None]:
continent = 'South America'
basins = basins_df.loc[(basins_df.CONTINENT == continent)]

shift = 0
nbasins = len(basins)
ncols = 4
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
labels = []
# font size
font_size = 7
# Plot all the basins
for i, basin in enumerate(basins.MRBID):
    #  Want to plot it in the last 75 subplots.
    idx = i + shift
    mrbid = str(basin)
    # Saving v-lines props.
    max_values = []
    max_years = []
    ymin = 0
    ymax = 0
    # Loop over the dicts.
    for ssp in ssps:
        ens_list = []
        for gcm in gcms:
            try:
                df = basin_dict[mrbid][gcm][ssp]
                annual_runoff = df.glacier_runoff.isel(time=slice(0, -12)).groupby('time.year').sum()
                annual_runoff = annual_runoff.rolling(year=11, center=True, min_periods=1).mean()
                annual_runoff_chg = (annual_runoff - annual_runoff.isel(year=0)) / annual_runoff.isel(year=0)
                annual_runoff_chg *= 100
                ens_list.append(annual_runoff_chg)
            except KeyError:
                continue

        # When all is in the list. concat.
        ens_df = xr.concat(ens_list, dim='gcm')
        q25 = ens_df.quantile(0.25, dim='gcm')
        q75 = ens_df.quantile(0.75, dim='gcm')
        ax.flat[idx].fill_between(q25.year, q25, q75, zorder=1, alpha=0.3)
        label = f'{ssp}\n N={len(ens_df)}'
        ens_df.mean(dim='gcm').plot(ax=ax.flat[idx], lw=0.8, label=label)
        # Stuff for vlines
        max_value = ens_df.mean(dim='gcm').max().values
        max_values.append(max_value)
        max_year = ens_df.mean(dim='gcm').isel(year=ens_df.mean(dim='gcm').argmax()).year
        max_years.append(max_year)
        ymin_t, ymax_t = ax.flat[idx].get_ylim()
        if ymin_t < ymin:
            ymin = ymin_t
        if ymax_t > ymax:
            ymax = ymax_t
            
        # Labels
        ax.flat[idx].set_ylabel('')
        ax.flat[idx].set_xlabel('')
        ax.flat[idx].tick_params(axis='y', labelsize=font_size, pad=-5)
        ax.flat[idx].tick_params(axis='x', labelsize=font_size, pad=-2)
        ax.flat[idx].set_xticks([2020., 2060., 2100.])
        ax.flat[idx].tick_params("x", labelrotation=30.)

        # Grid
        ax.flat[idx].grid(which='both', linewidth=0.5)


        # Fix the years
        #ax.flat[idx].set_xlim(2020, 2100)

        # Get the basin name. 
        name = basins.iloc[i].RIVER_BASI
        # We don't need to plot the alternative names for now, so split it.
        name = re.split(r'[\(\)]', name)[0]
        if mrbid == '2910':
            name = name.split()[1]
        if mrbid == '6101':
            name = name.split()[0]
        ax.flat[idx].set_title(f'{name.title()}', size=font_size, pad=1.1)
        ax.flat[idx].tick_params('both')
        
    # Plot vlines
    colors = ['C0', 'C1', 'C2', 'C3']
    for value, year, color in zip(max_values, max_years, colors):
        ax.flat[idx].axvline(year, ymax=(value-ymin)/(ymax-ymin), zorder=3,
                   ls=':', color=color, lw=1)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
for i in range(nplots - (nplots - nbasins - 1), nplots):
    ax.flat[i].set_visible(False)

# Make a legend.
handles, labels = ax.flat[5].get_legend_handles_labels()
fig.legend(handles, labels, ncol=4, bbox_to_anchor=(0.5, 0.92), fontsize=font_size,
           columnspacing=0.5, handlelength=1, loc='center')
# Figure title
# Do it in inkscape
# Common x/y labels
fig.text(0.5, 0.09, 'Year', ha='center', fontsize=font_size)
fig.text(0.07, 0.5, 'Relative change (%)', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 0.95, f'Glacier runoff projection ensembles for South America', ha='center', fontsize=font_size)


#fig.tight_layout()
plt.savefig(fig_path+f'peak_water_s_america.pdf', bbox_inches='tight')

### N. America

In [None]:
basins_df.CONTINENT.unique()

In [None]:
continent = 'North America, Central America and the Caribbean'
basins = basins_df.loc[(basins_df.CONTINENT == continent)]

shift = 0
nbasins = len(basins)
ncols = 4
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
labels = []
# font size
font_size = 7
# Plot all the basins
for i, basin in enumerate(basins.MRBID):
    #  Want to plot it in the last 75 subplots.
    idx = i + shift
    mrbid = str(basin)
    # Loop over the dicts.
    max_values = []
    max_years = []
    ymin = 0
    ymax = 0
    for ssp in ssps:
        ens_list = []
        for gcm in gcms:
            try:
                df = basin_dict[mrbid][gcm][ssp]
                annual_runoff = df.glacier_runoff.isel(time=slice(0, -12)).groupby('time.year').sum()
                annual_runoff = annual_runoff.rolling(year=11, center=True, min_periods=1).mean()
                annual_runoff_chg = (annual_runoff - annual_runoff.isel(year=0)) / annual_runoff.isel(year=0)
                annual_runoff_chg *= 100
                ens_list.append(annual_runoff_chg)
            except KeyError:
                continue

        # When all is in the list. concat.
        ens_df = xr.concat(ens_list, dim='gcm')
        q25 = ens_df.quantile(0.25, dim='gcm')
        q75 = ens_df.quantile(0.75, dim='gcm')
        ax.flat[idx].fill_between(q25.year, q25, q75, zorder=1, alpha=0.3)
        # Stuff for vlines
        max_value = ens_df.mean(dim='gcm').max().values
        max_values.append(max_value)
        max_year = ens_df.mean(dim='gcm').isel(year=ens_df.mean(dim='gcm').argmax()).year
        max_years.append(max_year)
        ymin_t, ymax_t = ax.flat[idx].get_ylim()
        if ymin_t < ymin:
            ymin = ymin_t
        if ymax_t > ymax:
            ymax = ymax_t
    
        label = f'{ssp}\n N={len(ens_df)}'
        ens_df.mean(dim='gcm').plot(ax=ax.flat[idx], lw=0.8, label=label)
        ax.flat[idx].set_ylabel('')
        ax.flat[idx].set_xlabel('')
        ax.flat[idx].tick_params(axis='y', labelsize=font_size, pad=-5)
        ax.flat[idx].tick_params(axis='x', labelsize=font_size, pad=-2)
        ax.flat[idx].set_xticks([2020., 2060., 2100.])
        ax.flat[idx].tick_params("x", labelrotation=30.)

        # Grid
        ax.flat[idx].grid(which='both', linewidth=0.5)


        # Fix the years
        #ax.flat[idx].set_xlim(2020, 2100)

        # Get the basin name. 
        name = basins.iloc[i].RIVER_BASI
        # We don't need to plot the alternative names for now, so split it.
        name = re.split(r'[\(\)]', name)[0]
        if mrbid == '2910':
            name = name.split()[1]
        if mrbid == '6101':
            name = name.split()[0]
        ax.flat[idx].set_title(f'{name.title()}', size=font_size, pad=1.1)
        ax.flat[idx].tick_params('both')
    # Plot vlines
    colors = ['C0', 'C1', 'C2', 'C3']
    for value, year, color in zip(max_values, max_years, colors):
        ax.flat[idx].axvline(year, ymax=(value-ymin)/(ymax-ymin), zorder=3,
                   ls=':', color=color, lw=1)
# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
for i in range(nplots - (nplots - nbasins - 1), nplots):
    ax.flat[i].set_visible(False)

# Make a legend.
handles, labels = ax.flat[5].get_legend_handles_labels()
fig.legend(handles, labels, ncol=4, bbox_to_anchor=(0.5, 0.92), fontsize=font_size,
           columnspacing=0.5, handlelength=1, loc='center')
# Figure title
# Do it in inkscape
# Common x/y labels
fig.text(0.5, 0.07, 'Year', ha='center', fontsize=font_size)
fig.text(0.07, 0.5, 'Relative change (%)', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 0.95, f'Glacier runoff projection ensembles for North America', ha='center', fontsize=font_size)


#fig.tight_layout()
plt.savefig(fig_path+f'peak_water_n_america.pdf', bbox_inches='tight')

### N. America presentation

In [None]:
continent = 'North America, Central America and the Caribbean'
basins = basins_df.loc[(basins_df.CONTINENT == continent)]

shift = 0
nbasins = len(basins)
ncols = 4
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 24/6
height = 9.8

fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
labels = []
# font size
font_size = 7
# Plot all the basins
for i, basin in enumerate(basins.MRBID):
    #  Want to plot it in the last 75 subplots.
    idx = i + shift
    mrbid = str(basin)
    # Loop over the dicts.
    max_values = []
    max_years = []
    ymin = 0
    ymax = 0
    for ssp in ssps:
        ens_list = []
        for gcm in gcms:
            try:
                df = basin_dict[mrbid][gcm][ssp]
                annual_runoff = df.glacier_runoff.isel(time=slice(0, -12)).groupby('time.year').sum()
                annual_runoff = annual_runoff.rolling(year=11, center=True, min_periods=1).mean()
                annual_runoff_chg = (annual_runoff - annual_runoff.isel(year=0)) / annual_runoff.isel(year=0)
                annual_runoff_chg *= 100
                ens_list.append(annual_runoff_chg)
            except KeyError:
                continue

        # When all is in the list. concat.
        ens_df = xr.concat(ens_list, dim='gcm')
        q25 = ens_df.quantile(0.25, dim='gcm')
        q75 = ens_df.quantile(0.75, dim='gcm')
        ax.flat[idx].fill_between(q25.year, q25, q75, zorder=1, alpha=0.3)
        # Stuff for vlines
        max_value = ens_df.mean(dim='gcm').max().values
        max_values.append(max_value)
        max_year = ens_df.mean(dim='gcm').isel(year=ens_df.mean(dim='gcm').argmax()).year
        max_years.append(max_year)
        ymin_t, ymax_t = ax.flat[idx].get_ylim()
        if ymin_t < ymin:
            ymin = ymin_t
        if ymax_t > ymax:
            ymax = ymax_t
    
        label = f'{ssp} (N={len(ens_df)})'
        ens_df.mean(dim='gcm').plot(ax=ax.flat[idx], lw=0.8, label=label)
        ax.flat[idx].set_ylabel('')
        ax.flat[idx].set_xlabel('')
        ax.flat[idx].tick_params(axis='y', labelsize=font_size, pad=-5)
        ax.flat[idx].tick_params(axis='x', labelsize=font_size, pad=-2)
        ax.flat[idx].xaxis.set_major_locator(mdates.AutoDateLocator(minticks=2, maxticks=5))

        # Grid
        ax.flat[idx].grid(which='both', linewidth=0.5)


        # Fix the years
        #ax.flat[idx].set_xlim(2020, 2100)

        # Get the basin name. 
        name = basins.iloc[i].RIVER_BASI
        # We don't need to plot the alternative names for now, so split it.
        name = re.split(r'[\(\)]', name)[0]
        if mrbid == '2910':
            name = name.split()[1]
        if mrbid == '6101':
            name = name.split()[0]
        ax.flat[idx].set_title(f'{name.title()}', size=font_size, pad=1.1)
        ax.flat[idx].tick_params('both')
    # Plot vlines
    colors = ['C0', 'C1', 'C2', 'C3']
    for value, year, color in zip(max_values, max_years, colors):
        ax.flat[idx].axvline(year, ymax=(value-ymin)/(ymax-ymin), zorder=3,
                   ls=':', color=color, lw=1.)
# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
for i in range(nplots - (nplots - nbasins - 1), nplots):
    ax.flat[i].set_visible(False)

# Make a legend.
handles, labels = ax.flat[5].get_legend_handles_labels()
fig.legend(handles, labels, ncol=4, bbox_to_anchor=(0.5, 0.94), fontsize=font_size,
           columnspacing=0.5, handlelength=1, loc='center')
# Figure title
# Do it in inkscape
# Common x/y labels
fig.text(0.5, 0.06, 'Year', ha='center', fontsize=font_size)
fig.text(0.07, 0.5, 'Relative change (%)', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 0.98, f'Glacier runoff projection ensembles for North America', ha='center', fontsize=font_size)


#fig.tight_layout()
plt.savefig(fig_path+f'peak_water_n_america_pres.pdf', bbox_inches='tight', facecolor='none')

In [None]:
basins_df.CONTINENT.unique()

# SPEI
Plots for SPEI

In [None]:
# We have a dual loop to load the data into a nested dict.
SPEI_dict = {}
for basin in basins_df.MRBID:
    gcm_dict = {}
    mrbid = str(basin)
    for gcm in gcms:
        # Create the path to the basin folder.
        path = os.path.join(base_path, mrbid) 
        # Name of the file.
        file = f'{mrbid}_SPEI_{gcm}_False_wref.nc' 
        # Full path to the file.
        path = os.path.join(path, file)
        # Open the file with xarray.
        with xr.open_dataset(path, use_cftime=True) as ds:
            gcm_dict[gcm] = ds
    # Put the scenario dict in the basin dict.
    SPEI_dict[mrbid] = gcm_dict 

 Utility function

In [None]:
def get_spei_ensemble(mrbid, ssp):
    
    spei_list = []
    spei_adj_list = []
    # Ensemble
    # try to conform all the series to the same calendar with use of one master idx.
    # Some of the gcms are on 16th instead of the 15th
    idx_master = SPEI_dict[mrbid][gcms[0]][f'SPEI_{ssp}'].sel(time=slice('2019', '2100')).indexes['time'].to_datetimeindex()
    for gcm in gcms:
        try:
            spei = SPEI_dict[mrbid][gcm][f'SPEI_{ssp}'].sel(time=slice('2019', '2100'))
            idx = spei.indexes['time'].to_datetimeindex()
            spei['time'] = idx 
            spei = spei.reindex({'time': idx_master}, method='nearest')
            spei_list.append(spei)
            spei_adj = SPEI_dict[mrbid][gcm][f'SPEI_adj_{ssp}'].sel(time=slice('2019', '2100'))
            idx = spei_adj.indexes['time'].to_datetimeindex()
            spei_adj['time'] = idx 
            spei_adj = spei_adj.reindex({'time': idx_master}, method='nearest')
            spei_adj_list.append(spei_adj)
        except KeyError:
            pass

    # Create the SPEI ensembles
    spei_ens = xr.concat(spei_list, dim='gcm')
    spei_adj_ens = xr.concat(spei_adj_list, dim='gcm')
    
    return spei_ens, spei_adj_ens

## Paper version of climatology plot

In [None]:
def plot_spei_climatology_paper(ssp, season=None, savefig=False, lat_sort=False):
    '''Plot the SPEI climatology (violins) of all basins for a certain rcp
    scenario. Paper version.
    
    Args:
    -----
    rcp: str
        I.e. 'rcp26'
    season: str
        Choose a season to plot. 'winter' or 'summer'.
    '''
    selection_dict = {'summer': {True: 'JJA', False: 'DJF'},
                      'winter': {True: 'DJF', False: 'JJA'}}
    
    # font size
    font_size = 7
    # Create the figure
    fig, ax = plt.subplots(ncols=7, nrows=11, figsize=(15/2.54, 24/2.54),
                           sharex=True, sharey=False)
    c0 = 'C1'
    c1 = 'C0'
    c2 = 'C2'
    alpha = 0.5
    basins = basins_df
    if lat_sort:
        basins['lat'] = basins['geometry'].centroid.y
        basins = basins.sort_values('lat', ascending=False)
        
    for i, basin in enumerate(basins.MRBID):
        shift = 2
        idx = i + shift
        mrbid = str(basin)
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        
        if season is not None:
            # Check wether we are in NH (true) or SH (False)
            hemisphere = basins.iloc[i]['lat'] > 0
            # Select the season.
            selection = selection_dict[season][hemisphere]
            spei = spei.sel(time=(spei['time.season']==selection))
            spei_adj = spei_adj.sel(time=(spei_adj['time.season']==selection))
            


        # spei_ens = spei_ens.sel(time=spei_ens['time.season'] == 'JJA')
        # spei_adj_ens = spei_adj_ens.sel(time=spei_adj_ens['time.season'] == 'JJA')

        lw = 0.5
        # Add the mean
        v1 = ax.flat[idx].violinplot(spei_ens.mean(dim='gcm'), showextrema=False)
        v2 = ax.flat[idx].violinplot(spei_adj_ens.mean(dim='gcm'), showextrema=False)

        for pc in v1['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            # pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('C1')
            pc.set_edgecolor('k')
            pc.set_alpha(0.5)
            pc.set_linewidth(lw)
            pc.set_zorder(2)
        for pc in v2['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            # pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('C0')
            pc.set_edgecolor('k')
            pc.set_alpha(0.5)
            pc.set_linewidth(lw)
            pc.set_zorder(2)

        # Add the 25th quantile 
        v3 = ax.flat[idx].violinplot(spei_ens.quantile(q=0.25, dim='gcm'), showextrema=False)
        v4 = ax.flat[idx].violinplot(spei_adj_ens.quantile(q=0.25, dim='gcm'), showextrema=False)
        for pc in v3['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], -np.inf, m)
            pc.set_facecolor('none')
            pc.set_edgecolor('C1')
            pc.set_linewidth(lw)
            pc.set_alpha(1)
        for pc in v4['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], -np.inf, m)
            pc.set_facecolor('none')
            pc.set_edgecolor('C0')
            pc.set_linewidth(lw)
            pc.set_alpha(1)

        # Add the 75th quantile 
        v3 = ax.flat[idx].violinplot(spei_ens.quantile(q=0.75, dim='gcm'), showextrema=False)
        v4 = ax.flat[idx].violinplot(spei_adj_ens.quantile(q=0.75, dim='gcm'), showextrema=False)
        for pc in v3['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('C1')
            pc.set_linewidth(lw)
            pc.set_alpha(1)
        for pc in v4['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('C0')
            pc.set_linewidth(lw)
            pc.set_alpha(1)

        # Get the basin name. 
        name = basins.iloc[i].RIVER_BASI
        # We don't need to plot the alternative names for now, so split it.
        name = re.split(r'[\(\)]', name)[0]
        if mrbid == '2910':
            name = name.split()[1]
        if mrbid == '6101':
            name = name.split()[0]
        ax.flat[idx].set_title(f'{name}', fontsize=font_size, pad=2)
        # Tick spacing
        #ax.flat[idx].set_ylim([-4, 4])
        # yticks
        # locs = ax.flat[idx].get_yticks()
        # loc = plticker.FixedLocator(locs)
        # ax.flat[i].yaxis.set_major_locator(loc)
        ax.flat[i].set_ylabel('')
        ax.flat[idx].set_yticklabels([])
        # xticks
        # ax.flat[idx].set_xlim([0.7, 1.3])
        # ax.flat[idx].set_xticks([0.85, 1.0, 1.15])
        # ax.flat[i].xaxis.set_major_formatter(StrMethodFormatter('{x:,.2f}'))
        ax.flat[idx].tick_params(axis='x', labelsize=font_size)
        ax.flat[idx].set_xticklabels([])
        # Add axis info to the first plot
        if i == shift:
            ax.flat[i].set_ylabel('SPEI', fontsize=font_size, labelpad=1)
            # ax.flat[i].set_yticklabels([None, -2, 0, 2, None], fontsize=font_size)
            ax.flat[i].tick_params('y', pad=-4)
            
        ax.flat[idx].grid(which='both', linewidth=0.5)
    # Legend stuff
    patches = [mpatches.Patch(facecolor=c0, edgecolor='k', alpha=alpha),
               mpatches.Patch(facecolor=c1, edgecolor='k', alpha=alpha)]
    labels = ['W.o. glaciers',
              'W. glaciers']
    #  Add the legend.
    fig.legend(patches, labels, bbox_to_anchor=(0.30, 0.895),
               frameon=True, framealpha=1,  labelcolor='k',
               fontsize=font_size, title=f'SPEI distributions\nfor {ssp}', title_fontsize=font_size)
    
    # Make the first plots invisible.
    for i in range(shift):
        ax.flat[i].set_visible(False)
        #plt.title(f'SPEI climatology {basin.RIVER_BASI}');
    plt.subplots_adjust(wspace=0.1, hspace=0.3)
    
    
    if savefig:
        fig_name = f'spei_clim_{ssp}_{season}_paper.pdf'
        path = os.path.join(fig_path, fig_name)
        fig.savefig(path, bbox_inches='tight', dpi=120)

In [None]:
plot_spei_climatology_paper('ssp370', savefig=True)

## SPEI annual cycle

In [None]:
basins_df.loc[basins_df.RIVER_BASI == "COPPER"]

In [None]:
spei_ens, spei_ens_adj = get_spei_ensemble("4408", "ssp126")

In [None]:
fig, ax = plt.subplots(figsize=(10/2.54, 7/2.54))
font_size = 7
spei_ens.mean(dim="gcm").groupby("time.month").mean().plot(ax=ax, lw=0.8, label="Prcp. only")
spei_ens_adj.mean(dim="gcm").groupby("time.month").mean().plot(ax=ax, lw=0.8, label="Prcp. + glacier")
# Legend
ax.legend(fontsize=font_size)
# Labels
ax.set_ylabel("SPEI", fontsize=font_size)
ax.set_xlabel("Month", fontsize=font_size)
# Tick labels
ax.tick_params("both", labelsize=font_size)
# Title
ax.set_title("Mean annual cycle of SPEI in the Copper basin, N. America", fontsize=font_size);
# Grid
ax.grid("both", lw=0.5)

path = os.path.join(fig_path, "spei_ann_cycle.pdf")
plt.savefig(path, bbox_inches="tight")

## Delta SPEI
These are the distributions plots

In [None]:
basin = basins_df.iloc[24]
mrbid = str(basin.MRBID)

# spei_ens = spei_ens.sel(time=spei_ens['time.season'] == 'JJA')
# spei_adj_ens = spei_adj_ens.sel(time=spei_adj_ens['time.season'] == 'JJA')

font_size = 7
shift = 0.2
fig, ax = plt.subplots(figsize=(5/2.54, 5/2.54))

for i, ssp in enumerate(ssps):
    # Get the ensembles
    spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
    # Add the mean
    v1 = ax.violinplot(spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm'), showextrema=False, positions=[i*shift])

    for pc in v1['bodies']:
        # get the center
        m = np.mean(pc.get_paths()[0].vertices[:, 0])
        # modify the paths to not go further right than the center
        pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
        pc.set_facecolor('C1')
        pc.set_edgecolor('k')
        pc.set_alpha(0.5)
        pc.set_linewidth(0.5)
        pc.set_zorder(4-(i*0.1))

    # Add the 25th quantile 
    q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
    v2 = ax.violinplot(q25_d, showextrema=False, positions=[i*shift])
    for pc in v2['bodies']:
        # get the center
        m = np.mean(pc.get_paths()[0].vertices[:, 0])
        # modify the paths to not go further right than the center
        pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
        pc.set_facecolor('none')
        pc.set_edgecolor('k')
        pc.set_linewidth(0.5)
        pc.set_linestyle(':')
        pc.set_alpha(0.5)

    # Add the 75th quantile 
    q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
    v3 = ax.violinplot(q75_d, showextrema=False, positions=[i*shift])
    for pc in v3['bodies']:
        # get the center
        m = np.mean(pc.get_paths()[0].vertices[:, 0])
        # modify the paths to not go further right than the center
        pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
        pc.set_facecolor('none')
        pc.set_edgecolor('k')
        pc.set_linewidth(0.5)
        pc.set_linestyle('--')
        pc.set_alpha(0.5)
# q1, median, q3 = spei.quantile([0.25, 0.5, 0.75]).values
# ax.vlines(ind+offset, q1, q3, color='C0', lw=5)
# q1, median, q3 = spei_adj.quantile([0.25, 0.5, 0.75]).values
# ax.vlines(ind-offset, q1, q3, color='C1', lw=5)

# Fix the xaxis
ax.set_xlim(0, 0.9)
ax.set_xticklabels(ssps, fontdict={'ha':'left', 'rotation': -25})
ax.tick_params('both', labelsize=font_size)
plt.ylabel('$\Delta$-SPEI', fontsize=font_size)
# Legend stuff
patches = [mpatches.Patch(color='C0', alpha=0.5),
           mpatches.Patch(color='C1', alpha=0.5)]
labels = ['W.o. glaciers',
          'W. glaciers']
# plt.legend(patches, labels)
plt.title(f'SPEI climatology {basin.RIVER_BASI}', fontsize=font_size);
plt.savefig(fig_path+'clim_test.pdf', bbox_inches='tight')

## Delta SPEI Europe
### All basins

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
selection_dict = {'summer': {True: 'JJA', False: 'DJF'},
                  'winter': {True: 'DJF', False: 'JJA'}}
lat_sort = False
season = None

basins = basins_df.loc[basins_df.CONTINENT == 'Europe']

shift = 1
nbasins = len(basins)
ncols = 4
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 7
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5
if lat_sort:
    basins['lat'] = basins['geometry'].centroid.y
    basins = basins.sort_values('lat', ascending=False)

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    if season is not None:
        # Check wether we are in NH (true) or SH (False)
        hemisphere = basins.iloc[i]['lat'] > 0
        # Select the season.
        selection = selection_dict[season][hemisphere]
        spei = spei.sel(time=(spei['time.season']==selection))
        spei_adj = spei_adj.sel(time=(spei_adj['time.season']==selection))




    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Add the mean
        v1 = ax.flat[idx].violinplot(spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm'),
                                     showextrema=False, positions=[0.01+j*dist_shift])

        for pc in v1['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor(colors[j])
            pc.set_edgecolor('k')
            pc.set_alpha(0.5)
            pc.set_linewidth(0.5)
            pc.set_zorder(4-(j*0.1))

        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        v2 = ax.flat[idx].violinplot(q25_d, showextrema=False, positions=[j*dist_shift])
        for pc in v2['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle(':')
            pc.set_alpha(0.5)

        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        v3 = ax.flat[idx].violinplot(q75_d, showextrema=False, positions=[j*dist_shift])
        for pc in v3['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle('--')
            pc.set_alpha(0.5)

    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].set_xticklabels([])
    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    ax.flat[idx].grid(which='both', linewidth=0.5)
# Legend stuff: Patches
patches = [mpatches.Patch(facecolor=color, edgecolor='k', alpha=alpha, lw=0.6) for color in colors]
labels = [f'GCM mean {ssp}'for ssp in ssps]
# Lines
lines = [Line2D([0], [0], color='k', ls=':', lw=0.6, alpha=0.5), Line2D([0], [0], color='k', ls='--', lw=0.6, alpha=0.5)]
line_labels = ['25th percentile', '75th percentile']

#  Add the legend.
fig.legend(patches + lines, labels + line_labels, bbox_to_anchor=(0.30, 0.895),
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.3, hspace=0.2)
# Labels
fig.text(0.07, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 0.92, '$\Delta$-SPEI distributions in Europe', ha='center', va='center', fontsize=font_size)


fig_name = f'delta_spei_clim_europe_{season}_paper.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', dpi=120)

### Selected basins

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
selection_dict = {'summer': {True: 'JJA', False: 'DJF'},
                  'winter': {True: 'DJF', False: 'JJA'}}
lat_sort = False
season = None

basins = basins_df.loc[basins_df.CONTINENT == 'Europe']
# A bit inefficient to do this twise but...
basins_sel = []
for basin in basins.MRBID:
    for ssp in ssps:
        spei_ens, spei_adj_ens = get_spei_ensemble(str(basin), ssp)
        delta_value = (spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')).std()
        if delta_value >= 0.1:
            basins_sel.append(basin)
basins_sel = list(set(basins_sel))
# Select the basins that fit the criteria
basins = basins[basins['MRBID'].isin(basins_sel)]
    

shift = 1
nbasins = len(basins)
ncols = 3
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 7
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5
if lat_sort:
    basins['lat'] = basins['geometry'].centroid.y
    basins = basins.sort_values('lat', ascending=False)

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    if season is not None:
        # Check wether we are in NH (true) or SH (False)
        hemisphere = basins.iloc[i]['lat'] > 0
        # Select the season.
        selection = selection_dict[season][hemisphere]
        spei = spei.sel(time=(spei['time.season']==selection))
        spei_adj = spei_adj.sel(time=(spei_adj['time.season']==selection))

    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Add the mean
        v1 = ax.flat[idx].violinplot(spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm'),
                                     showextrema=False, positions=[0.01+j*dist_shift])

        for pc in v1['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor(colors[j])
            pc.set_edgecolor('k')
            pc.set_alpha(0.5)
            pc.set_linewidth(0.5)
            pc.set_zorder(4-(j*0.1))

        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        v2 = ax.flat[idx].violinplot(q25_d, showextrema=False, positions=[j*dist_shift])
        for pc in v2['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle(':')
            pc.set_alpha(0.5)

        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        v3 = ax.flat[idx].violinplot(q75_d, showextrema=False, positions=[j*dist_shift])
        for pc in v3['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle('--')
            pc.set_alpha(0.5)

    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name.title()}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].set_xticklabels([])
    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    ax.flat[idx].grid(which='both', linewidth=0.5)
    
    # Add hlines at 0.
    ax.flat[idx].axhline(0, alpha=0.7, lw=0.5)

# Legend stuff: Patches
patches = [mpatches.Patch(facecolor=color, edgecolor='k', alpha=alpha, lw=0.6) for color in colors]
labels = [f'GCM ens. mean {ssp}'for ssp in ssps]
# Lines
lines = [Line2D([0], [0], color='k', ls=':', lw=0.6, alpha=0.5), Line2D([0], [0], color='k', ls='--', lw=0.6, alpha=0.5)]
line_labels = ['Ens. 25th percentile', 'Ens. 75th percentile']

#  Add the legend.
fig.legend(patches + lines, labels + line_labels, bbox_to_anchor=(0.35, 0.895),
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.2, hspace=0.2)
# Labels
fig.text(0.07, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 0.93, '$\Delta$-SPEI distributions in Europe where $\sigma \geq 0.1$', ha='center', va='center', fontsize=font_size)


fig_name = f'delta_spei_clim_europe_selection_{season}_paper.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', dpi=120)

### Selected basins temporal

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
basins = basins_df.loc[basins_df.CONTINENT == 'Europe']
# A bit inefficient to do this twise but...
basins_sel = []
for basin in basins.MRBID:
    for ssp in ssps:
        spei_ens, spei_adj_ens = get_spei_ensemble(str(basin), ssp)
        delta_value = (spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')).std()
        if delta_value >= 0.1:
            basins_sel.append(basin)
basins_sel = list(set(basins_sel))
# Select the basins that fit the criteria
basins = basins[basins['MRBID'].isin(basins_sel)]
    

shift = 1
nbasins = len(basins)
ncols = 3
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 7
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Label
        label = f'{ssp}\n N={len(spei_adj_ens)}'
        # Add the mean
        rolling_mean = spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')
        rolling_mean = rolling_mean.rolling(time=30*12, center=True, min_periods=1).mean()
        ax.flat[idx].plot(rolling_mean.time, rolling_mean, lw=0.8, label=label)


        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        q25_d = q25_d.rolling(time=30*12, center=True, min_periods=1).mean()
        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        q75_d = q75_d.rolling(time=30*12, center=True, min_periods=1).mean()
        # Fill between the quantiles
        ax.flat[idx].fill_between(q25_d.time, q25_d, q75_d, zorder=1, alpha=0.3)
    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name.title()}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    # ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    #ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].tick_params(axis='x', labelsize=font_size, pad=-2)
    ax.flat[idx].xaxis.set_major_locator(mdates.AutoDateLocator(minticks=2, maxticks=7))

    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    ax.flat[idx].grid(which='both', linewidth=0.5)
    
    # Grid
    ax.flat[idx].grid(which='both', linewidth=0.5)
    # Add hlines at 0.
    #ax.flat[idx].axhline(0, alpha=0.7, lw=0.5)

# Legend stuff
handles, labels = ax.flat[5].get_legend_handles_labels()
#  Add the legend.
fig.legend(handles, labels, bbox_to_anchor=(0.3, 0.895),
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size,
           )

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.2, hspace=0.2)
# Labels
fig.text(0.07, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 0.93, '30-year rolling mean $\Delta$-SPEI in Europe (Selected basins)', ha='center', va='center', fontsize=font_size)
fig.text(0.5, 0.04, 'Year', ha='center', fontsize=font_size)


fig_name = f'delta_spei_rolling_europe_selection_paper.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', dpi=120)

### Selected basins temporal presentation

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
basins = basins_df.loc[basins_df.CONTINENT == 'Europe']
# A bit inefficient to do this twise but...
basins_sel = []
for basin in basins.MRBID:
    for ssp in ssps:
        spei_ens, spei_adj_ens = get_spei_ensemble(str(basin), ssp)
        delta_value = (spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')).std()
        if delta_value >= 0.1:
            basins_sel.append(basin)
basins_sel = list(set(basins_sel))
# Select the basins that fit the criteria
basins = basins[basins['MRBID'].isin(basins_sel)]
    

shift = 1
nbasins = len(basins)
ncols = 3
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 17.5/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(11.5/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 5
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Label
        label = f'{ssp}\n N={len(spei_adj_ens)}'
        # Add the mean
        rolling_mean = spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')
        rolling_mean = rolling_mean.rolling(time=30*12, center=True, min_periods=1).mean()
        ax.flat[idx].plot(rolling_mean.time, rolling_mean, lw=0.8, label=label)


        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        q25_d = q25_d.rolling(time=30*12, center=True, min_periods=1).mean()
        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        q75_d = q75_d.rolling(time=30*12, center=True, min_periods=1).mean()
        # Fill between the quantiles
        ax.flat[idx].fill_between(q25_d.time, q25_d, q75_d, zorder=1, alpha=0.3)
    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name.title()}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    # ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    #ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].tick_params(axis='x', labelsize=font_size, pad=-2)
    ax.flat[idx].xaxis.set_major_locator(mdates.AutoDateLocator(minticks=4, maxticks=6))

    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    ax.flat[idx].grid(which='both', linewidth=0.5)
    
    # Grid
    ax.flat[idx].grid(which='both', linewidth=0.5)
    # Add hlines at 0.
    #ax.flat[idx].axhline(0, alpha=0.7, lw=0.5)

# Legend stuff
handles, labels = ax.flat[5].get_legend_handles_labels()
#  Add the legend.
fig.legend(handles, labels, bbox_to_anchor=(0.3, 0.895),
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size,
           )

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.2, hspace=0.2)
# Labels
fig.text(0.07, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 0.93, '30-year rolling mean $\Delta$-SPEI in Europe (Selected basins)', ha='center', va='center', fontsize=font_size)
fig.text(0.5, 0.04, 'Year', ha='center', fontsize=font_size)


fig_name = f'delta_spei_rolling_europe_selection_pres.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', dpi=120, facecolor="none")

### Selcted basins presentation

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
selection_dict = {'summer': {True: 'JJA', False: 'DJF'},
                  'winter': {True: 'DJF', False: 'JJA'}}
lat_sort = False
season = None

basins = basins_df.loc[basins_df.CONTINENT == 'Europe']
# A bit inefficient to do this twise but...
basins_sel = []
for basin in basins.MRBID:
    for ssp in ssps:
        spei_ens, spei_adj_ens = get_spei_ensemble(str(basin), ssp)
        delta_value = (spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')).std()
        if delta_value >= 0.1:
            basins_sel.append(basin)
basins_sel = list(set(basins_sel))
# Select the basins that fit the criteria
basins = basins[basins['MRBID'].isin(basins_sel)]
    

shift = 1
nbasins = len(basins)
ncols = 3
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 7
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5
if lat_sort:
    basins['lat'] = basins['geometry'].centroid.y
    basins = basins.sort_values('lat', ascending=False)

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    if season is not None:
        # Check wether we are in NH (true) or SH (False)
        hemisphere = basins.iloc[i]['lat'] > 0
        # Select the season.
        selection = selection_dict[season][hemisphere]
        spei = spei.sel(time=(spei['time.season']==selection))
        spei_adj = spei_adj.sel(time=(spei_adj['time.season']==selection))

    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Add the mean
        v1 = ax.flat[idx].violinplot(spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm'),
                                     showextrema=False, positions=[0.01+j*dist_shift])

        for pc in v1['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor(colors[j])
            pc.set_edgecolor('k')
            pc.set_alpha(0.5)
            pc.set_linewidth(0.5)
            pc.set_zorder(4-(j*0.1))

        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        v2 = ax.flat[idx].violinplot(q25_d, showextrema=False, positions=[j*dist_shift])
        for pc in v2['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle(':')
            pc.set_alpha(0.5)

        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        v3 = ax.flat[idx].violinplot(q75_d, showextrema=False, positions=[j*dist_shift])
        for pc in v3['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle('--')
            pc.set_alpha(0.5)

    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].set_xticklabels([])
    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    ax.flat[idx].grid(which='both', linewidth=0.5)
    
    # Add hlines at 0.
    ax.flat[idx].axhline(0, alpha=0.7, lw=0.5)

# Legend stuff: Patches
patches = [mpatches.Patch(facecolor=color, edgecolor='k', alpha=alpha, lw=0.6) for color in colors]
labels = [f'GCM ens. mean {ssp}'for ssp in ssps]
# Lines
lines = [Line2D([0], [0], color='k', ls=':', lw=0.6, alpha=0.5), Line2D([0], [0], color='k', ls='--', lw=0.6, alpha=0.5)]
line_labels = ['Ens. 25th percentile', 'Ens. 75th percentile']

#  Add the legend.
fig.legend(patches + lines, labels + line_labels, bbox_to_anchor=(0.35, 0.895),
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.2, hspace=0.2)
# Labels
fig.text(0.07, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 0.93, '$\Delta$-SPEI distributions in Europe where $\sigma \geq 0.1$', ha='center', va='center', fontsize=font_size)


fig_name = f'delta_spei_clim_europe_selection_{season}_pres.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', facecolor='none')

## Delta SPEI South America
### All basins

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
selection_dict = {'summer': {True: 'JJA', False: 'DJF'},
                  'winter': {True: 'DJF', False: 'JJA'}}
lat_sort = False
season = None

basins = basins_df.loc[basins_df.CONTINENT == 'South America']

shift = 0
nbasins = len(basins)
ncols = 4
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 7
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5
if lat_sort:
    basins['lat'] = basins['geometry'].centroid.y
    basins = basins.sort_values('lat', ascending=False)

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    if season is not None:
        # Check wether we are in NH (true) or SH (False)
        hemisphere = basins.iloc[i]['lat'] > 0
        # Select the season.
        selection = selection_dict[season][hemisphere]
        spei = spei.sel(time=(spei['time.season']==selection))
        spei_adj = spei_adj.sel(time=(spei_adj['time.season']==selection))




    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Add the mean
        v1 = ax.flat[idx].violinplot(spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm'),
                                     showextrema=False, positions=[0.01+j*dist_shift])

        for pc in v1['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor(colors[j])
            pc.set_edgecolor('k')
            pc.set_alpha(0.5)
            pc.set_linewidth(0.5)
            pc.set_zorder(4-(j*0.1))

        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        v2 = ax.flat[idx].violinplot(q25_d, showextrema=False, positions=[j*dist_shift])
        for pc in v2['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle(':')
            pc.set_alpha(0.5)

        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        v3 = ax.flat[idx].violinplot(q75_d, showextrema=False, positions=[j*dist_shift])
        for pc in v3['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle('--')
            pc.set_alpha(0.5)

    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name.title()}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].set_xticklabels([])
    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    ax.flat[idx].grid(which='both', linewidth=0.5)
# Legend stuff: Patches
patches = [mpatches.Patch(facecolor=color, edgecolor='k', alpha=alpha, lw=0.6) for color in colors]
labels = [f'GCM mean {ssp}'for ssp in ssps]
# Lines
lines = [Line2D([0], [0], color='k', ls=':', lw=0.6, alpha=0.5), Line2D([0], [0], color='k', ls='--', lw=0.6, alpha=0.5)]
line_labels = ['25th percentile', '75th percentile']

#  Add the legend.
fig.legend(patches + lines, labels + line_labels, ncol=3, bbox_to_anchor=(0.50, 0.92), loc='center',
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.3, hspace=0.2)
# Labels
fig.text(0.07, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 0.95, '$\Delta$-SPEI distributions in South America', ha='center', va='center', fontsize=font_size)


fig_name = f'delta_spei_clim_s_america_{season}_paper.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', dpi=120)

### Selected basins

Also includes Asia.

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
selection_dict = {'summer': {True: 'JJA', False: 'DJF'},
                  'winter': {True: 'DJF', False: 'JJA'}}
lat_sort = False
season = None

basins = basins_df.loc[(basins_df.CONTINENT == 'South America') | (basins_df.CONTINENT == 'Asia')]
# A bit inefficient to do this twise but...
basins_sel = []
for basin in basins.MRBID:
    for ssp in ssps:
        spei_ens, spei_adj_ens = get_spei_ensemble(str(basin), ssp)
        delta_value = (spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')).std()
        if delta_value >= 0.1:
            basins_sel.append(basin)
basins_sel = list(set(basins_sel))
# Select the basins that fit the criteria
basins = basins[basins['MRBID'].isin(basins_sel)]
    

shift = 0
nbasins = len(basins)
ncols = 3
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 7
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5
if lat_sort:
    basins['lat'] = basins['geometry'].centroid.y
    basins = basins.sort_values('lat', ascending=False)

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    if season is not None:
        # Check wether we are in NH (true) or SH (False)
        hemisphere = basins.iloc[i]['lat'] > 0
        # Select the season.
        selection = selection_dict[season][hemisphere]
        spei = spei.sel(time=(spei['time.season']==selection))
        spei_adj = spei_adj.sel(time=(spei_adj['time.season']==selection))

    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Add the mean
        v1 = ax.flat[idx].violinplot(spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm'),
                                     showextrema=False, positions=[0.01+j*dist_shift])

        for pc in v1['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor(colors[j])
            pc.set_edgecolor('k')
            pc.set_alpha(0.5)
            pc.set_linewidth(0.5)
            pc.set_zorder(4-(j*0.1))

        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        v2 = ax.flat[idx].violinplot(q25_d, showextrema=False, positions=[j*dist_shift])
        for pc in v2['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle(':')
            pc.set_alpha(0.5)

        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        v3 = ax.flat[idx].violinplot(q75_d, showextrema=False, positions=[j*dist_shift])
        for pc in v3['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle('--')
            pc.set_alpha(0.5)

    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name.title()}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].set_xticklabels([])
    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    ax.flat[idx].grid(which='both', linewidth=0.5)
    
    # Add hlines at 0.
    ax.flat[idx].axhline(0, alpha=0.7, lw=0.5)

# Legend stuff: Patches
patches = [mpatches.Patch(facecolor=color, edgecolor='k', alpha=alpha, lw=0.6) for color in colors]
labels = [f'GCM ens. mean {ssp}'for ssp in ssps]
# Lines
lines = [Line2D([0], [0], color='k', ls=':', lw=0.6, alpha=0.5), Line2D([0], [0], color='k', ls='--', lw=0.6, alpha=0.5)]
line_labels = ['Ens. 25th percentile', 'Ens. 75th percentile']

#  Add the legend.
fig.legend(patches + lines, labels + line_labels, ncol=3, bbox_to_anchor=(0.50, 0.98), loc='center',
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.2, hspace=0.2)
# Labels
fig.text(0.07, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 1.07, '$\Delta$-SPEI distributions in Asia and South America where $\sigma \geq 0.1$', ha='center', va='center', fontsize=font_size)


fig_name = f'delta_spei_clim_asia_s_america_selection_{season}_paper.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', dpi=120)

### Presentation selected basins

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
selection_dict = {'summer': {True: 'JJA', False: 'DJF'},
                  'winter': {True: 'DJF', False: 'JJA'}}
lat_sort = False
season = None

basins = basins_df.loc[(basins_df.CONTINENT == 'South America') | (basins_df.CONTINENT == 'Asia')]
# A bit inefficient to do this twise but...
basins_sel = []
for basin in basins.MRBID:
    for ssp in ssps:
        spei_ens, spei_adj_ens = get_spei_ensemble(str(basin), ssp)
        delta_value = (spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')).std()
        if delta_value >= 0.1:
            basins_sel.append(basin)
basins_sel = list(set(basins_sel))
# Select the basins that fit the criteria
basins = basins[basins['MRBID'].isin(basins_sel)]
    

shift = 0
nbasins = len(basins)
ncols = 3
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 7
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5
if lat_sort:
    basins['lat'] = basins['geometry'].centroid.y
    basins = basins.sort_values('lat', ascending=False)

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    if season is not None:
        # Check wether we are in NH (true) or SH (False)
        hemisphere = basins.iloc[i]['lat'] > 0
        # Select the season.
        selection = selection_dict[season][hemisphere]
        spei = spei.sel(time=(spei['time.season']==selection))
        spei_adj = spei_adj.sel(time=(spei_adj['time.season']==selection))

    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Add the mean
        v1 = ax.flat[idx].violinplot(spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm'),
                                     showextrema=False, positions=[0.01+j*dist_shift])

        for pc in v1['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor(colors[j])
            pc.set_edgecolor('k')
            pc.set_alpha(0.5)
            pc.set_linewidth(0.5)
            pc.set_zorder(4-(j*0.1))

        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        v2 = ax.flat[idx].violinplot(q25_d, showextrema=False, positions=[j*dist_shift])
        for pc in v2['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle(':')
            pc.set_alpha(0.5)

        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        v3 = ax.flat[idx].violinplot(q75_d, showextrema=False, positions=[j*dist_shift])
        for pc in v3['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle('--')
            pc.set_alpha(0.5)

    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name.title()}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].set_xticklabels([])
    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    ax.flat[idx].grid(which='both', linewidth=0.5)
    
    # Add hlines at 0.
    ax.flat[idx].axhline(0, alpha=0.7, lw=0.5)

# Legend stuff: Patches
patches = [mpatches.Patch(facecolor=color, edgecolor='k', alpha=alpha, lw=0.6) for color in colors]
labels = [f'GCM ens. mean {ssp}'for ssp in ssps]
# Lines
lines = [Line2D([0], [0], color='k', ls=':', lw=0.6, alpha=0.5), Line2D([0], [0], color='k', ls='--', lw=0.6, alpha=0.5)]
line_labels = ['Ens. 25th percentile', 'Ens. 75th percentile']

#  Add the legend.
fig.legend(patches + lines, labels + line_labels, ncol=3, bbox_to_anchor=(0.50, 0.98), loc='center',
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.2, hspace=0.2)
# Labels
fig.text(0.07, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 1.07, '$\Delta$-SPEI distributions in Asia and South America where $\sigma \geq 0.1$', ha='center', va='center', fontsize=font_size)


fig_name = f'delta_spei_clim_asia_s_america_selection_{season}_pres.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', facecolor='none')

### Selected basins temporal

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
basins = basins_df.loc[(basins_df.CONTINENT == 'South America') | (basins_df.CONTINENT == 'Asia')]
# A bit inefficient to do this twise but...
basins_sel = []
for basin in basins.MRBID:
    for ssp in ssps:
        spei_ens, spei_adj_ens = get_spei_ensemble(str(basin), ssp)
        delta_value = (spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')).std()
        if delta_value >= 0.1:
            basins_sel.append(basin)
basins_sel = list(set(basins_sel))
# Select the basins that fit the criteria
basins = basins[basins['MRBID'].isin(basins_sel)]
    

shift = 0
nbasins = len(basins)
ncols = 3
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 7
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Label
        label = f'{ssp}\n N={len(spei_adj_ens)}'
        # Add the mean
        rolling_mean = spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')
        rolling_mean = rolling_mean.rolling(time=30*12, center=True, min_periods=1).mean()
        ax.flat[idx].plot(rolling_mean.time, rolling_mean, lw=0.8, label=label)


        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        q25_d = q25_d.rolling(time=30*12, center=True, min_periods=1).mean()
        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        q75_d = q75_d.rolling(time=30*12, center=True, min_periods=1).mean()
        # Fill between the quantiles
        ax.flat[idx].fill_between(q25_d.time, q25_d, q75_d, zorder=1, alpha=0.3)
    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name.title()}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    # ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    #ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].tick_params(axis='x', labelsize=font_size, pad=-2)
    ax.flat[idx].xaxis.set_major_locator(mdates.AutoDateLocator(minticks=2, maxticks=7))

    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    ax.flat[idx].grid(which='both', linewidth=0.5)
    
    # Grid
    ax.flat[idx].grid(which='both', linewidth=0.5)
    # Add hlines at 0.
    #ax.flat[idx].axhline(0, alpha=0.7, lw=0.5)

# Legend stuff
handles, labels = ax.flat[5].get_legend_handles_labels()
#  Add the legend.
fig.legend(handles, labels , ncol=4, bbox_to_anchor=(0.50, 0.98), loc='center',
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.2, hspace=0.2)
# Labels
fig.text(0.07, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 1.05, '30-year rolling mean $\Delta$-SPEI in Asia and South America (Selected basins)', ha='center', fontsize=font_size)
fig.text(0.5, 0.04, 'Year', ha='center', fontsize=font_size)


fig_name = f'delta_spei_rolling_asia_s_america_selection_paper.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', dpi=120)

### Selected basins temporal presentaion

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
basins = basins_df.loc[(basins_df.CONTINENT == 'South America') | (basins_df.CONTINENT == 'Asia')]
# A bit inefficient to do this twise but...
basins_sel = []
for basin in basins.MRBID:
    for ssp in ssps:
        spei_ens, spei_adj_ens = get_spei_ensemble(str(basin), ssp)
        delta_value = (spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')).std()
        if delta_value >= 0.1:
            basins_sel.append(basin)
basins_sel = list(set(basins_sel))
# Select the basins that fit the criteria
basins = basins[basins['MRBID'].isin(basins_sel)]
    

shift = 0
nbasins = len(basins)
ncols = 3
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 17.5/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(11.5/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 5
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Label
        label = f'{ssp}\n N={len(spei_adj_ens)}'
        # Add the mean
        rolling_mean = spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')
        rolling_mean = rolling_mean.rolling(time=30*12, center=True, min_periods=1).mean()
        ax.flat[idx].plot(rolling_mean.time, rolling_mean, lw=0.8, label=label)


        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        q25_d = q25_d.rolling(time=30*12, center=True, min_periods=1).mean()
        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        q75_d = q75_d.rolling(time=30*12, center=True, min_periods=1).mean()
        # Fill between the quantiles
        ax.flat[idx].fill_between(q25_d.time, q25_d, q75_d, zorder=1, alpha=0.3)
    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name.title()}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    # ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    #ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].tick_params(axis='x', labelsize=font_size, pad=-2)
    ax.flat[idx].xaxis.set_major_locator(mdates.AutoDateLocator(minticks=3, maxticks=6))

    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    ax.flat[idx].grid(which='both', linewidth=0.5)
    
    # Grid
    ax.flat[idx].grid(which='both', linewidth=0.5)
    # Add hlines at 0.
    #ax.flat[idx].axhline(0, alpha=0.7, lw=0.5)

# Legend stuff
handles, labels = ax.flat[5].get_legend_handles_labels()
#  Add the legend.
fig.legend(handles, labels , ncol=1, bbox_to_anchor=(0.96, 0.72), loc='center',
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.2, hspace=0.2)
# Labels
fig.text(0.07, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 0.95, '30-year rolling mean $\Delta$-SPEI in Asia and South America (Selected basins)', ha='center', fontsize=font_size)
fig.text(0.5, 0.04, 'Year', ha='center', fontsize=font_size)


fig_name = f'delta_spei_rolling_asia_s_america_selection_pres.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', dpi=120, facecolor="none")

## Delta SPEI Asia

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
selection_dict = {'summer': {True: 'JJA', False: 'DJF'},
                  'winter': {True: 'DJF', False: 'JJA'}}
lat_sort = False
season = None

basins = basins_df.loc[(basins_df.CONTINENT == 'Asia') | (basins_df.CONTINENT == 'South-West Pacific')]

shift = 0
nbasins = len(basins)
ncols = 4
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 7
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5
if lat_sort:
    basins['lat'] = basins['geometry'].centroid.y
    basins = basins.sort_values('lat', ascending=False)

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    if season is not None:
        # Check wether we are in NH (true) or SH (False)
        hemisphere = basins.iloc[i]['lat'] > 0
        # Select the season.
        selection = selection_dict[season][hemisphere]
        spei = spei.sel(time=(spei['time.season']==selection))
        spei_adj = spei_adj.sel(time=(spei_adj['time.season']==selection))




    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Add the mean
        v1 = ax.flat[idx].violinplot(spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm'),
                                     showextrema=False, positions=[0.01+j*dist_shift])

        for pc in v1['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor(colors[j])
            pc.set_edgecolor('k')
            pc.set_alpha(0.5)
            pc.set_linewidth(0.5)
            pc.set_zorder(4-(j*0.1))

        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        v2 = ax.flat[idx].violinplot(q25_d, showextrema=False, positions=[j*dist_shift])
        for pc in v2['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle(':')
            pc.set_alpha(0.5)

        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        v3 = ax.flat[idx].violinplot(q75_d, showextrema=False, positions=[j*dist_shift])
        for pc in v3['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle('--')
            pc.set_alpha(0.5)

    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name.title()}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].set_xticklabels([])
    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    ax.flat[idx].grid(which='both', linewidth=0.5)
# Legend stuff: Patches
patches = [mpatches.Patch(facecolor=color, edgecolor='k', alpha=alpha, lw=0.6) for color in colors]
labels = [f'GCM mean {ssp}'for ssp in ssps]
# Lines
lines = [Line2D([0], [0], color='k', ls=':', lw=0.6, alpha=0.5), Line2D([0], [0], color='k', ls='--', lw=0.6, alpha=0.5)]
line_labels = ['25th percentile', '75th percentile']

#  Add the legend.
fig.legend(patches + lines, labels + line_labels, ncol=3, bbox_to_anchor=(0.50, 0.92), loc='center',
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.3, hspace=0.2)
# Labels
fig.text(0.07, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 0.95, '$\Delta$-SPEI distributions in Asia and New Zealand', ha='center', va='center', fontsize=font_size)


fig_name = f'delta_spei_clim_asia_{season}_paper.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', dpi=120)

## Delta SPEI North America

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
selection_dict = {'summer': {True: 'JJA', False: 'DJF'},
                  'winter': {True: 'DJF', False: 'JJA'}}
lat_sort = False
season = None

basins = basins_df.loc[(basins_df.CONTINENT == 'North America, Central America and the Caribbean')]

shift = 0
nbasins = len(basins)
ncols = 4
nrows = (nbasins // ncols) + shift
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 7
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5
if lat_sort:
    basins['lat'] = basins['geometry'].centroid.y
    basins = basins.sort_values('lat', ascending=False)

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    if season is not None:
        # Check wether we are in NH (true) or SH (False)
        hemisphere = basins.iloc[i]['lat'] > 0
        # Select the season.
        selection = selection_dict[season][hemisphere]
        spei = spei.sel(time=(spei['time.season']==selection))
        spei_adj = spei_adj.sel(time=(spei_adj['time.season']==selection))




    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Add the mean
        v1 = ax.flat[idx].violinplot(spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm'),
                                     showextrema=False, positions=[0.01+j*dist_shift])

        for pc in v1['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor(colors[j])
            pc.set_edgecolor('k')
            pc.set_alpha(0.5)
            pc.set_linewidth(0.5)
            pc.set_zorder(4-(j*0.1))

        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        v2 = ax.flat[idx].violinplot(q25_d, showextrema=False, positions=[j*dist_shift])
        for pc in v2['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle(':')
            pc.set_alpha(0.5)

        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        v3 = ax.flat[idx].violinplot(q75_d, showextrema=False, positions=[j*dist_shift])
        for pc in v3['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle('--')
            pc.set_alpha(0.5)

    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name.title()}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].set_xticklabels([])
    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    ax.flat[idx].grid(which='both', linewidth=0.5)
# Legend stuff: Patches
patches = [mpatches.Patch(facecolor=color, edgecolor='k', alpha=alpha, lw=0.6) for color in colors]
labels = [f'GCM mean {ssp}'for ssp in ssps]
# Lines
lines = [Line2D([0], [0], color='k', ls=':', lw=0.6, alpha=0.5), Line2D([0], [0], color='k', ls='--', lw=0.6, alpha=0.5)]
line_labels = ['25th percentile', '75th percentile']

#  Add the legend.
fig.legend(patches + lines, labels + line_labels, ncol=3, bbox_to_anchor=(0.50, 0.93), loc='center',
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.3, hspace=0.2)
# Labels
fig.text(0.07, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.5, 0.97, '$\Delta$-SPEI distributions in North America', ha='center', va='center', fontsize=font_size)


fig_name = f'delta_spei_clim_n_america_{season}_paper.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', dpi=120)

### Selected basins

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
selection_dict = {'summer': {True: 'JJA', False: 'DJF'},
                  'winter': {True: 'DJF', False: 'JJA'}}
lat_sort = False
season = None

basins = basins_df.loc[(basins_df.CONTINENT == 'North America, Central America and the Caribbean')]
# A bit inefficient to do this twise but...
basins_sel = []
for basin in basins.MRBID:
    for ssp in ssps:
        spei_ens, spei_adj_ens = get_spei_ensemble(str(basin), ssp)
        delta_value = (spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')).std()
        if delta_value >= 0.1:
            basins_sel.append(basin)
basins_sel = list(set(basins_sel))
# Select the basins that fit the criteria
basins = basins[basins['MRBID'].isin(basins_sel)]
    

shift = 2
nbasins = len(basins)
ncols = 3
nrows = (nbasins // ncols) + shift - 1
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 7
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5
if lat_sort:
    basins['lat'] = basins['geometry'].centroid.y
    basins = basins.sort_values('lat', ascending=False)

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    if season is not None:
        # Check wether we are in NH (true) or SH (False)
        hemisphere = basins.iloc[i]['lat'] > 0
        # Select the season.
        selection = selection_dict[season][hemisphere]
        spei = spei.sel(time=(spei['time.season']==selection))
        spei_adj = spei_adj.sel(time=(spei_adj['time.season']==selection))

    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Add the mean
        v1 = ax.flat[idx].violinplot(spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm'),
                                     showextrema=False, positions=[0.01+j*dist_shift])

        for pc in v1['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor(colors[j])
            pc.set_edgecolor('k')
            pc.set_alpha(0.5)
            pc.set_linewidth(0.5)
            pc.set_zorder(4-(j*0.1))

        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        v2 = ax.flat[idx].violinplot(q25_d, showextrema=False, positions=[j*dist_shift])
        for pc in v2['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle(':')
            pc.set_alpha(0.5)

        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        v3 = ax.flat[idx].violinplot(q75_d, showextrema=False, positions=[j*dist_shift])
        for pc in v3['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle('--')
            pc.set_alpha(0.5)

    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name.title()}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].set_xticklabels([])
    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    ax.flat[idx].grid(which='both', linewidth=0.5)
    
    # Add hlines at 0.
    ax.flat[idx].axhline(0, alpha=0.7, lw=0.5)

# Legend stuff: Patches
patches = [mpatches.Patch(facecolor=color, edgecolor='k', alpha=alpha, lw=0.6) for color in colors]
labels = [f'GCM ens. mean {ssp}'for ssp in ssps]
# Lines
lines = [Line2D([0], [0], color='k', ls=':', lw=0.6, alpha=0.5), Line2D([0], [0], color='k', ls='--', lw=0.6, alpha=0.5)]
line_labels = ['Ens. 25th percentile', 'Ens. 75th percentile']

#  Add the legend.
fig.legend(patches + lines, labels + line_labels, ncol=2, bbox_to_anchor=(0.35, 0.78), loc='center',
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.2, hspace=0.2)
# Labels
fig.text(0.07, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.35, 0.86, '$\Delta$-SPEI distributions in North America where $\sigma \geq 0.1$', ha='center', va='center', fontsize=font_size)


fig_name = f'delta_spei_clim_n_america_selection_{season}_paper.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', dpi=120)

### Presentation: Selected basins

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
selection_dict = {'summer': {True: 'JJA', False: 'DJF'},
                  'winter': {True: 'DJF', False: 'JJA'}}
lat_sort = False
season = None

basins = basins_df.loc[(basins_df.CONTINENT == 'North America, Central America and the Caribbean')]
# A bit inefficient to do this twise but...
basins_sel = []
for basin in basins.MRBID:
    for ssp in ssps:
        spei_ens, spei_adj_ens = get_spei_ensemble(str(basin), ssp)
        delta_value = (spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')).std()
        if delta_value >= 0.1:
            basins_sel.append(basin)
basins_sel = list(set(basins_sel))
# Select the basins that fit the criteria
basins = basins[basins['MRBID'].isin(basins_sel)]
    

shift = 2
nbasins = len(basins)
ncols = 3
nrows = (nbasins // ncols) + shift - 1
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 7
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5
if lat_sort:
    basins['lat'] = basins['geometry'].centroid.y
    basins = basins.sort_values('lat', ascending=False)

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    if season is not None:
        # Check wether we are in NH (true) or SH (False)
        hemisphere = basins.iloc[i]['lat'] > 0
        # Select the season.
        selection = selection_dict[season][hemisphere]
        spei = spei.sel(time=(spei['time.season']==selection))
        spei_adj = spei_adj.sel(time=(spei_adj['time.season']==selection))

    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Add the mean
        v1 = ax.flat[idx].violinplot(spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm'),
                                     showextrema=False, positions=[0.01+j*dist_shift])

        for pc in v1['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor(colors[j])
            pc.set_edgecolor('k')
            pc.set_alpha(0.5)
            pc.set_linewidth(0.5)
            pc.set_zorder(4-(j*0.1))

        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        v2 = ax.flat[idx].violinplot(q25_d, showextrema=False, positions=[j*dist_shift])
        for pc in v2['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle(':')
            pc.set_alpha(0.5)

        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        v3 = ax.flat[idx].violinplot(q75_d, showextrema=False, positions=[j*dist_shift])
        for pc in v3['bodies']:
            # get the center
            m = np.mean(pc.get_paths()[0].vertices[:, 0])
            # modify the paths to not go further right than the center
            pc.get_paths()[0].vertices[:, 0] = np.clip(pc.get_paths()[0].vertices[:, 0], m, np.inf)
            pc.set_facecolor('none')
            pc.set_edgecolor('k')
            pc.set_linewidth(0.5)
            pc.set_linestyle('--')
            pc.set_alpha(0.5)

    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].set_xticklabels([])
    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    ax.flat[idx].grid(which='both', linewidth=0.5)
    
    # Add hlines at 0.
    ax.flat[idx].axhline(0, alpha=0.7, lw=0.5)

# Legend stuff: Patches
patches = [mpatches.Patch(facecolor=color, edgecolor='k', alpha=alpha, lw=0.6) for color in colors]
labels = [f'GCM ens. mean {ssp}'for ssp in ssps]
# Lines
lines = [Line2D([0], [0], color='k', ls=':', lw=0.6, alpha=0.5), Line2D([0], [0], color='k', ls='--', lw=0.6, alpha=0.5)]
line_labels = ['Ens. 25th percentile', 'Ens. 75th percentile']

#  Add the legend.
fig.legend(patches + lines, labels + line_labels, ncol=2, bbox_to_anchor=(0.35, 0.78), loc='center',
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.2, hspace=0.2)
# Labels
fig.text(0.07, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.35, 0.86, '$\Delta$-SPEI distributions in North America where $\sigma \geq 0.1$', ha='center', va='center', fontsize=font_size)


fig_name = f'delta_spei_clim_n_america_selection_{season}_pres.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', facecolor='none')

### Selected basins temporal

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
basins = basins_df.loc[(basins_df.CONTINENT == 'North America, Central America and the Caribbean')]
# A bit inefficient to do this twise but...
basins_sel = []
for basin in basins.MRBID:
    for ssp in ssps:
        spei_ens, spei_adj_ens = get_spei_ensemble(str(basin), ssp)
        delta_value = (spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')).std()
        if delta_value >= 0.1:
            basins_sel.append(basin)
basins_sel = list(set(basins_sel))
# Select the basins that fit the criteria
basins = basins[basins['MRBID'].isin(basins_sel)]
    

shift = 2
nbasins = len(basins)
ncols = 3
nrows = (nbasins // ncols) + shift - 1
nplots = ncols*nrows
height_factor = 24/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 7
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Label
        label = f'{ssp} N={len(spei_adj_ens)}'
        # Add the mean
        rolling_mean = spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')
        rolling_mean = rolling_mean.rolling(time=30*12, center=True, min_periods=1).mean()
        ax.flat[idx].plot(rolling_mean.time, rolling_mean, lw=0.8, label=label)


        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        q25_d = q25_d.rolling(time=30*12, center=True, min_periods=1).mean()
        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        q75_d = q75_d.rolling(time=30*12, center=True, min_periods=1).mean()
        # Fill between the quantiles
        ax.flat[idx].fill_between(q25_d.time, q25_d, q75_d, zorder=1, alpha=0.3)
    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name.title()}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    # ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    #ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].tick_params(axis='x', labelsize=font_size, pad=-2)
    ax.flat[idx].xaxis.set_major_locator(mdates.AutoDateLocator(minticks=2, maxticks=7))

    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    ax.flat[idx].grid(which='both', linewidth=0.5)
    
    # Grid
    ax.flat[idx].grid(which='both', linewidth=0.5)
    # Add hlines at 0.
    #ax.flat[idx].axhline(0, alpha=0.7, lw=0.5)

# Legend stuff
handles, labels = ax.flat[5].get_legend_handles_labels()
#  Add the legend.
fig.legend(handles, labels , ncol=2, bbox_to_anchor=(0.35, 0.78), loc='center',
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.2, hspace=0.2)
# Labels
fig.text(0.05, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.35, 0.86, '30-year rolling mean $\Delta$-SPEI in North America (Selected basins)', ha='center', fontsize=font_size)
fig.text(0.5, 0.06, 'Year', ha='center', fontsize=font_size)


fig_name = f'delta_spei_rolling_n_america_selection_paper.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', dpi=120)

### Selected basins temporal presentation

In [None]:
'''Plot the SPEI climatology (violins) of all basins for a certain rcp
scenario. Paper version.

Args:
-----
season: str
    Choose a season to plot. 'winter' or 'summer'.
'''
basins = basins_df.loc[(basins_df.CONTINENT == 'North America, Central America and the Caribbean')]
# A bit inefficient to do this twise but...
basins_sel = []
for basin in basins.MRBID:
    for ssp in ssps:
        spei_ens, spei_adj_ens = get_spei_ensemble(str(basin), ssp)
        delta_value = (spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')).std()
        if delta_value >= 0.1:
            basins_sel.append(basin)
basins_sel = list(set(basins_sel))
# Select the basins that fit the criteria
basins = basins[basins['MRBID'].isin(basins_sel)]
    

shift = 2
nbasins = len(basins)
ncols = 3
nrows = (nbasins // ncols) + shift - 1
nplots = ncols*nrows
height_factor = 17.5/6
height = height_factor * nrows

# Create the figure
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(11.5/2.54, height/2.54),
                       sharex=True, sharey=False)
colors = ['C0', 'C1', 'C2', 'C3']
# font size
font_size = 5
c0 = 'C1'
c1 = 'C0'
c2 = 'C2'
alpha = 0.5

for i, basin in enumerate(basins.MRBID):
    idx = i + shift
    mrbid = str(basin)

    lw = 0.5
    dist_shift = 0.2
    # Add the mean
    for j, ssp in enumerate(ssps):
        # Get the ensembles
        spei_ens, spei_adj_ens = get_spei_ensemble(mrbid, ssp)
        # Label
        label = f'{ssp} N={len(spei_adj_ens)}'
        # Add the mean
        rolling_mean = spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')
        rolling_mean = rolling_mean.rolling(time=30*12, center=True, min_periods=1).mean()
        ax.flat[idx].plot(rolling_mean.time, rolling_mean, lw=0.8, label=label)


        # Add the 25th quantile 
        q25_d = spei_adj_ens.quantile(q=0.25, dim='gcm') - spei_ens.quantile(q=0.25, dim='gcm')
        q25_d = q25_d.rolling(time=30*12, center=True, min_periods=1).mean()
        # Add the 75th quantile 
        q75_d = spei_adj_ens.quantile(q=0.75, dim='gcm') - spei_ens.quantile(q=0.75, dim='gcm')
        q75_d = q75_d.rolling(time=30*12, center=True, min_periods=1).mean()
        # Fill between the quantiles
        ax.flat[idx].fill_between(q25_d.time, q25_d, q75_d, zorder=1, alpha=0.3)
    # Get the basin name. 
    name = basins.iloc[i].RIVER_BASI
    # We don't need to plot the alternative names for now, so split it.
    name = re.split(r'[\(\)]', name)[0]
    if mrbid == '2910':
        name = name.split()[1]
    if mrbid == '6101':
        name = name.split()[0]
    ax.flat[idx].set_title(f'{name.title()}', fontsize=font_size, pad=2)
    # Tick spacing
    # yticks
    ax.flat[i].set_ylabel('')
    # ax.flat[idx].ticklabel_format(axis='y', scilimits=(-4, -2), useMathText=True)
    ax.flat[idx].yaxis.offsetText.set_fontsize(5)
    ax.flat[idx].get_yaxis().get_offset_text().set_x(-0.15)
    #if idx != shift:
    # xticks
    #ax.flat[idx].set_xlim(0, 0.9)
    ax.flat[idx].tick_params(axis='x', labelsize=font_size, pad=-2)
    ax.flat[idx].xaxis.set_major_locator(mdates.AutoDateLocator(minticks=3, maxticks=6))

    ax.flat[idx].tick_params(axis='both', labelsize=font_size)
    ax.flat[idx].tick_params('y', pad=-4)

    # Grid
    ax.flat[idx].grid(which='both', linewidth=0.5)
    # Add hlines at 0.
    #ax.flat[idx].axhline(0, alpha=0.7, lw=0.5)

# Legend stuff
handles, labels = ax.flat[5].get_legend_handles_labels()
#  Add the legend.
fig.legend(handles, labels , ncol=2, bbox_to_anchor=(0.35, 0.78), loc='center',
           frameon=True, framealpha=1,  labelcolor='k',
           fontsize=font_size, title_fontsize=font_size)

# Make the first plots invisible.
for i in range(shift):
    ax.flat[i].set_visible(False)
    
# Margins
plt.subplots_adjust(wspace=0.25, hspace=0.2)
# Labels
fig.text(0.05, 0.5, '$\Delta$-SPEI', va='center', rotation='vertical', fontsize=font_size)
fig.text(0.35, 0.86, '30-year rolling mean $\Delta$-SPEI in North America (Selected basins)', ha='center', fontsize=font_size)
fig.text(0.5, 0.06, 'Year', ha='center', fontsize=font_size)


fig_name = f'delta_spei_rolling_n_america_selection_pres.pdf'
path = os.path.join(fig_path, fig_name)
fig.savefig(path, bbox_inches='tight', dpi=120, facecolor="none")

# Counting droughts

In [None]:
# A bit inefficient to do this twise but...
basins_sel = []
for basin in basins_df.MRBID:
    for ssp in ssps:
        spei_ens, spei_adj_ens = get_spei_ensemble(str(basin), ssp)
        delta_value = (spei_adj_ens.mean(dim='gcm') - spei_ens.mean(dim='gcm')).std()
        if delta_value >= 0.1:
            basins_sel.append(basin)
basins_sel = list(set(basins_sel))
# Select the basins that fit the criteria
basins = basins_df[basins_df['MRBID'].isin(basins_sel)]

In [None]:
basins

In [None]:
def count_droughts(basins_df, threshold=-1.0):
    '''Utility function for counting droughts in the basins depending on
    the inclusion of the glacier runoff or not. Generates a pandas dataframe. Basins as 
    as rows and rcp scenario + glacier runoff inclusion in the columns. Hence it should be
    75 rows and 8 columns.
    
    Args:
    -----
    basins_df: gepandas dataframe
        Contains the basin data.
    threshold: float
        The drought threshold.
    
    Returns:
    --------
    Pandas dataframe with the counts.
    '''
    cases = ['months_', 'months_adj_']
    
    cols = [x + ssp for ssp in ssps for x in cases]
    df = pd.DataFrame(index=basins_df.MRBID, columns=cols)
    for mrbid in basins_df.MRBID:
        for ssp in ssps:
            # Select data and count months.
            spei_ens, spei_adj_ens = get_spei_ensemble(str(mrbid), ssp)
            no_gl_count = (spei_ens.mean(dim='gcm') < -spei_ens.mean(dim='gcm').std()).sum()
            gl_count = (spei_adj_ens.mean(dim='gcm') < -spei_adj_ens.mean(dim='gcm').std()).sum()
            
            # Add the data to df.
            df.loc[mrbid][f'months_{ssp}'] = float(no_gl_count.values)
            df.loc[mrbid][f'months_adj_{ssp}'] = float(gl_count.values)
    
    return df

In [None]:
mrbid = basins.MRBID.iloc[11]
ssp = 'ssp245'
spei_ens, spei_adj_ens = get_spei_ensemble(str(mrbid), ssp)
spei_adj_ens.isel(gcm=6).std()

In [None]:
count_droughts(basins)

In [None]:
def drought_diff(basins, threshold=-1.0):
    '''Calculate the drought differenct'''
    # Get the counted droughts df
    count_df = count_droughts(basins, threshold)
    # Then we diff the columns, and select every other column.
    # These contain the values we want.
    diff_df = count_df.diff(axis=1).iloc[:, 1::2]
    # Rename the columns in the new df.
    diff_df.columns = ssps
    diff_df = diff_df.abs()
    return diff_df

In [None]:
diff_df = drought_diff(basins)

**Minimum and maximum drought numbers**

In [None]:
diff_df.max()

In [None]:
diff_df.min()

### Drought conunt thesis

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
def plot_drought_diff(basins_df, threshold=-1.0, savefig=False):
    '''Plot the differences in droughts months for all basins in basins_df.
    '''
    # Get the counted droughts df
    drought_count_diff = drought_diff(basins_df)
    # Set up the plot
    rows = basins.shape[0]
    fig, ax = plt.subplots(figsize=(6/2.54, 24/2.54))
    for i in range(4):
        for j in range(0, rows):
            # Get the radius for the circle.
            radius = np.interp(drought_count_diff.iloc[j, i],
                               [drought_count_diff.min().min(),
                                drought_count_diff.max().max()],
                               [0.2, 0.4])
            # Get the color.   
            cmap = cm.viridis
            norm = mcolors.Normalize(drought_count_diff.min().min(),
                                     drought_count_diff.max().max())
            col = cmap(norm(drought_count_diff.iloc[j, i]))
            circle = mpatches.CirclePolygon((i, j-0.2), color=col, radius=radius,
                                            resolution=30)
            ax.add_artist(circle)
            ax.axis('equal')
    # xlimits.
    ax.set_xlim(-1.5, 3.5)       
    ax.set_ylim(rows, -0.9)       
    ax.plot([-0.9, 3.5], [rows, -0.9], alpha=0)
    #ax.axis('off')
    
    # Lets get some annotations in here as well. 
    for i in range(rows):
        name = basins.iloc[i].RIVER_BASI
        name = name.title()
        name = re.split(r'[\(\)]', name)[0]
        name = re.sub('\s', '\n', name)
        if i == 14:
            name = re.sub('(.{7})', '\\1-\n', name)
        plt.annotate(xy=(-1.5, i-0.2), text=name, size=font_size,
                     horizontalalignment='left', verticalalignment='center')
    # xlabels...
    fancy_ssps = ['1-2.6', '2-4.5', '3-7.0', '5-8.5']
    for i, rcp in enumerate(fancy_ssps):
        plt.annotate(xy=(i, rows-0.3), text=rcp,
                     horizontalalignment='center', size=font_size)
    plt.annotate(xy=(-0.8, rows-0.3), text='SSP',
                 horizontalalignment='center', size=font_size)
    # Colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('bottom', size='2%', pad='1%') 
    cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap),
                 cax=cax, orientation='horizontal',
            )
    cb.ax.tick_params(labelsize=font_size)
    cb.set_label(label=f'Month count difference', size=font_size)
    # Styling
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.grid(False)
    cax.set_facecolor('none')
    ax.set_facecolor('none')
    
    # Title
    ax.set_title('Difference in the number of months with\n drought for the selected basins', fontsize=font_size, y=0.98)
    #plt.tight_layout()
    # return diff_df
    if savefig:
        fig_name = f'delta_dots_selection.pdf'
        path = os.path.join(fig_path, fig_name)
        fig.savefig(path, bbox_inches='tight')

In [None]:
plot_drought_diff(basins, savefig=True) 

### Drought conunt presentation

In [None]:
for i in reversed(range(0, 4)):
    print(i)

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
def plot_drought_diff_pres(basins_df, threshold=-1.0, savefig=False):
    '''Plot the differences in droughts months for all basins in basins_df.
    '''
    font_size = 4
    # Get the counted droughts df
    drought_count_diff = drought_diff(basins_df)
    # Set up the plot
    rows = basins.shape[0]
    fig, ax = plt.subplots(figsize=(9/2.54, 13/2.54))
    for i, i_rev in zip(range(4), reversed(range(4))):
        for j in range(0, rows):
            # Get the radius for the circle.
            radius = np.interp(drought_count_diff.iloc[j, i_rev],
                               [drought_count_diff.min().min(),
                                drought_count_diff.max().max()],
                               [0.2, 0.4])
            # Get the color.   
            cmap = cm.viridis
            norm = mcolors.Normalize(drought_count_diff.min().min(),
                                     drought_count_diff.max().max())
            col = cmap(norm(drought_count_diff.iloc[j, i_rev]))
            circle = mpatches.CirclePolygon((i, j-0.2), color=col, radius=radius,
                                            resolution=30)
            ax.add_artist(circle)
            ax.axis('equal')
    # xlimits.
    ax.set_xlim(-1.7, 3.5)       
    ax.set_ylim(rows, -1.5)       
    ax.plot([-1.7, 3.5], [rows, -1.5], color="none")
    #ax.axis('off')
    
    # Lets get some annotations in here as well. 
    for i in range(rows):
        name = basins.iloc[i].RIVER_BASI
        name = name.title()
        name = re.split(r'[\(\)]', name)[0]
        #name = re.sub('\s', '\n', name)
        if name == "Jokulsa A Fjollum":
            name = "Jokulsa A\nFjollum"
        if i == 14:
            name = re.sub('(.{7})', '\\1-\n', name)
        plt.annotate(xy=(-1.6, i-0.2), text=name, size=font_size,
                     horizontalalignment='left', va='center', ha="left", rotation=-45)
    # xlabels...
    fancy_ssps = ['1-2.6', '2-4.5', '3-7.0', '5-8.5']
    for i, rcp in enumerate(fancy_ssps[::-1]):
        plt.annotate(xy=(i, -1.4), text=rcp,
                     ha='center', va="top", size=font_size, rotation=-90)
    plt.annotate(xy=(3.4, -1.4), text='SSP',
                 ha='left', va="top", size=font_size, rotation=-90)
    # Colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('bottom', size='2%', pad=-0.05)
    cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap),
                 cax=cax, orientation='horizontal',
            )
    cb.ax.tick_params(labelsize=font_size, width=0, pad=0.05, rotation=-90)
    cb.ax.xaxis.set_tick_params(pad=-3)
    cb.set_label(label=f'Month drought count diff.', size=font_size)
    # Styling
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.grid(False)
    cax.set_facecolor('none')
    ax.set_facecolor('none')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    
    # Title
    ax.set_title('Difference in the number of months with drought\nfor the selected basins', fontsize=font_size, x=1.05, y=0.5, ha="center", va="center", rotation=-90)
    #plt.tight_layout()
    # return diff_df
    if savefig:
        fig_name = f'delta_dots_selection_pres.pdf'
        path = os.path.join(fig_path, fig_name)
        fig.savefig(path, bbox_inches='tight', facecolor="none", transparent=True)

In [None]:
plot_drought_diff_pres(basins, savefig=True) 