# Code for creating the paper figures

In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cartopy.feature as cfeature
from cartopy.feature import NaturalEarthFeature
import cartopy.crs as ccrs
import matplotlib as mpl
import matplotlib.colors as mcolors
import math
import matplotlib.patches as mpatches
import os
from shapely.geometry import Point
import geopandas as gpd
import matplotlib.cm as mpl_cm

datadir = '/g/data/w97/mg5624/RF_project/'

## Figure 1: Map of Natrual Resource Management (NRM) Regions

In [None]:
NRM_clusters = xr.open_dataset(
    '/g/data/w97/amu561/Steven_CABLE_runs/shapefiles/NRM/NRM_clusters.nc'
)['NRM_cluster']
NRM_clusters = NRM_clusters.where(NRM_clusters != 0)

NRM_REGIONS = {
    'Central_Slopes': 1,
    'East_Coast': 2,
    'Murray_Basin': 4,
    'Monsoonal_North': 5,
    'Rangelands': 6,
    'Southern_Slopes': 7,
    'S_SW_Flatlands': 8,
    'Wet Tropics': 9,
}

REGION_NAMES = {
    1: 'Central Slopes',
    2: 'East Coast', 
    4: 'Murray Basin',
    5: 'Monsoonal North',
    6: 'Rangelands',
    7: 'Southern Slopes',
    8: 'S/SW Flatlands',
    9: 'Wet Tropics'
}

region_colors = {
    1: '#1f77b4',  # blue
    2: '#ff7f0e',  # orange
    4: '#d62728',  # red
    5: '#8c564b',  # brown
    6: '#e377c2',  # pink
    7: '#7f7f7f',  # gray
    8: '#bcbd22',  # yellow-green
    9: '#17becf'   # turquoise
}

# Rearrange legend labels
legend_order = [5, 9, 6, 2, 1, 8, 4, 7]  # MN, WT, RL, EC, MB, CS, SSW, SS

# Create a smaller figure size
fig, ax = plt.subplots(figsize=(8, 6), subplot_kw={'projection': ccrs.PlateCarree()})

# Remove the border around the map
ax.spines['geo'].set_visible(False)

# Add features and plot the data
ax.add_feature(cfeature.OCEAN, facecolor='white', zorder=2)
NRM_clusters.plot(ax=ax, cmap='tab10', add_colorbar=False, transform=ccrs.PlateCarree())

# Create the legend to the right of the map
handles = [mpatches.Patch(color=region_colors[i], label=REGION_NAMES[i]) for i in legend_order]
plt.legend(
    handles=handles,
    loc='center left',
    bbox_to_anchor=(1.1, 0.5),   # Position the legend slightly further to the right
    ncol=1,                      # Single column
    frameon=False,               # Remove border around the legend
    fontsize=14,                 # Increase font size
    handleheight=1.5             # Increase spacing between legend entries
)

# Use tight layout to reduce unnecessary space
plt.tight_layout()

plt.show()

## Figure 2: Time under drought trends for traditional drought metrics

### Colormap creation functions

In [1]:
def create_trends_colormap(cmap, bin_number, end_cutoffs, norm=True, season=None):
    """
    Creates a custom colormap with specified binning and optional normalization.

    Args:
        cmap (str): The name of the colormap to use.
        bin_number (int): The number of bins for the colormap.
        end_cutoffs (int): The number of bins to cut off from each end of the colormap.
        norm (bool, optional): Whether to normalize the colormap. Defaults to True.
        season (str, optional): If provided, adjusts the bin edges for seasonal data. Defaults to None.

    Returns:
        custom_cmap, norm (tuple): A tuple containing the custom colormap and the normalization object if norm is True.
        custom_cmap (ListedColormap): The custom colormap if norm is False.
    """
    if norm:
        middle_bins = np.array([-0.3666, -0.2333, -0.1, 0.1, 0.2333, 0.3666])
        # middle_bins = np.array([-0.3, -0.1, 0.1, 0.3])
        positive_end = np.arange(0.5, 2.166667, 0.166667)
        negative_end = np.sort(positive_end * (-1))
        bin_edges = np.concatenate((negative_end, middle_bins, positive_end))
        if season is not None:
            bin_edges = bin_edges/4

        bin_number = len(bin_edges) - 1

    brewer_cmap = mpl.colormaps[cmap]
    extra_bin_number = bin_number + (end_cutoffs * 4)
    colors_array = brewer_cmap(np.linspace(0, 1, extra_bin_number))
    grey_RGB = 230
    grey_color = [grey_RGB/255, grey_RGB/255, grey_RGB/255, 1]

    # Remove the very light colors close to zero to give extra contrast and super dark ones at the ends
    first_half_colors = colors_array[: extra_bin_number // 2][end_cutoffs:-end_cutoffs]
    second_half_colors = colors_array[extra_bin_number // 2 :][end_cutoffs:-end_cutoffs]
    new_colors_array = np.vstack([first_half_colors, grey_color, second_half_colors])

    custom_cmap = mcolors.ListedColormap(new_colors_array)

    if norm:
        norm = mcolors.BoundaryNorm(bin_edges, ncolors=bin_number, clip=False)
        
        return custom_cmap, norm
    else:
        return custom_cmap


def create_mean_var_colormap():
    """
    Creates a colormap and normalization for visualizing mean and variance data.

    Returns:
        cmap, norm (tuple): A tuple containing the colormap (cmap) and the normalization (norm).
    """
    colors = ['#5B9BD5', '#F4C542', '#E06666'] 
    cmap = mcolors.ListedColormap(colors)
    bounds = [0, 49.9999, 50.0001, 100]  # Boundaries for the values
    norm = mcolors.BoundaryNorm(bounds, cmap.N)
    return cmap, norm

### Loading data

In [None]:
def load_trends_df_as_dictionary(vars, years, season=None, agg=5, scale=3, intensity=False):
    """
    Creates list of dataarrays to then be plotted in order. group_by defines the order.

    Args:
        vars (list of str): list of variable names
        years (list of str): list of year ranges of the trends
        group_by str: 'vars', or 'years'): decide how to group the plots
        agg (int): how many years aggregated over before taking trend
        scale (int): the months that droughts were defined over
        season (str): season to load trend of (DJF, MAM, JJA, SON) or None if loading annual trend
        intensity (bool): whether to plot dorught intensity trends, if False plots time under drought trends

    Returns:
        trends_list (list of xr.Datasets): list of the specified trend datasets
    """
    trends = {}
    for var in vars:
        years_dict = {}
        for year_range in years:
            if 'contribution' in var:
                year_range = year_range.replace('-', '_')
            path_dict = {
                'precip_drought':
                    f'{datadir}MK_test/drought_metrics/Aus/precip_percentile/MK_yue_wang/{year_range}/Aus_{year_range}_{str(agg)}_year_MK_yue_wang_test_precip_percentile_baseline_1911_2020.nc',
                    
                'runoff_drought': 
                    f'{datadir}MK_test/drought_metrics/Aus/runoff_percentile/MK_yue_wang/{year_range}/Aus_{year_range}_{str(agg)}_year_MK_yue_wang_test_runoff_percentile_baseline_1911_2020.nc',
                'soil_moisture_drought':
                    f'{datadir}MK_test/drought_metrics/Aus/soil_moisture_percentile/MK_yue_wang/{year_range}/Aus_{year_range}_{str(agg)}_year_MK_yue_wang_test_soil_moisture_percentile_baseline_1911_2020.nc',
                'impacts_drought':
                    f'{datadir}MK_test/RF_droughts/drought_events/1911_model/MK_yue_wang/{year_range}_{str(agg)}_year_MK_yue_wang_test_drought_events_1911_model.nc',
                'precip':
                    f'{datadir}/MK_test/RF_predictors/Precipitation/MK_yue_wang/{year_range}/{year_range}_{str(agg)}_year_MK_yue_wang_test_Precipitation_1911_model.nc',
                'runoff':
                    f'{datadir}/MK_test/RF_predictors/Runoff/MK_yue_wang/{year_range}/{year_range}_{str(agg)}_year_MK_yue_wang_test_Runoff_1911_model.nc',
                'soil_moisture': 
                    f'{datadir}/MK_test/RF_predictors/Soil_Moisture/MK_yue_wang/{year_range}/{year_range}_{str(agg)}_year_MK_yue_wang_test_Soil_Moisture_1911_model.nc',
            
                'impacts': 
                    f'{datadir}/MK_test/RF_droughts/drought_events/1911_model/MK_yue_wang/{year_range}_{str(agg)}_year_yue_wang_MK_test_drought_events_1911_model.nc',
            
                f'precip_drought_mean_contribution': 
                    f'/scratch/w97/mg5624/data/mean_var_contribution/precip_drought/precip_drought_mean_contribution_{year_range}.nc',
                
                f'runoff_drought_mean_contribution': 
                    f'/scratch/w97/mg5624/data/mean_var_contribution/runoff_drought/runoff_drought_mean_contribution_{year_range}.nc',
                    
                f'soil_moisture_drought_mean_contribution': 
                    f'/scratch/w97/mg5624/data/mean_var_contribution/soil_moisture_drought/soil_moisture_drought_mean_contribution_{year_range}.nc',
            }

            # if we want the seasonal trend, add this onto the filename string
            if season is not None:
                dirpath, filename = os.path.split(path_dict[var])
                filename = f"{season}_{filename}"
                path_dict[var] = os.path.join(dirpath, filename)
            if intensity:
                if 'percentile' in path_dict[var] and '_baseline' in path_dict[var]:
                    filepath = path_dict[var].replace('_baseline', '_intensity_baseline')
            else:
                filepath = path_dict[var]
            if scale != 3 and 'drought' in var:
                filepath = f'{filepath[:-3]}_{scale}-month.nc'
                
            trend_data = xr.open_dataset(filepath)
            lon_min, lon_max = trend_data.lon.min().item(), trend_data.lon.max().item()
            lat_min, lat_max = trend_data.lat.min().item(), trend_data.lat.max().item()
            mask = xr.open_dataarray(f'{datadir}/masks/regridded_awra_awap_mask.nc')
            mask = mask.sel(lat=slice(lat_min, lat_max), lon=slice(lon_min, lon_max))
            if trend_data.lat.equals(mask.lat):
                trend_data_masked = trend_data.where(mask)
            else:
                trend_data = trend_data.interp_like(mask, method='nearest')
                trend_data_masked = trend_data.where(mask)
            
            year_range = year_range.replace('_', '-')
            years_dict[year_range] = trend_data_masked
        trends[var] = years_dict

    return trends

### General function for plotting trend maps (and mean/variability contribution)

In [None]:
def plot_MK_trendtest_results(vars, years, season=None, agg=5, scale=3, plotting_trend=True, cbar=True, cbar_lim=None, intensity=False, impacts=False, find_percentages=False):
    """
    Plots results from the MK trendtest. Hatching is overlayed to indicate where the trends are significant. Also works to plot the mean/variability contribution, in this case
    set plotting_trend to False.
    Args:
        vars (list of str): name of vars to plot trends for
        years (list of str): the year ranges that the trends are over
        season (str): season to load trend of (DJF, MAM, JJA, SON), only used when plotting_trend=True
        agg (int): how many years aggregated over before taking trend, only used when plotting_trend=True
        scale (int): the months that droughts were defined over, only used when plotting_trend=True
        plotting_trend (bool): set to False if plotting mean/variability contribution, true otherwise
        cbar (bool): whether to plot with colorbars or not
        cbar_lim (list of float): list same size as years, indicates the cbar max for each year range
        intensity (bool): whether to plot drought intensity trends, if False plots time under drought trends
        impacts (bool): True if plotting trend in impact-based drought metric
        find_percenatges (bool): if true, prints out percenatges of significant/positive/negative trends
    """
    if impacts:
        # impacts plotted in 1 row, 3 columns; each solumns the different time period so need to load dfs in different order
        data_arrays = load_trends_df_as_dictionary(years, vars, season, agg, scale, intensity)
    else: 
        data_arrays = load_trends_df_as_dictionary(vars, years, season, agg, scale, intensity)
        if not plotting_trend:
            signif_vars = [var.removesuffix('_mean_contribution') for var in vars]
            signif_data_arrays = load_trends_df_as_dictionary(signif_vars, years, season, agg, scale, intensity)

    mpl.rcParams['hatch.linewidth'] = 0.35
    if len(years) == 1:
        fig, axes = plt.subplots(1, len(vars), figsize=(15, 4), subplot_kw={'projection': ccrs.PlateCarree()})
    else:
        fig, axes = plt.subplots(len(years), len(vars), figsize=(15, 10), subplot_kw={'projection': ccrs.PlateCarree()})
    for row, year in enumerate(years):
        for col, var in enumerate(vars):
            if impacts:
                ds = data_arrays[year][var]
            else:
                ds = data_arrays[var][year]

            da = ds.MK_slope
            if plotting_trend:
                trend = ds.MK_trend
                # da = ds.MK_slope
                significant = trend != 0
                trend_masked = trend.where(~significant)

                if find_percentages:
                    all_points = trend.count().item()
                    print('----------SIGNIFICANCE PERCENTAGES------------')
                    print(year, var)
                    print('no trend points: ', (trend.where(trend == 0).count().item() / all_points) * 100)
                    print('pos trend points: ', (trend.where(trend == 1).count().item() / all_points) * 100)
                    print('neg trend points: ', (trend.where(trend == -1).count().item() / all_points) * 100)
                    print('signif points: ', (trend.where(trend != 0).count().item() / all_points) * 100)
            else:
                signif_var = var.removesuffix('_mean_contribution')
                signif_ds = signif_data_arrays[signif_var][year]
                trend = signif_ds.MK_trend
                significant = trend != 0
                significant, da = xr.align(significant, da)
                trend_masked = trend.where(~significant)
                da_count = da.where(significant)
                da = da.where(significant, -1)
                if find_percentages:
                    all_points = da_count.count().item()
                    print('----------DOMINATING PERCENTAGES------------')
                    print(year, var)
                    print('equal contribution: ', (da_count.where(da_count == 50).count().item() / all_points) * 100)
                    print('mean dominates for: ', (da_count.where(da_count > 50).count().item() / all_points) * 100)
                    print('variability dominates for: ', (da_count.where(da_count < 50).where(da_count > -1).count().item() / all_points) * 100)
                    print('mean fully causing: ', (da_count.where(da_count == 100).count().item() / all_points) * 100)
                    print('variability fully causing: ', (da_count.where(da_count == 0).count().item() / all_points) * 100)

            AGCD_mask = xr.open_dataarray('/g/data/w97/mg5624/RF_project/masks/regridded_awra_awap_mask.nc')
            ds = ds.where(AGCD_mask)

            if len(years) == 1:
                ax = axes[col]
            else:
                ax = axes[row, col]

            for spine in ax.spines.values():
                spine.set_visible(False)
            if 'mean_contribution' in var:
                cmap, norm = create_mean_var_colormap()
            elif 'drought' in var or impacts:
                cmap = create_trends_colormap('BrBG_r', 30, 3, norm=False)
            else:
                cmap, norm = create_trends_colormap('BrBG', 30, 3, season=season)
            if not plotting_trend:
                cbar_max = 100
                cbar_min = 0
            elif cbar_lim is None:
                max_slope = da.max()
                min_slope = da.min()
                cbar_max = max(abs(max_slope), abs(min_slope))
                cbar_min = -cbar_max
            else:
                cbar_max = cbar_lim[row]
                cbar_min = -cbar_max
            plot = da.plot(ax=ax, cmap=cmap, vmin=cbar_min, vmax=cbar_max, add_colorbar=False)
            
            if plotting_trend:
                trend_masked.plot.contourf(ax=ax, hatches=[5*'/'], colors='none',  levels=(0.5, 1), add_colorbar=False)
            else:
                non_significant_cmap = mcolors.ListedColormap(['#D9D9D9'])
                trend_masked.plot(ax=ax, cmap=non_significant_cmap, add_colorbar=False)

            ax.add_feature(cfeature.OCEAN, zorder=100, facecolor='white', edgecolor='k')
            states_provinces = NaturalEarthFeature(category='cultural',
                                                   name='admin_1_states_provinces_lines',
                                                   scale='50m',
                                                   edgecolor='black',
                                                   facecolor='none',
                                                   alpha=0.5)
            ax.add_feature(states_provinces, linestyle='-', linewidth=1)
    if 'mean_contribution' in var:
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    else:
        sm = plt.cm.ScalarMappable(cmap=cmap)
    sm.set_array([])
    sm.set_clim(vmin=cbar_min, vmax=cbar_max)
    if cbar:
        cbar_ax = fig.add_axes([0.04, 0.1, 0.72, 0.03])
        if plotting_trend:
            extend_arg = 'both'
        else:
            extend_arg='neither'
        
        cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal', extend=extend_arg)
        cbar.ax.tick_params(labelsize=15)
        if intensity:
            cbar_label = rf'Change in drought intensity per {agg} years (%)'
        elif agg == 'LogReg':
            cbar_label = rf'Change in the probability of drought each month per month'
        elif 'mean_contribution' in var:
            cbar_label = ''
            tick_positions = [25, 50, 75]
            tick_labels = [
                'Variability Dominates Trend', 'Equal Contribution', 'Mean Dominates Trend'
            ]
            cbar.set_ticks(tick_positions)
            cbar.set_ticklabels(tick_labels)
        else:
            cbar_label = rf'Change in no. of drought months per {agg} years'
            
        cbar.ax.tick_params(which="minor", length=0)  # Turn off minor ticks
        cbar.ax.tick_params(which="major", length=5)  # Customize major tick length
        cbar.set_label(cbar_label, fontsize=18)
    plt.tight_layout(rect=[0.002, 0.15, 0.8, 1])
    row_titles_dict = {
        'precip_drought_mean_contribution': 'Meteorological Drought', 
        'soil_moisture_drought_mean_contribution': 'Agricultural Drought', 
        'runoff_drought_mean_contribution': 'Hydrological Drought',
        'precip_drought': 'Meteorological Drought', 
        'soil_moisture_drought': 'Agricultural Drought', 
        'runoff_drought': 'Hydrological Drought', 
    }
    if impacts:
        for j in range(len(vars)):
            axes[j].set_title(vars[j], fontsize=20)
    else:
        for i in range(len(years)):
            fig.text(0.02, len(years) * 0.286 - (i * len(years) * 0.09225), years[i], va='center', rotation='vertical', fontsize=20)
        for j in range(len(vars)):
            axes[0, j].set_title(row_titles_dict[vars[j]], fontsize=20)


### Plotting

In [None]:
years = ['1911-2020', '1951-2020', '1971-2020']
vars = ['precip_drought', 'soil_moisture_drought', 'runoff_drought']

plot_MK_trendtest_results(
    vars, years, season=None, agg=5, cbar_lim=[2, 2, 2], intensity=False, find_percentages=False
)

## Figure 3: Time under drought trends for impact-based drought metric

In [None]:
plot_MK_trendtest_results(
    years, ['impacts_drought'], season=None, agg=5, cbar_lim=[2, 2, 2], impacts=True
)

## Figure 4: Area under drought timeseries

### See area_under_drought_plots.py for Figure 4 code

## Figure 5: DJF and JJA seasonal time under drought trends for traditional drought metrics

In [None]:
# DJF
plot_MK_trendtest_results(
    vars, years, season='DJF', agg=5, cbar_lim=[0.5, 0.5, 0.5]
)

In [None]:
# JJA
plot_MK_trendtest_results(
    vars, years, season='JJA', agg=5, cbar_lim=[0.5, 0.5, 0.5]
)

## Figure 6: Contributions of mean and variability changes

In [None]:
vars = ['precip_drought_mean_contribution', 'soil_moisture_drought_mean_contribution', 'runoff_drought_mean_contribution']

plot_MK_trendtest_results(
    vars, years, plotting_trend=False
)

## Figure 7: Contribution of hydrometeorological variables to agricultural drought trends

In [None]:
def create_var_import_barcharts_for_all_seasons(drought_type, region):
    """
    Creates barcharts with al seasonal variable importance scores for specified NRM region and drought type.

    Args:
        drought_type (str): varibale name of the drought type
        NRM (str): NRM region that the barchart is for
    """
    colors = ['red', 'orange', 'skyblue', 'orchid']
    all_seasons_df = pd.DataFrame()
    legend_detail = {}
    for season in ['DJF', 'MAM', 'JJA', 'SON']:
        filepath = f'/scratch/w97/mg5624/data/trend_analysis/variable_importance/{drought_type}_drought/{season}/Random_Forest_Regression/'
        filename = f'{season}_{region}_{drought_type}_drought_1981-2020_trend_analysis_Random_Forest_Regression_std_dev.csv'
        df = pd.read_csv(filepath + filename, index_col=0)
        df.columns = df.columns.str.replace('_', ' ')
        df_mean = df.mean().to_frame().T
        all_seasons_df = pd.concat([all_seasons_df, df_mean], ignore_index=True)

        if region != 'Season_Average':
            trend_path = f'/scratch/w97/mg5624/data/trend_analysis/NRM_mean_trends/{drought_type}_drought/1981-2020/'
            trend_file = f'NRM_mean_{drought_type}_drought_trends_1981-1981.csv'
            trend_df = pd.read_csv(trend_path + trend_file)
            trend_df_filtered = trend_df.loc[(trend_df['NRM_region'] == region) & (trend_df['Season'] == season)]
            slope = trend_df_filtered['MK_slope'].iat[0]
            trend = trend_df_filtered['MK_trend'].iat[0]
            slope_string = str(np.round(slope, 2))
            if trend != 0:
                slope_string += '*'
            legend_detail[season] = slope_string
    all_seasons_var_import_mean = all_seasons_df.mean(axis=0)
    dataframe_sorted = all_seasons_df[all_seasons_var_import_mean.sort_values(ascending=False).index]

    fig, ax = plt.subplots(figsize=(16, 12))

    # Transpose the DataFrame to make it easier to plot
    dataframe_sorted = dataframe_sorted.T

    # Create positions for the bars
    bar_width = 0.2
    num_vars = len(dataframe_sorted.index)
    positions = np.arange(num_vars) * (len(4) + 1) * bar_width

    for i, season in enumerate(4):
        if region == 'Season_Average':
            season_label = season
        else:
            season_label = f'{season} ({legend_detail[season]})'
        ax.bar(positions + i * bar_width, dataframe_sorted.iloc[:, i], width=bar_width, color=colors[i], label=season_label)

    ax.set_xticks(positions + bar_width * (len(4) - 1) / 2)
    ax.set_xticklabels(dataframe_sorted.index.str.replace('_', ' '), ha='right', rotation=35, fontsize=40)
    yticks = ax.get_yticks()
    ax.set_yticks(yticks)
    ax.set_yticklabels([f'{y:.2f}' for y in ax.get_yticks()], fontsize=32)
    plt.subplots_adjust(left=0.2, right=0.97, bottom=0.35, top=0.9)
    region_title = region.replace('_', ' ')
    plt.title(region_title, fontsize=46)

    # Add legend
    ax.legend(fontsize=32)

In [None]:
# Barcharts for fig. 7
for region in list(NRM_REGIONS.keys()):
    create_var_import_barcharts_for_all_seasons('soil_moisture', region)

## Figure 8: Contribution of hydrometeorological variables to hydrological drought trends

In [None]:
# Barcharts for fig. 8
for region in list(NRM_REGIONS.keys()):
    create_var_import_barcharts_for_all_seasons('runoff', region)

## Figure S1: Information on drought impact reports data

### Load drought impact data

In [None]:
drought_impact_filepath = '/g/data/w97/mg5624/RF_project/training_data/training_data.csv'
drought_impact_data = pd.read_csv(drought_impact_filepath)
# drought_impact_data = drought_impact_data[['Year', 'Latitude', 'Longitude', 'Drought']]

### Create timeseries detailing the year of the drought impact reports

In [None]:
# Group by Year and count occurrences
data_count_per_year = drought_impact_data.groupby('Year').size()

# Create a full range of years (from the min to the max year in the data)
all_years = range(drought_impact_data['Year'].min(), drought_impact_data['Year'].max() + 1)

# Reindex the data to include all years, filling missing years with 0 using fillna
data_count_per_year_full = data_count_per_year.reindex(all_years).fillna(0).astype(int)

# Define subtle colors for the plot and scatter points
subtle_blue = '#6fa3e1'  # Softer blue
subtle_red = '#e57373'   # Softer red

# Plot the time series
plt.figure(figsize=(10, 6))
plt.plot(data_count_per_year_full.index, data_count_per_year_full.values, marker='o', linestyle='-', color=subtle_blue)

# Highlight the years with no data (optional: use a different color or markers)
plt.scatter(data_count_per_year_full.index[data_count_per_year_full == 0], 
            data_count_per_year_full[data_count_per_year_full == 0], color=subtle_red, label='No Reports', zorder=5)

# Add labels and remove title
plt.xlabel('Year', fontsize=16)
plt.ylabel('Number of Reports', fontsize=16)

# Set x-axis ticks to display every 2 years
plt.xticks(range(drought_impact_data['Year'].min(), drought_impact_data['Year'].max() + 1, 2), fontsize=14)

# Set y-axis ticks fontsize
plt.yticks(fontsize=14)

# Remove gridlines
plt.grid(False)

# Add a legend for the "No Data" points
plt.legend(fontsize=14)

# Show plot with tight layout to avoid clipping
plt.tight_layout()
plt.show()

### Create map detailing locations of drought impact reports

In [None]:
def count_repeated_locations_in_training_data(drought_impact_data):
    """
    Counts the number of times each location is used for an impact report and saves in a dataframe.

    Args:
        training_data (pd.DataFrame): the data used to train the RF model

    Returns:
        unique_points (pd.DataFrame): dataframe of each unique point used with a 
        count of the number of reports used for that point.
    """
    drought_impact_data['Coord']  = list(zip(drought_impact_data["Longitude"], drought_impact_data["Latitude"]))
    drought_impact_data['Duplicate_Count'] = drought_impact_data.groupby('Coord')['Coord'].transform('count')

    unique_points = drought_impact_data[['Longitude', 'Latitude', 'Duplicate_Count']].drop_duplicates()

    return unique_points


def create_plot_of_impact_reports(drought_impact_data, bounding_coords):
    """
    Creates plot of location of impact reports, with color of point dependent on number of reports.

    Args:
        training_data (pd.DataFrame): training data with impact reports
        bounding_coords (list): the coordinate to plot between in order of [lon_min, lon_max, lat_min, lat_max]
    """
    from shapely.geometry import Point
    import geopandas as gpd
    import matplotlib.cm as mpl_cm

    unique_points = count_repeated_locations_in_training_data(drought_impact_data)

    geometry = [Point(xy) for xy in zip(unique_points['Longitude'], unique_points['Latitude'])]
    unique_points_gdf = gpd.GeoDataFrame(unique_points, geometry=geometry)

    fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': ccrs.PlateCarree()})

    # Plot the unique points onto the map
    brewer_cmap = mpl_cm.get_cmap('cividis')
    vmax = 30
    bins = 6

    # Plot points with a color scale
    plot = unique_points_gdf.plot(
        ax=ax,
        vmin=0,
        vmax=vmax,
        markersize=10,
        legend=False,
        column='Duplicate_Count',
        cmap=brewer_cmap.resampled(bins),
    )

    # Add horizontal colorbar at the bottom
    cbar = fig.colorbar(
        plot.collections[0],  # Access the plotted collection
        ax=ax,
        orientation='horizontal',
        fraction=0.05,  # Adjust size of the colorbar
        pad=0.1,  # Add space between plot and colorbar
    )
    cbar.set_label('Number of Impact Reports per Location')  # Add label to colorbar

    # Add background features from Cartopy
    ax.add_feature(cfeature.COASTLINE, linewidth=0.8, edgecolor='black')
    states_provinces = NaturalEarthFeature(
        category='cultural',
        name='admin_1_states_provinces_lines',
        scale='50m',
        edgecolor='black',
        facecolor='none',
        alpha=0.5,
    )
    ax.add_feature(states_provinces, linestyle='-', linewidth=1)
    ax.add_feature(cfeature.LAND, facecolor='#f7f5f5')

    # Set axis and bounding coordinates
    ax.set_axis_off()
    ax.set_xlim(bounding_coords[0], bounding_coords[1])
    ax.set_ylim(bounding_coords[2], bounding_coords[3])

    plt.show()


bounds = [138, 155, -40, -26]
create_plot_of_impact_reports(drought_impact_data, bounds)

## Figure S2: Random forest skill score for predicting impact-based drought metric

In [None]:
def create_performance_metric_bar_chart(performance_df, ylim=None, plot_name=None):
    """
    Creates a bar chart plot of the performance metrics for the RF model that is trained to predict drought impacts.
    Each bar shows the mean performance metric value of the various iterations, with error bars showing the standard deviation.

    Args:
        performance_df (pd.DataFrame): Performance metrics stored in a DataFrame.
        ylim (list, optional): List specifying the y-axis limits as [ymin, ymax].
        plot_name (str, optional): Title for the plot.
    """
    # Set up the figure and axis
    fig, ax = plt.subplots(figsize=(10, 6))

    # Create the bar plot with a uniform pastel blue color
    sns.barplot(
        performance_df, 
        ax=ax, 
        errorbar='sd', 
        capsize=0.5, 
        color='lightblue',  # Use a single color for all bars
        err_kws={'color': 'gray'}
    )

    # Set y-axis limits if provided
    if ylim:
        ax.set_ylim([0, ylim])

    # Customize x-axis labels
    ax.set_xticklabels(ax.get_xticklabels(), rotation=35, ha='right', fontsize=16)
    
    # Customize y-axis labels with one decimal place
    ax.set_yticks(ax.get_yticks())
    ax.set_yticklabels([f"{y:.1f}" for y in ax.get_yticks()], fontsize=14)
    
    # Annotate bars with their values
    for p in ax.patches:
        ax.annotate(format(p.get_height(), '.3f'), 
                    (p.get_x() + p.get_width() / 2., p.get_height()), 
                    ha='center', va='center', 
                    xytext=(0, 11), 
                    textcoords='offset points',
                    fontsize=14)
    
    # Set plot title if provided
    if plot_name:
        ax.set_title(plot_name, fontsize=20)

    # Adjust layout for better appearance
    plt.tight_layout()
    plt.show()

In [None]:
impact_metric_performance = pd.read_csv(f'{datadir}model_analytics/performance_metric/1911final/performance_metric_1911final.csv')
create_performance_metric_bar_chart(impact_metric_performance, ylim=1.05)

## Figure S3: S/N ratio and KS test for tme under drought

In [None]:
def direction_of_change(drought_type, baseline, number_final_years):
    """
    Finds whether the difference between the baseline period and the period of the number of final years is 
    positive or negative.

    Args:
        drought_type (str): name of drought_type ('precip', etc.)
        baseline (list of str): list of form [first_baseline_year, last_baseline_year]
        number_of_final_years (int): the number of years at the end of the data to compare to the baseline

    Returns: 
        sign_of_change (xr.DataArray): spatial data which inciates sign of the change in droughts
    """
    data_source = {
        'precip': 'AGCD',
        'soil_moisture': 'AWRA',
        'runoff': 'AWRA'
    }
    years = {
        'precip': '1900-2021',
        'soil_moisture': '1911-2020',
        'runoff': '1911-2020'
    }

    drought_metric = xr.open_dataarray(
        f'/scratch/w97/mg5624/data/drought_metric/{drought_type}_percentile/' \
        f'{data_source[drought_type]}_{drought_type}_percentile_drought_metric_annual_{years[drought_type]}_baseline_1911_2020.nc'
    ).sel(time=slice('1911', '2020'))

    drought_metric_baseline = drought_metric.sel(time=slice(baseline[0], baseline[-1])).mean(dim='time')
    drought_metric_end_years = drought_metric.isel(time=slice(-number_final_years, None)).mean(dim='time')

    difference = drought_metric_end_years - drought_metric_baseline

    sign_of_change = xr.DataArray(np.sign(difference), coords=difference.coords, dims=difference.dims)

    return sign_of_change


def fully_emerged_check(emergence_test_data, emerged_length, threshold, threshold_condition='<'):
    """
    Checks if the variable in question has fully emerged from the variability. Works for 
    KS test dataarray (significance) or signal to noise ratio test.

    Args:
        emergence_test_data (xr.DataArray): dataarray of the emergence test in question
        emerged_length (int): the number of years for which the signal has to have emerged for to be fully emerged
        threshold (float): the threhsold value of emergence_test_data which defines the signal to have emerged
        threshold_condition ('<', '>', etc.): the condtion against the threshold which the emergecne test must meet

    Returns:
        fully_emerged (xr.DataArray): dataarray of boolean data, True where signal has emerged, False otherwise
    """
    data_over_emergence_period = emergence_test_data.isel(time=slice(-emerged_length, None))
    if threshold_condition == '<':
        condition = data_over_emergence_period < threshold
    elif threshold_condition == '>':
        condition = data_over_emergence_period > threshold
    else:
        raise ValueError(
            f'Threshold_condition does not accept {threshold_condition} as an argument. Use \'<\' or \'>\' instead.'
            )
    emerged_timesteps = condition
    emerged_mean = emerged_timesteps.mean(dim='time')
    fully_emerged = emerged_mean == 1

    return fully_emerged


def plot_trend_emergence(emergence_test, emerged_length, baseline):
    """
    Plots the trend emergence for different types of droughts.

    Args:
        emergence_test (str): The type of emergence test to use ('ks_test' or 'signal_to_noise').
        emerged_length (int): The length of the period to consider for emergence.
        baseline (list of str): The baseline period to compare against.
    """
    colors = ['#5B9BD5', '#D9D9D9', '#E06666']  # Blue, Grey, Red
    cmap = mcolors.ListedColormap(colors)
    bounds = [-1.5, -0.5, 0.5, 1.5]  # Boundaries for the values
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    # Titles for the subplots
    titles = {
        'precip': 'Meteorological Drought',
        'soil_moisture': 'Agricultural Drought',
        'runoff': 'Hydrological Drought',
    }

    # Set global font size
    plt.rcParams.update({'font.size': 14})  # Increase font size for readability

    # Create the figure and subplots
    fig, axes = plt.subplots(
        nrows=1, ncols=3, figsize=(15, 6), subplot_kw={"projection": ccrs.PlateCarree()}
    )

    for ax, drought_type in zip(axes, ['precip', 'soil_moisture', 'runoff']):
        # Load data
        emergence_test_data = xr.open_dataarray(
            f'/scratch/w97/mg5624/data/time_of_emergence/{emergence_test}/{emergence_test}_{drought_type}_time_under_drought_1911_2020_baseline_{baseline[0]}_{baseline[-1]}.nc'
        )#.sel(lat=slice(-16, -10), lon=slice(128, 135))
        if emergence_test == 'ks_test':
            emergence_test_data = emergence_test_data.isel(time=slice(None, -10)) # remove the last ten years as they don't have 20-year windows
        emergence_test_data = abs(emergence_test_data)
        threshold = {
            'ks_test': [0.05, '<'],
            'signal_to_noise': [1, '>']
        }
        fully_emerged = fully_emerged_check(emergence_test_data, emerged_length, threshold[emergence_test][0], threshold[emergence_test][-1])
        sign_of_change = direction_of_change(drought_type, baseline, emerged_length)
        sign_of_change, fully_emerged = xr.align(sign_of_change, fully_emerged)
        fully_emerged_and_sign = sign_of_change.where(fully_emerged, 0)

        mask = xr.open_dataarray(
            '/g/data/w97/mg5624/RF_project/masks/regridded_awra_awap_mask.nc'
        )
        fully_emerged_and_sign = fully_emerged_and_sign.where(mask)

        total_pixels = fully_emerged_and_sign.count()
        increasing_fraction = (fully_emerged_and_sign == 1).sum()
        decreasing_fraction = (fully_emerged_and_sign == -1).sum()

        percent_increasing = (increasing_fraction / total_pixels) * 100
        percent_decreasing = (decreasing_fraction / total_pixels) * 100

        print('----------', drought_type, '----------')
        print('Emerged increasing percent: ', percent_increasing.values)
        print('Emerged decreasing percent: ', percent_decreasing.values)
        # Plot on the current axis
        im = fully_emerged_and_sign.plot(
            ax=ax,
            cmap=cmap,
            norm=norm,
            add_colorbar=False,  # We'll add a shared colorbar
        )

        # Customize the map
        ax.set_title(titles[drought_type], fontsize=24)  # Add the title
        ax.coastlines()  # Add coastlines
        ax.set_xlabel('')  # Remove longitude label
        ax.set_ylabel('')  # Remove latitude label

    # Add a shared horizontal colorbar at the bottom
    cbar = fig.colorbar(
        im,
        ax=axes,
        orientation='horizontal',
        fraction=0.1,  # Make the colorbar longer
        pad=0.1,
        ticks=[-1, 0, 1],  # Ticks for the colorbar
    )
    cbar.ax.set_xticklabels(
        ['Decreased', 'Not Emerged', 'Increased'], fontsize=20
    )

    # Adjust layout
    plt.tight_layout(rect=[0, 0.25, 1, 1])  # Ensure plots fit well with the colorbar
    plt.show()

### SIgnal to Noise Ratio

In [None]:
plot_trend_emergence('signal_to_noise', 20, ['1911', '1961'])

### KS test

In [None]:
plot_trend_emergence('ks_test', 20, ['1911', '1961'])

## Figures S4: Time under drought trend for 2-year block aggregated periods

In [None]:
plot_MK_trendtest_results(
    vars, years, season=None, agg=2, scale=3, cbar_lim=[0.25, 0.25, 0.25]
)

## Figures S5: Time under drought trend for 3-year block aggregated periods

In [None]:
plot_MK_trendtest_results(
    vars, years, season=None, agg=3, scale=3, cbar_lim=[0.6, 0.6, 0.6]
)

## Figures S6: Time under drought trend for 7-year block aggregated periods

In [None]:
plot_MK_trendtest_results(
    vars, years, agg=7, scale=3, cbar_lim=[3, 3, 3]
)

## Figures S7: Time under drought trend for logistic regression model of drought probability

In [None]:
plot_MK_trendtest_results(
        vars, years, agg='LogReg', scale=3, cbar_lim=[0.0006, 0.0006, 0.0006]
    )

## Figure S8: Time under drought trend for drought metrics defined over 12-month period, instead of 3-month

In [None]:
plot_MK_trendtest_results(
    vars, years, agg=5, scale=12, cbar_lim=[2,2,2]
)

## Figure S9: Drought intensity trends

In [None]:
plot_MK_trendtest_results(
    vars, years, agg=5, cbar_lim=[50, 50, 50], intensity=True
)

## Figure S10: S/N ratio for area under drought

In [None]:
drought_titles = {
    'precip': 'Meteorological Drought',
    'soil_moisture': 'Agricultural Drought',
    'runoff': 'Hydrological Drought'
}

def plot_trend_emergence_for_are_under_drought(emergence_test_type):
    """
    Plot the trend emergence tests for areas under drought for each NRM region.

    Args:
        emergence_test_type (str): The type of emergence test to use. Options are 'ks_test' or 'signal_to_noise'.
    """
    for drought_type in ['precip', 'soil_moisture', 'runoff']:
        # Create a figure for each drought type and test combination
        fig, axes = plt.subplots(
            nrows=2, ncols=4, figsize=(12, 6), sharex=True, sharey=True
        )
        axes = axes.ravel()  # Flatten the 2D array of axes for easy iteration

        for ax, area in zip(axes, [
            'Central_Slopes',
            'East_Coast',
            'Murray_Basin',
            'Monsoonal_North',
            'Rangelands',
            'Southern_Slopes',
            'S_SW_Flatlands',
            'Wet_Tropics',
        ]):
            # Load the data
            emergence_test = xr.open_dataarray(
                f'/scratch/w97/mg5624/data/time_of_emergence/{emergence_test_type}/aud/{area}_{emergence_test_type}_{drought_type}_area_under_drought_1911_2020_baseline_1911_1961.nc'
            )
            
            if emergence_test_type == 'ks_test':
                emergence_test_time = emergence_test.coords['time'][10:]
                emergence_test = emergence_test.isel(time=slice(None, -10))
                emergence_test.coords['time'] = emergence_test_time
            # Plot the data
            emergence_test.plot(ax=ax, add_legend=False)
            
            # Add the red dotted lines
            if emergence_test_type == 'ks_test':
                ax.axhline(0.05, color='red', linestyle='dotted')  # Line at 0.05
            elif emergence_test_type == 'signal_to_noise':
                ax.axhline(1, color='red', linestyle='dotted')  # Line at 1
                ax.axhline(-1, color='red', linestyle='dotted')  # Line at -1
                # Set the y-axis limits for signal_to_noise test
                ax.set_ylim(-1.6, 1.4)
            
            # Add the black vertical line at the year 2000
            # Ensure the time dimension is in years, or adjust accordingly
            threshold_year = np.datetime64('2000-01-01')
            ax.axvline(threshold_year, color='k', linestyle='--', linewidth=1)
            ax.set_title(
                area.replace('_', ' '),
                fontsize=18  # Larger title size
            )
            
            # Set fewer x-axis ticks
            if 'time' in emergence_test.dims:
                ax.set_xticks(ax.get_xticks()[::2])  # Keep every second tick
            
            # Customize tick label sizes
            ax.tick_params(axis='both', which='major', labelsize=14)  # Larger tick labels
            
            # Remove x-axis and y-axis labels for better spacing
            ax.set_xlabel("")
            if area in ['Central_Slopes', 'Rangelands']:
                if emergence_test_type == 'ks_test':
                    ax.set_ylabel('p-value', fontsize=18)
                else:
                    ax.set_ylabel('S/N ratio', fontsize=18)
            else:
                ax.set_ylabel("")
        
        # Add a suptitle for the descriptive drought type
        plt.suptitle(
            drought_titles[drought_type],  # Get the descriptive title
            fontsize=22,  # Larger font size for suptitle
            y=0.98  # Adjust position of the suptitle
        )
        
        # Adjust layout to prevent overlap
        plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust for suptitle
        
        # Show the figure
        plt.show()


In [None]:
plot_trend_emergence_for_are_under_drought('signal_to_noise')

## Figure S11: KS test for area under drought

In [None]:
plot_trend_emergence_for_are_under_drought('ks_test')

## Figure S12: MAM and SON time under drought trends

In [None]:
# MAM
plot_MK_trendtest_results(
    vars, years, season='MAM', agg=5, cbar_lim=[0.5, 0.5, 0.5]
)

In [None]:
# SON
plot_MK_trendtest_results(
    vars, years, season='SON', agg=5, cbar_lim=[0.5, 0.5, 0.5]
)

## Figure S13: Hydrological drought verification

In [None]:
def load_runoff_streamflow_drought_data(years):
    """
    Loads and processes hydrological drought trends over the specified years for AWRA model data (labelled runoff drought) 
    and observed streamflow (labelled as streamflow drought).

    Args:
        years (list): A list containing two elements, the start year and the end year.

    Returns:
        tuple: A tuple containing three elements:
            - streamflow_drought_slope (pd.DataFrame): DataFrame with streamflow drought slopes and catchment statistics.
            - runoff_drought_slope (pd.DataFrame): DataFrame with runoff drought slopes and catchment statistics.
            - gridded_runoff_drought_slope (xr.DataArray): DataArray with gridded runoff drought slopes.
    """
    year_from = years[0]
    year_to = years[1]
    runoff_streamflow_drought = pd.read_csv(
        f'/g/data/w97/mg5624/RF_project/MK_test/streamflow_catchments/MK_yue_wang/{year_from}-{year_to}/runoff_streamflow_drought_MK_yue_wang_trend_test_{year_from}-{year_to}_bymonth.csv'
    )

    catchment_stats = pd.read_csv('/g/data/w97/mg5624/RF_project/Streamflow/02_location_boundary_area/location_boundary_area.csv')
    catchment_stats.rename(columns={'station_id': 'CatchID'}, inplace=True)
    catchment_stats_subset = catchment_stats[['CatchID', 'lat_centroid', 'long_centroid', 'catchment_area']]
    runoff_streamflow_drought = pd.merge(runoff_streamflow_drought, catchment_stats_subset, on='CatchID', how='left')

    streamflow_drought_slope = runoff_streamflow_drought[['CatchID', 'Streamflow_MK_slope', 'lat_centroid', 'long_centroid', 'catchment_area']]
    runoff_drought_slope = runoff_streamflow_drought[['CatchID', 'Runoff_MK_slope', 'lat_centroid', 'long_centroid', 'catchment_area']]
    gridded_runoff_drought_slope = xr.open_dataset(
        f'/g/data/w97/mg5624/RF_project/MK_test/drought_metrics/Aus/runoff_percentile/MK_yue_wang/{year_from}-{year_to}/Aus_{year_from}-{year_to}_5_year_MK_yue_wang_test_runoff_percentile_baseline_{year_from}_{year_to}.nc'
    )['MK_slope']

    return streamflow_drought_slope, runoff_drought_slope, gridded_runoff_drought_slope


def plot_trends_at_catchments(trends_at_catchment_df, gridded_runoff_trend, years, vmin=None, vmax=None):
    """
    Plot drought trends at catchments on a map of Australia.

    Args:
        trends_at_catchment_df (pd.DataFrame): DataFrame containing drought trend data for each catchment. 
            Must include columns 'CatchID', 'lat_centroid', 'long_centroid', 'catchment_area', and the drought trend values.
        gridded_runoff_trend (xr.DataArray): Gridded runoff trend data.
        years (str): String representing the range of years for the trends (e.g., '1980-2020').
        vmin (float, optional): Minimum value for color normalization. If None, it is set to the negative of the maximum absolute value of the trend data.
        vmax (float, optional): Maximum value for color normalization. If None, it is set to the maximum absolute value of the trend data.
    """

    mask = xr.open_dataarray('/g/data/w97/mg5624/RF_project/masks/regridded_awra_awap_mask.nc')
    gridded_runoff_trend = gridded_runoff_trend.where(mask)
    # Create a GeoDataFrame from the DataFrame
    gdf = gpd.GeoDataFrame(trends_at_catchment_df, geometry=gpd.points_from_xy(trends_at_catchment_df.long_centroid, trends_at_catchment_df.lat_centroid))

    # Plot the map of Australia with PlateCarree projection
    fig, ax = plt.subplots(figsize=(10, 10), subplot_kw={'projection': ccrs.PlateCarree()})
    ax.set_extent([110, 155, -45, -10], crs=ccrs.PlateCarree())

    # Add coastline feature first
    ax.add_feature(cfeature.COASTLINE, zorder=1)

    # Normalize drought trend values for color mapping
    slope_column = trends_at_catchment_df.drop(columns=['CatchID', 'lat_centroid', 'long_centroid', 'catchment_area']).columns.tolist()[0]
    if vmin is None:
        cbar_lim = min(abs(trends_at_catchment_df[slope_column].min()), abs(trends_at_catchment_df[slope_column].max()))
        vmax = cbar_lim
        vmin = -cbar_lim
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    cmap = cmap = create_trends_colormap('BrBG_r', 30, 3, norm=False)

    # Plot the points with size based on the area and color based on drought trend
    marker_size = gdf['catchment_area'].apply(lambda x: min(x / 10, 200))
    gdf.plot(ax=ax, markersize=marker_size, color=[cmap(norm(val)) for val in trends_at_catchment_df[slope_column]],
             edgecolor='black', linewidth=0.5, transform=ccrs.PlateCarree(), zorder=3)

    # Add color bar
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, orientation='horizontal', fraction=0.046, pad=0.1, extend='both')
    cbar.set_label(r'Change in no. of drought months per 5 years', fontsize=20, labelpad=10)
    cbar.ax.tick_params(labelsize=16)

    plt.title(f'Runoff vs Streamflow Drought Trends \n {years}', fontsize=20)
    plt.show()


def scatterplot_streamflow_vs_runoff_drought_trends(runoff_drought_trend, streamflow_drought_trend):
    """
    Generates a scatter plot comparing hydrological drought trends from runoff and streamflow data.

    Args:
        runoff_drought_trend (pd.DataFrame): DataFrame containing the runoff drought trend data with a column 'Runoff_MK_slope'.
        streamflow_drought_trend (pd.DataFrame): DataFrame containing the streamflow drought trend data with a column 'Streamflow_MK_slope'.
    """
    from scipy.stats import pearsonr
    import matplotlib.ticker as mticker
    streamflow_drought_trend = streamflow_drought_trend['Streamflow_MK_slope']
    runoff_drought_trend = runoff_drought_trend['Runoff_MK_slope']
    pearson = pearsonr(streamflow_drought_trend, runoff_drought_trend)[0]
    
    plt.scatter(runoff_drought_trend, streamflow_drought_trend)
    
    # Add r^2 line
    m, b = np.polyfit(runoff_drought_trend, streamflow_drought_trend, 1)
    plt.plot(runoff_drought_trend, m*runoff_drought_trend + b, color='k', label=f'Pearson coefficient = {pearson:.2f}')
    
    # Add horizontal and vertical dashed lines at 0
    plt.axhline(0, color='k', linestyle='dashed')
    plt.axvline(0, color='k', linestyle='dashed')
    
    plt.xlabel('Hydrological Drought Trend\n(from AWRA-L model)', fontsize=16)
    plt.ylabel('Hydrological Drought Trend\n(from observed streamflow)', fontsize=16)
    plt.tick_params(axis='both', which='major', labelsize=14)
    plt.gca().yaxis.set_major_formatter(mticker.FormatStrFormatter('%.1f'))
    plt.legend(fontsize=11.75)
    plt.show()

### Map of streamflow drought trends on top of gridded runoff drought trends for 1981-2020

In [None]:
streamflow_drought_slope_1981, runoff_drought_slope_1981, gridded_runoff_drought_trend_1981 = \
    load_runoff_streamflow_drought_data(['1981', '2020'])

plot_trends_at_catchments(streamflow_drought_slope_1981, gridded_runoff_drought_trend_1981, '1981-2020', vmin=-3, vmax=3)


### Map of streamflow drought trends on top of gridded runoff drought trends for 1951-2020

In [None]:
streamflow_drought_slope_1951, runoff_drought_slope_1951, gridded_runoff_drought_trend_1951 = \
    load_runoff_streamflow_drought_data(['1951', '2020'])

plot_trends_at_catchments(streamflow_drought_slope_1951, gridded_runoff_drought_trend_1951, '1951-2020', vmin=-3, vmax=3)

### Scatterplot of runoff vs streamflow drought trend for 1981-2020

In [None]:
scatterplot_streamflow_vs_runoff_drought_trends(runoff_drought_slope_1981, streamflow_drought_slope_1981)

### Scatterplot of runoff vs streamflow drought trend for 1951-2020

In [None]:
scatterplot_streamflow_vs_runoff_drought_trends(runoff_drought_slope_1951, streamflow_drought_slope_1951)