# Code for creating the paper figures

In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cartopy.feature as cfeature
from cartopy.feature import NaturalEarthFeature
import cartopy.crs as ccrs
import matplotlib as mpl
import matplotlib.colors as mcolors
import math
import matplotlib.patches as mpatches
import os

datadir = '/g/data/w97/mg5624/RF_project/'

## Figure 1: Map of Natrual Resource Management (NRM) Regions

In [None]:
NRM_clusters = xr.open_dataset(
    '/g/data/w97/amu561/Steven_CABLE_runs/shapefiles/NRM/NRM_clusters.nc'
)['NRM_cluster']
NRM_clusters = NRM_clusters.where(NRM_clusters != 0)

NRM_REGIONS = {
    'Central_Slopes': 1,
    'East_Coast': 2,
    'Murray_Basin': 4,
    'Monsoonal_North': 5,
    'Rangelands': 6,
    'Southern_Slopes': 7,
    'S_SW_Flatlands': 8,
    'Wet Tropics': 9,
}

REGION_NAMES = {
    1: 'Central Slopes',
    2: 'East Coast', 
    4: 'Murray Basin',
    5: 'Monsoonal North',
    6: 'Rangelands',
    7: 'Southern Slopes',
    8: 'S/SW Flatlands',
    9: 'Wet Tropics'
}

region_colors = {
    1: '#1f77b4',  # blue
    2: '#ff7f0e',  # orange
    4: '#d62728',  # red
    5: '#8c564b',  # brown
    6: '#e377c2',  # pink
    7: '#7f7f7f',  # gray
    8: '#bcbd22',  # yellow-green
    9: '#17becf'   # turquoise
}

# Rearrange legend labels
legend_order = [5, 9, 6, 2, 1, 8, 4, 7]  # MN, WT, RL, EC, MB, CS, SSW, SS

# Create a smaller figure size
fig, ax = plt.subplots(figsize=(8, 6), subplot_kw={'projection': ccrs.PlateCarree()})

# Remove the border around the map
ax.spines['geo'].set_visible(False)

# Add features and plot the data
ax.add_feature(cfeature.OCEAN, facecolor='white', zorder=2)
NRM_clusters.plot(ax=ax, cmap='tab10', add_colorbar=False, transform=ccrs.PlateCarree())

# Create the legend to the right of the map
handles = [mpatches.Patch(color=region_colors[i], label=REGION_NAMES[i]) for i in legend_order]
plt.legend(
    handles=handles,
    loc='center left',
    bbox_to_anchor=(1.1, 0.5),   # Position the legend slightly further to the right
    ncol=1,                      # Single column
    frameon=False,               # Remove border around the legend
    fontsize=14,                 # Increase font size
    handleheight=1.5             # Increase spacing between legend entries
)

# Use tight layout to reduce unnecessary space
plt.tight_layout()

plt.show()

## Figure 2: Time under drought trends for traditional drought metrics

### Colormap creation functions

In [1]:
def create_colormap(cmap, bin_number, end_cutoffs, norm=True, season=None):
    """
    Creates a custom colormap with specified binning and optional normalization.

    Args:
        cmap (str): The name of the colormap to use.
        bin_number (int): The number of bins for the colormap.
        end_cutoffs (int): The number of bins to cut off from each end of the colormap.
        norm (bool, optional): Whether to normalize the colormap. Defaults to True.
        season (str, optional): If provided, adjusts the bin edges for seasonal data. Defaults to None.

    Returns:
        custom_cmap, norm (tuple): A tuple containing the custom colormap and the normalization object if norm is True.
        custom_cmap (ListedColormap): The custom colormap if norm is False.
    """
    if norm:
        middle_bins = np.array([-0.3666, -0.2333, -0.1, 0.1, 0.2333, 0.3666])
        # middle_bins = np.array([-0.3, -0.1, 0.1, 0.3])
        positive_end = np.arange(0.5, 2.166667, 0.166667)
        negative_end = np.sort(positive_end * (-1))
        bin_edges = np.concatenate((negative_end, middle_bins, positive_end))
        if season is not None:
            bin_edges = bin_edges/4

        bin_number = len(bin_edges) - 1

    brewer_cmap = mpl.colormaps[cmap]
    extra_bin_number = bin_number + (end_cutoffs * 4)
    colors_array = brewer_cmap(np.linspace(0, 1, extra_bin_number))
    grey_RGB = 230
    grey_color = [grey_RGB/255, grey_RGB/255, grey_RGB/255, 1]

    # Remove the very light colors close to zero to give extra contrast and super dark ones at the ends
    first_half_colors = colors_array[: extra_bin_number // 2][end_cutoffs:-end_cutoffs]
    second_half_colors = colors_array[extra_bin_number // 2 :][end_cutoffs:-end_cutoffs]
    new_colors_array = np.vstack([first_half_colors, grey_color, second_half_colors])

    custom_cmap = mcolors.ListedColormap(new_colors_array)

    if norm:
        norm = mcolors.BoundaryNorm(bin_edges, ncolors=bin_number, clip=False)
        
        return custom_cmap, norm
    else:
        return custom_cmap


def create_mean_var_colormap():
    """
    Creates a colormap and normalization for visualizing mean and variance data.

    Returns:
        cmap, norm (tuple): A tuple containing the colormap (cmap) and the normalization (norm).
    """
    colors = ['#5B9BD5', '#F4C542', '#E06666'] 
    cmap = mcolors.ListedColormap(colors)
    bounds = [0, 49.9999, 50.0001, 100]  # Boundaries for the values
    norm = mcolors.BoundaryNorm(bounds, cmap.N)
    return cmap, norm

### Loading data

In [None]:
def load_trends_df_as_dictionary(vars, years, season=None, agg=5, scale=3, intensity=False):
    """
    Creates list of dataarrays to then be plotted in order. group_by defines the order.

    Args:
        vars (list of str): list of variable names
        years (list of str): list of year ranges of the trends
        group_by str: 'vars', or 'years'): decide how to group the plots
        agg (int): how many years aggregated over before taking trend
        scale (int): the months that droughts were defined over
        season (str): season to load trend of (DJF, MAM, JJA, SON) or None if loading annual trend
        intensity (bool): whether to plot dorught intensity trends, if False plots time under drought trends

    Returns:
        trends_list (list of xr.Datasets): list of the specified trend datasets
    """
    trends = {}
    for var in vars:
        years_dict = {}
        for year_range in years:
            if 'contribution' in var:
                year_range = year_range.replace('-', '_')
            path_dict = {
                'precip_drought':
                    f'{datadir}MK_test/drought_metrics/Aus/precip_percentile/MK_yue_wang/{year_range}/Aus_{year_range}_{str(agg)}_year_MK_yue_wang_test_precip_percentile_baseline_1911_2020.nc',
                    
                'runoff_drought': 
                    f'{datadir}MK_test/drought_metrics/Aus/runoff_percentile/MK_yue_wang/{year_range}/Aus_{year_range}_{str(agg)}_year_MK_yue_wang_test_runoff_percentile_baseline_1911_2020.nc',
                'soil_moisture_drought':
                    f'{datadir}MK_test/drought_metrics/Aus/soil_moisture_percentile/MK_yue_wang/{year_range}/Aus_{year_range}_{str(agg)}_year_MK_yue_wang_test_soil_moisture_percentile_baseline_1911_2020.nc',
                'impacts_drought':
                    f'{datadir}MK_test/RF_droughts/drought_events/1911_model/MK_yue_wang/{year_range}_{str(agg)}_year_MK_yue_wang_test_drought_events_1911_model.nc',
                'precip':
                    f'{datadir}/MK_test/RF_predictors/Precipitation/MK_yue_wang/{year_range}/{year_range}_{str(agg)}_year_MK_yue_wang_test_Precipitation_1911_model.nc',
                'runoff':
                    f'{datadir}/MK_test/RF_predictors/Runoff/MK_yue_wang/{year_range}/{year_range}_{str(agg)}_year_MK_yue_wang_test_Runoff_1911_model.nc',
                'soil_moisture': 
                    f'{datadir}/MK_test/RF_predictors/Soil_Moisture/MK_yue_wang/{year_range}/{year_range}_{str(agg)}_year_MK_yue_wang_test_Soil_Moisture_1911_model.nc',
            
                'impacts': 
                    f'{datadir}/MK_test/RF_droughts/drought_events/1911_model/MK_yue_wang/{year_range}_{str(agg)}_year_yue_wang_MK_test_drought_events_1911_model.nc',
            
                f'precip_drought_mean_contribution': 
                    f'/scratch/w97/mg5624/data/mean_var_contribution/precip_drought/precip_drought_mean_contribution_{year_range}.nc',
                
                f'runoff_drought_mean_contribution': 
                    f'/scratch/w97/mg5624/data/mean_var_contribution/runoff_drought/runoff_drought_mean_contribution_{year_range}.nc',
                    
                f'soil_moisture_drought_mean_contribution': 
                    f'/scratch/w97/mg5624/data/mean_var_contribution/soil_moisture_drought/soil_moisture_drought_mean_contribution_{year_range}.nc',
            }

            # if we want the seasonal trend, add this onto the filename string
            if season is not None:
                dirpath, filename = os.path.split(path_dict[var])
                filename = f"{season}_{filename}"
                path_dict[var] = os.path.join(dirpath, filename)
            if intensity:
                if 'percentile' in path_dict[var] and '_baseline' in path_dict[var]:
                    filepath = path_dict[var].replace('_baseline', '_intensity_baseline')
            else:
                filepath = path_dict[var]
            if scale != 3 and 'drought' in var:
                filepath = f'{filepath[:-3]}_{scale}-month.nc'
                
            trend_data = xr.open_dataset(filepath)
            lon_min, lon_max = trend_data.lon.min().item(), trend_data.lon.max().item()
            lat_min, lat_max = trend_data.lat.min().item(), trend_data.lat.max().item()
            mask = xr.open_dataarray(f'{datadir}/masks/regridded_awra_awap_mask.nc')
            mask = mask.sel(lat=slice(lat_min, lat_max), lon=slice(lon_min, lon_max))
            if trend_data.lat.equals(mask.lat):
                trend_data_masked = trend_data.where(mask)
            else:
                trend_data = trend_data.interp_like(mask, method='nearest')
                trend_data_masked = trend_data.where(mask)
            
            year_range = year_range.replace('_', '-')
            years_dict[year_range] = trend_data_masked
        trends[var] = years_dict

    return trends

### General function for plotting trend maps (and mean/variability contribution)

In [None]:
def plot_MK_trendtest_results(vars, years, season=None, agg=5, scale=3, plotting_trend=True, cbar=True, cbar_lim=None, intensity=False, impacts=False, find_percentages=False):
    """
    Plots results from the MK trendtest. Hatching is overlayed to indicate where the trends are significant. Also works to plot the mean/variability contribution, in this case
    set plotting_trend to False.
    Args:
        vars (list of str): name of vars to plot trends for
        years (list of str): the year ranges that the trends are over
        season (str): season to load trend of (DJF, MAM, JJA, SON), only used when plotting_trend=True
        agg (int): how many years aggregated over before taking trend, only used when plotting_trend=True
        scale (int): the months that droughts were defined over, only used when plotting_trend=True
        plotting_trend (bool): set to False if plotting mean/variability contribution, true otherwise
        cbar (bool): whether to plot with colorbars or not
        cbar_lim (list of float): list same size as years, indicates the cbar max for each year range
        intensity (bool): whether to plot drought intensity trends, if False plots time under drought trends
        impacts (bool): True if plotting trend in impact-based drought metric
        find_percenatges (bool): if true, prints out percenatges of significant/positive/negative trends
    """
    if impacts:
        # impacts plotted in 1 row, 3 columns; each solumns the different time period so need to load dfs in different order
        data_arrays = load_trends_df_as_dictionary(years, vars, season, agg, scale, intensity)
    else: 
        data_arrays = load_trends_df_as_dictionary(vars, years, season, agg, scale, intensity)
        if not plotting_trend:
            signif_vars = [var.removesuffix('_mean_contribution') for var in vars]
            signif_data_arrays = load_trends_df_as_dictionary(signif_vars, years, season, agg, scale, intensity)

    mpl.rcParams['hatch.linewidth'] = 0.35
    if len(years) == 1:
        fig, axes = plt.subplots(1, len(vars), figsize=(15, 4), subplot_kw={'projection': ccrs.PlateCarree()})
    else:
        fig, axes = plt.subplots(len(years), len(vars), figsize=(15, 10), subplot_kw={'projection': ccrs.PlateCarree()})
    for row, year in enumerate(years):
        for col, var in enumerate(vars):
            if impacts:
                ds = data_arrays[year][var]
            else:
                ds = data_arrays[var][year]

            da = ds.MK_slope
            if plotting_trend:
                trend = ds.MK_trend
                # da = ds.MK_slope
                significant = trend != 0
                trend_masked = trend.where(~significant)

                if find_percentages:
                    all_points = trend.count().item()
                    print('----------SIGNIFICANCE PERCENTAGES------------')
                    print(year, var)
                    print('no trend points: ', (trend.where(trend == 0).count().item() / all_points) * 100)
                    print('pos trend points: ', (trend.where(trend == 1).count().item() / all_points) * 100)
                    print('neg trend points: ', (trend.where(trend == -1).count().item() / all_points) * 100)
                    print('signif points: ', (trend.where(trend != 0).count().item() / all_points) * 100)
            else:
                signif_var = var.removesuffix('_mean_contribution')
                signif_ds = signif_data_arrays[signif_var][year]
                trend = signif_ds.MK_trend
                significant = trend != 0
                significant, da = xr.align(significant, da)
                trend_masked = trend.where(~significant)
                da_count = da.where(significant)
                da = da.where(significant, -1)
                if find_percentages:
                    all_points = da_count.count().item()
                    print('----------DOMINATING PERCENTAGES------------')
                    print(year, var)
                    print('equal contribution: ', (da_count.where(da_count == 50).count().item() / all_points) * 100)
                    print('mean dominates for: ', (da_count.where(da_count > 50).count().item() / all_points) * 100)
                    print('variability dominates for: ', (da_count.where(da_count < 50).where(da_count > -1).count().item() / all_points) * 100)
                    print('mean fully causing: ', (da_count.where(da_count == 100).count().item() / all_points) * 100)
                    print('variability fully causing: ', (da_count.where(da_count == 0).count().item() / all_points) * 100)

            AGCD_mask = xr.open_dataarray('/g/data/w97/mg5624/RF_project/masks/regridded_awra_awap_mask.nc')
            ds = ds.where(AGCD_mask)

            if len(years) == 1:
                ax = axes[col]
            else:
                ax = axes[row, col]

            for spine in ax.spines.values():
                spine.set_visible(False)
            if 'mean_contribution' in var:
                cmap, norm = create_mean_var_colormap()
            elif 'drought' in var or impacts:
                cmap = create_colormap('BrBG_r', 30, 3, norm=False)
            else:
                cmap, norm = create_colormap('BrBG', 30, 3, season=season)
            if not plotting_trend:
                cbar_max = 100
                cbar_min = 0
            elif cbar_lim is None:
                max_slope = da.max()
                min_slope = da.min()
                cbar_max = max(abs(max_slope), abs(min_slope))
                cbar_min = -cbar_max
            else:
                cbar_max = cbar_lim[row]
                cbar_min = -cbar_max
            plot = da.plot(ax=ax, cmap=cmap, vmin=cbar_min, vmax=cbar_max, add_colorbar=False)
            
            if plotting_trend:
                trend_masked.plot.contourf(ax=ax, hatches=[5*'/'], colors='none',  levels=(0.5, 1), add_colorbar=False)
            else:
                non_significant_cmap = mcolors.ListedColormap(['#D9D9D9'])
                trend_masked.plot(ax=ax, cmap=non_significant_cmap, add_colorbar=False)

            ax.add_feature(cfeature.OCEAN, zorder=100, facecolor='white', edgecolor='k')
            states_provinces = NaturalEarthFeature(category='cultural',
                                                   name='admin_1_states_provinces_lines',
                                                   scale='50m',
                                                   edgecolor='black',
                                                   facecolor='none',
                                                   alpha=0.5)
            ax.add_feature(states_provinces, linestyle='-', linewidth=1)
    if 'mean_contribution' in var:
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    else:
        sm = plt.cm.ScalarMappable(cmap=cmap)
    sm.set_array([])
    sm.set_clim(vmin=cbar_min, vmax=cbar_max)
    if cbar:
        cbar_ax = fig.add_axes([0.04, 0.1, 0.72, 0.03])
        if plotting_trend:
            extend_arg = 'both'
        else:
            extend_arg='neither'
        
        cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal', extend=extend_arg)
        cbar.ax.tick_params(labelsize=15)
        if intensity:
            cbar_label = rf'Change in drought intensity per {agg} years (%)'
        elif agg == 'LogReg':
            cbar_label = rf'Change in the probability of drought each month per month'
        elif 'mean_contribution' in var:
            cbar_label = ''
            tick_positions = [25, 50, 75]
            tick_labels = [
                'Variability Dominates Trend', 'Equal Contribution', 'Mean Dominates Trend'
            ]
            cbar.set_ticks(tick_positions)
            cbar.set_ticklabels(tick_labels)
        else:
            cbar_label = rf'Change in no. of drought months per {agg} years'
            
        cbar.ax.tick_params(which="minor", length=0)  # Turn off minor ticks
        cbar.ax.tick_params(which="major", length=5)  # Customize major tick length
        cbar.set_label(cbar_label, fontsize=18)
    plt.tight_layout(rect=[0.002, 0.15, 0.8, 1])
    row_titles_dict = {
        'precip_drought_mean_contribution': 'Meteorological Drought', 
        'soil_moisture_drought_mean_contribution': 'Agricultural Drought', 
        'runoff_drought_mean_contribution': 'Hydrological Drought',
        'precip_drought': 'Meteorological Drought', 
        'soil_moisture_drought': 'Agricultural Drought', 
        'runoff_drought': 'Hydrological Drought', 
    }
    if impacts:
        for j in range(len(vars)):
            axes[j].set_title(vars[j], fontsize=20)
    else:
        for i in range(len(years)):
            fig.text(0.02, len(years) * 0.286 - (i * len(years) * 0.09225), years[i], va='center', rotation='vertical', fontsize=20)
        for j in range(len(vars)):
            axes[0, j].set_title(row_titles_dict[vars[j]], fontsize=20)


### Plotting

In [None]:
years = ['1911-2020', '1951-2020', '1971-2020']
vars = ['precip_drought', 'soil_moisture_drought', 'runoff_drought']

plot_MK_trendtest_results(
    vars, years, season=None, agg=5, cbar_lim=[2, 2, 2], intensity=False, find_percentages=False
)

## Figure 3: Time under drought trends for impact-based drought metric

In [None]:
plot_MK_trendtest_results(
    years, ['impacts_drought'], season=None, agg=5, cbar_lim=[2, 2, 2], impacts=True
)

## Figure 4: Area under drought timeseries

### See area_under_drought_plots.py for Figure 4 code

## Figure 5: DJF and JJA seasonal time under drought trends for traditional drought metrics

In [None]:
# DJF
plot_MK_trendtest_results(
    vars, years, season='DJF', agg=5, cbar_lim=[0.5, 0.5, 0.5]
)

In [None]:
# JJA
plot_MK_trendtest_results(
    vars, years, season='JJA', agg=5, cbar_lim=[0.5, 0.5, 0.5]
)

## Figure 6: Contributions of mean and variability changes

In [None]:
vars = ['precip_drought_mean_contribution', 'soil_moisture_drought_mean_contribution', 'runoff_drought_mean_contribution']

plot_MK_trendtest_results(
    vars, years, plotting_trend=False
)

## Figure 7: Contirbution of hydrometerological variables to agricultural drought trends

In [None]:
def create_var_import_barcharts_for_all_seasons(drought_type, region):
    """
    Creates barcharts with al seasonal variable importance scores for specified NRM region and drought type.

    Args:
        drought_type (str): varibale name of the drought type
        NRM (str): NRM region that the barchart is for
    """
    colors = ['red', 'orange', 'skyblue', 'orchid']
    all_seasons_df = pd.DataFrame()
    legend_detail = {}
    for season in ['DJF', 'MAM', 'JJA', 'SON']:
        filepath = f'/scratch/w97/mg5624/data/trend_analysis/variable_importance/{drought_type}_drought/{season}/Random_Forest_Regression/'
        filename = f'{season}_{region}_{drought_type}_drought_1981-2020_trend_analysis_Random_Forest_Regression_std_dev.csv'
        df = pd.read_csv(filepath + filename, index_col=0)
        df.columns = df.columns.str.replace('_', ' ')
        df_mean = df.mean().to_frame().T
        all_seasons_df = pd.concat([all_seasons_df, df_mean], ignore_index=True)

        if region != 'Season_Average':
            trend_path = f'/scratch/w97/mg5624/data/trend_analysis/NRM_mean_trends/{drought_type}_drought/1981-2020/'
            trend_file = f'NRM_mean_{drought_type}_drought_trends_1981-1981.csv'
            trend_df = pd.read_csv(trend_path + trend_file)
            trend_df_filtered = trend_df.loc[(trend_df['NRM_region'] == region) & (trend_df['Season'] == season)]
            slope = trend_df_filtered['MK_slope'].iat[0]
            trend = trend_df_filtered['MK_trend'].iat[0]
            slope_string = str(np.round(slope, 2))
            if trend != 0:
                slope_string += '*'
            legend_detail[season] = slope_string
    all_seasons_var_import_mean = all_seasons_df.mean(axis=0)
    dataframe_sorted = all_seasons_df[all_seasons_var_import_mean.sort_values(ascending=False).index]

    fig, ax = plt.subplots(figsize=(16, 12))

    # Transpose the DataFrame to make it easier to plot
    dataframe_sorted = dataframe_sorted.T

    # Create positions for the bars
    bar_width = 0.2
    num_vars = len(dataframe_sorted.index)
    positions = np.arange(num_vars) * (len(4) + 1) * bar_width

    for i, season in enumerate(4):
        if region == 'Season_Average':
            season_label = season
        else:
            season_label = f'{season} ({legend_detail[season]})'
        ax.bar(positions + i * bar_width, dataframe_sorted.iloc[:, i], width=bar_width, color=colors[i], label=season_label)

    ax.set_xticks(positions + bar_width * (len(4) - 1) / 2)
    ax.set_xticklabels(dataframe_sorted.index.str.replace('_', ' '), ha='right', rotation=35, fontsize=40)
    yticks = ax.get_yticks()
    ax.set_yticks(yticks)
    ax.set_yticklabels([f'{y:.2f}' for y in ax.get_yticks()], fontsize=32)
    plt.subplots_adjust(left=0.2, right=0.97, bottom=0.35, top=0.9)
    region_title = region.replace('_', ' ')
    plt.title(region_title, fontsize=46)

    # Add legend
    ax.legend(fontsize=32)

In [None]:
# Barcharts for fig. 7
for region in list(NRM_REGIONS.keys()):
    create_var_import_barcharts_for_all_seasons('soil_moisture', region)

## Figure 8: Contirbution of hydrometerological variables to hydrological drought trends

In [None]:
# Barcharts for fig. 8
for region in list(NRM_REGIONS.keys()):
    create_var_import_barcharts_for_all_seasons('runoff', region)

## Figure S1: Information on drought impact reports data