# Comments #

This notebook imports data from one of three publicly available sources (ETCCDI, HADEX_ref1961-1990, and HADEX_ref1981-2010) for weather observations, and **is designed 
to show trends**. 
For that purpose weather data are grouped into logical categories ('Warm', 'Cold', 'Wet', 'Dry', and 'Extremes')

The user is then able to zoom into individual states, as well as select one of the categories via radio buttons.

**I have deliberately not displayed any y-values**, because the main purpose is to display trends, not specific values.
Because the different datasets within each category all represent different measurements (i.e. mm of rain, number of days, temperature) they naturally occupy significantly differnt ranges of y-values.

Therefore **I have chosen to normalise each dataset** so it's values range from 0 to 1, and then plot them stacked (non-additive), so they don't overlap. When possible I have plotted a faint 3rd-degree polynomal regression curve to emphasise clear trends if they exist.

The original data contain obviously significant noise, which made it also harder to discern trends. I therefore am plotting a 10 year rolling average, which smoothes out the lines a fair bit. If the purpose of these graphs where to detect outlier and extremes, this would of course not be an acceptable technique.

The original data are provided as directories with thousands of individual csv files, one for each weather-observation category and each weather station.
To import these files takes a significant amount of time. I have therefore provided a csv file which contains my finished DataFrame which is much, much faster to import.
To demonstrate however that I constructed that dataframe from the original data by combining many different datafiles I have also provided a method `importAllAndStore()` to import from the original data (and then save the final DataFrame into a csv file).

This code could be fairly easily adapated to display climate trends from other countries, since the data-format seems to be identical internationally (as far as I checked). Apart from various constants, the main method that would need to be adapted is loadWeatherStations(), which is necessary to group the weatherstations into states.

**To test simply run all cells top to bottom.**

The code as it is imports and uses data from the HADEX_ref1981-2010 dataset, but that can easily be changed as this is fully parameterised.

To do a **quick test** with the prepared and exported dataframe, place the file `Climate-Trends-Australia-HADEX_1961.csv` into the working directory, and run the code as is.

**To test the original import**, expand and create a `Data` folder in the working directory, and place the expanded data into it. Because the datafiles are large, I compressed them separately for each data source - you only need to download the one you chose to use. Execute `importAllAndStore()` once, then run the notebook again top to bottom (exept for `importAllAndStore()` ).

Folder structure for import from original data:

`Data
   ETCCDI Climate Change indices
       Australia
           CDD
           CSDI
           ...
           WSDI
   HADEX_ref1961-1990
       stn-indices
           CDD_6190
           CSDI_6190
           ...
           WSDI_6190
   HADEX_ref1981-2010
       stn-indices
           CSDI_8110
           R95p_8110
           ...
           WSDI_8110
   Stations.csv`

If using this code, please acknowledge original author, **Angelika Sajani**


# URLS #

Explanation of observed indices:  https://www.climdex.org/learn/indices/#index-FD 

Hadex Datasets: https://www.climdex.org/access/   (I found the 1961-1990 dataset to be the most comprehensive) 

ETCCDI Datasets: http://etccdi.pacificclimate.org/data.shtml   

Australian Weather Stations (as part of downloading other data): http://www.bom.gov.au/metadata/catalogue/19115/ANZCW0503900447#dataset-constraints





# Imports and Constants

In [1]:
%matplotlib
import pandas as pd
import pandas.io.excel as excl
import numpy as np
import matplotlib as mp
import matplotlib.pyplot as plt
from matplotlib.widgets import Button, CheckButtons, RadioButtons, AxesWidget

import math
import os


### Constants to select the data source #############################
HADEX_1961 = 'HADEX_1961'
HADEX_1981 = 'HADEX_1981'
ETCCDI = 'ETCCDI'

HADEX_COUNTRY_CODE = 'AS' # all files from Australia in the HADEX_~ data sources start with this letter


Using matplotlib backend: MacOSX


# Import Weather Stations

In [2]:
# Returns a dataframe containing information about weather stations, with a six-digit string index 'SiteID'
# SiteIDs are prefixed with zeroes if (as a number) they contain less than 6 digits
# Extracted columns: 'Site Name', 'Longitude', and 'State'

def loadWeatherStations():
    
    rowsToSkip = [0, 1, 3] + list(range(19399, 19406))
    useCols = [0, 2, 6, 8]
    df = pd.read_csv('./Data/stations.csv', skiprows=rowsToSkip, usecols=useCols)

    # remove sites from islands and Antartica, as their climate would be very different from the mainland
    df = df[(df['STA'] != 'ISL') & (df['STA'] != 'ANT')]

    # index by weather station ID
    df.set_index('Site', inplace=True)
        
    df.columns = ['Site Name', 'Longitude', 'State']
    return df

In [3]:
####### Tester #######
# _ws = loadWeatherStations()
# print(_ws.head())
# print(_ws.tail())

# Data Source Specific Values

In [4]:
# selectors:             required key parameters       result        comments
# ----------------------------------------------------------------------------------------------------------------
#  'isValidDataFileName' fileName and eventCode        boolean       True if fileName is a valid weather data file
#  'dataFolderPath'      eventCode                     boolean       Relative  (to pwd) folder path for data files
#  'annualColumnWidths'  none                          [int]         list of two column widths to import data files
#                                                                       which contain one annual value per row
#  'monthlyColumnWidths' none                          [int]         list of 14 column widths to import data files
#                                                                       which contain one annual value per row
#  'siteIdFromFileName'  fileName                      str           six-character weather station ID (all digits)
#  'eventDictionary'     none                          dictionary    A dictionary describing available weather events 


def getForSource(selector, source, fileName=None, eventCode=None, basePath=None):
    
    # -------------------------------------------------------------------------------------------------------------    
    # values that are at the moment independent of the datasource but may not always be
    
    if (selector == 'annualColumnWidths'):
        return  [4, 7] # the column widths in each file, 4 characters for the year, the rest is the value
 
    elif (selector == 'monthlyColumnWidths'):
        return  [4] + [8] * 13 # the column widths in each file, 4 characters for the year, the rest is the value

    elif (selector == 'eventDictionary'):
        return  defineEvents(source)  # in its own function for readability

    # -------------------------------------------------------------------------------------------------------------    
    if source == HADEX_1961:
        if selector == 'isValidDataFileName':
            if (fileName is None) or (eventCode is None):
                raise ValueError('You must pass values for key-parameters fileName and eventCode')    
            return fileName.startswith(HADEX_COUNTRY_CODE) and fileName.endswith('_6190_' + eventCode + '.txt')
  
        elif selector == 'dataFolderPath':
            return './Data/HADEX_ref1961-1990/stn-indices/' + eventCode + '_6190/'
        
        elif selector == 'siteIdFromFileName':
            if (fileName is None):
                raise ValueError('You must pass a value for key-parameter fileName')    
            return fileName[5:11] 
   
    # -------------------------------------------------------------------------------------------------------------    
    elif source == HADEX_1981:
        if selector == 'isValidDataFileName':
            if (fileName is None) or (eventCode is None):
                raise ValueError('You must pass values for fileName, eventCode, and basePath');   
            return fileName.startswith(HADEX_COUNTRY_CODE) and fileName.endswith('_8110_' + eventCode + '.txt')

        elif selector == 'dataFolderPath':
            return './Data/HADEX_ref1981-2010/stn-indices/' + eventCode + '_8110/'
        
        elif selector == 'siteIdFromFileName':
            if (fileName is None):
                raise ValueError('You must pass a value for key-parameter fileName')    
            return fileName[5:11]   

    # -------------------------------------------------------------------------------------------------------------    
    elif source == ETCCDI:
        if selector == 'isValidDataFileName':
            if (fileName is None) or (eventCode is None):
                raise ValueError('You must pass values for fileName, eventCode, and basePath')    
            return fileName.endswith('.' + eventCode)
 
        elif selector == 'dataFolderPath':
            return './Data/ETCCDI Climate Change indices/Australia/' + eventCode + '/'
        
        elif selector == 'siteIdFromFileName':
            if (fileName is None):
                raise ValueError('You must pass a value for key-parameter fileName')    
            return fileName[0:6]

    # -------------------------------------------------------------------------------------------------------------    
    else:
        raise ValueError(f"Invalid value for parameter 'source' ({source}).")    
    
 
    # if we get here, we didn't have a valid selector, as for every valid selector there is a return statement
    raise ValueError(f"Invalid value for parameter 'selector' ({selector}).");    
    

In [5]:
####### Tester #######
# source = 'ETCCDI'
# print(getForSource('isValidDataFileName', source, fileName='ASH30018044_8110_R95pTOT.txt', eventCode='R95pTOT'))
# print(getForSource('dataFolderPath', source, eventCode='CDD'))
# print(getForSource('annualColumnWidths',  source))
# print(getForSource('monthlyColumnWidths', source))


# Climate Events to Examine

In [6]:
def defineEvents(source):
 
    eventCodes = None
    
    allEvents = {
            'CDD': {
                'Label': 'Consecutive Dry Days',
                'colour': 'sienna',
                'monthly': False,
                'category': 'Dry'
            },  
            'CSDI': {
                'Label': 'Cold Spell Duration', 
                'colour': 'blue',
                'monthly': False,
                'category': 'Cold'
            },  
            'CWD': {
                'Label': 'Consecutive Wet Days',
                'colour': 'darkgreen',
                'monthly': False,
                'category': 'Wet'
            },  
            'DTR': {
                'Label': 'Daily Temp. Range',
                'LongLabel': 'Daily Temperature Range',
                'colour': 'magenta',
                'monthly': True,
                'category': 'Extremes'
            },  
            'ETR': {
                'Label': 'Extreme Temp. Range',
                'LongLabel': 'Extreme Temperature Range',
                'colour': 'darkviolet',
                'monthly': True,
                'category': 'Extremes'
            },  
            'FD': {
                'Label': 'Frost Days',
                'colour': 'deepskyblue',
                'monthly': False,
                'category': 'Cold'
            },   
            'GSL': {
                'Label': 'Growing Season Dur.',
                'LongLabel': 'Growing Season Duration',
                'colour': 'tomato',
                'monthly': False,
                'category': 'Warm'
            },  
            'PRCPTOT': {
                'Label': 'Annual Total Rain',
                'LongLabel': 'Annual Total Precipitation',
                'colour': 'green',
                'monthly': False,
                'category': 'Wet'
            }, 
            'R95p': {
                'Label': 'Wet Days',
                'colour': 'lime',
                'monthly': False,
                'category': 'Wet'
            },
            'R95pTOT': {
                'Label': 'Rain from Wet Days',
                'colour': 'mediumseagreen',
                'monthly': False,
                'category': 'Wet'
            },
            'R99p': {
                'Label': 'Very Wet Days',
                'colour': 'green',
                'monthly': False,
                'category': 'Extremes'
            }, 
            'R99pTOT': {
                'Label': 'Rain Very Wet Days', 
                'LongLabel': 'Rain from Very Wet Days', 
                'colour': 'seagreen',
                'monthly': False,
                'category': 'Extremes'
            },
            'SDII': {
                'Label': 'Rain Intensity',
                'LongLabel': 'Precipitation Intensity Index',
                'colour': 'teal',
                'monthly': False,
                'category': 'Extremes'
            }, 
            'SU': {
                'Label': 'Summer Days',
                'colour': 'red',
                'monthly': False,
                'category': 'Warm'
            },  
            'TN10p': {
                'Label': 'Cold Nights',
                'colour': 'cyan',
                'monthly': False,
                'category': 'Cold'
            }, 
            'TN90p': {
                'Label': 'Warm Nights', 
                'colour': 'indianred',
                'monthly': False,
                'category': 'Warm'
            }, 
            'TR': {
                'Label': 'Tropical Nights',
                'colour': 'crimson',
                'monthly': False,
                'category': 'Warm'
            }, 
            'TX10p': {
                'Label': 'Cold Days',
                'colour': 'mediumturquoise',
                'monthly': False,
                'category': 'Cold'
            },
            'TX90p': {
                'Label': 'Hot Days',
                'colour': 'darkred',
                'monthly': False,
                'category': 'Warm'
            },
            'WSDI': {
                'Label': 'Warm Spell Dur.',
                'LongLabel': 'Warm Spell Duration',
                'colour': 'orangered',
                'monthly': False,
                'category': 'Warm'
            } 
        }
        
    if source == HADEX_1961:
        return allEvents
    
    elif source == HADEX_1981:
        eventCodes = ['CSDI', 'R95p', 'R95pTOT', 'R99p', 'R99pTOT', 'TN10p', 'TN90p', 'TX10p', 'TX90p', 'WSDI']
        eventCodes = ['CSDI', 'R95pTOT', 'R99pTOT', 'TN10p', 'TN90p', 'TX10p', 'TX90p', 'WSDI']

    elif source == ETCCDI: 
        result = allEvents
        result.pop('ETR', None)
        result.pop('R95pTOT', None)
        result.pop('R99pTOT', None)
        return result 


    
    # If we get here, and don't have an eventCodes list,
    #   we didn't have a valid source
    if eventCodes is None:
        raise ValueError(f"Invalid value for parameter 'source' ({source}).");    

    # all good, create a dictionary with all the keys in eventCodes from dictionary allEvents
    result = dict()
    for eventCode in eventCodes:
        result[eventCode] = allEvents[eventCode]
        
    return result 
        

In [7]:
####### Tester #######
# x = defineEvents(HADEX_1961)
# e = x['CDD']
# e.get('LongLabel', e['Label'])
# print()
# print(defineEvents(HADEX_1961))
# print()
# print(defineEvents(HADEX_1981))
# print()
# print(defineEvents(ETCCDI))


In [8]:
def getFilteredEvents(dataSource, categoryFlags):

    # allEvents is a dictionary, key is eventCode, 'category' contains 'Warm', 'Cold', 'Wet', 'Dry', or 'Extremes'
    allEvents = defineEvents(dataSource)
  
    # categoryFlags is a simple dictionary, key is a category, value is a boolean
    selectedCategories = []
    for category, flag in categoryFlags.items():
        if flag:
            selectedCategories += [category]
    
    result = dict()
    if not allEvents is None:
        for eventCode, eventData in allEvents.items():
            if eventData['category'] in selectedCategories:
                result[eventCode] = eventData # add this event to the result
    return result       

In [9]:
####### Tester #######
# categoryFlags = {'Warm': True, 'Cold': False, 'Wet': False, 'Dry': False}
# getFilteredEvents(HADEX_1961, categoryFlags)


In [10]:
# # Returns a dataframe which contains two columns: 'year', value, and index 'SiteID'
# # - the value column is float, it's column name is the same as the event code, i.e. 'CDD' for 'consecutive dry days'
# # - the SiteID is a six-digit string (prefixed with zeroes if necessary) identifying the source weather station

# eventData is the entry for eventCode from the eventDictionary

def importOneEvent(eventCode, eventData, source):

    result = None
    folderPath = getForSource('dataFolderPath', source, eventCode=eventCode)
    monthlyData = eventData['monthly']
    columnWidths = getForSource('monthlyColumnWidths', source) if monthlyData else getForSource('annualColumnWidths', source)

    # Data for each weather station are stored in separate fixed width files
    # convention for file name:
    #   siteID with 6 digits (prefixed with zeroes if possible)
    #   extension is the same as the event code, i.e. '.CDD' for 'consecutive dry days'
    rowsToSkip = list(range(0, 10)) # skip rows 0 to 9
   
    fileNames = None
    try:
        fileNames = os.listdir(folderPath)
    except:
        print(f"ERROR: Could not find the directory for event code '{eventCode}' and data source '{source}'.")
        print(f"PATH:  {folderPath}")
 
    if fileNames is not None:

        # first pass: count valid files so we can show import-percentage
        n = 0
        for fileName in fileNames:
            if getForSource('isValidDataFileName', source, fileName=fileName, eventCode=eventCode):
                n += 1

        # second pass, do the actual import
        counter = 0
        for fileName in fileNames:
            if getForSource('isValidDataFileName', source, fileName=fileName, eventCode=eventCode):
                counter += 1
                print(f'\r   {counter / n:>3.0%}  {fileName:30}', end='')
                # determine site ID from file name
                siteID = getForSource('siteIdFromFileName', source, fileName=fileName)
                # read the file (fixed width)
                filePath = os.path.join(folderPath, fileName)
                df = pd.read_fwf(filePath, widths=columnWidths, skiprows=rowsToSkip)
                
                if monthlyData: # we only need the first and last column
                    df = df.iloc[:, [0, 13]]

                # standardise the name of the column containing the event-code, and capitalise the 'year' column
                df.columns = ['Year', eventCode] # the column is already the code but not capitalised -> ensure 100% standardisation
                # filter out rows that don't contain a value
                df = df[df[eventCode] >= 0]  # empty row are marked with -99.0 or -99.9
                # filter out years before the period we are interested in
                # df = df[df['Year'] >= FIRST_YEAR]

                # add a column to store the site ID and create a Site:Year combined index
                df['Site'] = int(siteID)
                df.set_index(['Site', 'Year'], inplace=True)
                
                # normalise so that all curves roughly have the same scale
                mx = df[eventCode].max()
                if mx != 0.0:
                    df[eventCode] /= mx

                # if this is the first iteration, the result is the new data frame
                # subsequently, append the new data frame to the existing result
                if result is None:
                    result = df
                else:
                    result = pd.concat([result, df])
    print('')
    return result


In [11]:
####### Tester #######
# _eventCode = 'SDII'
# _evdict = getFilteredEvents(ETCCDI, extreme=1)
# # print(_evdict)
# x = importOneEvent(_eventCode, _evdict[_eventCode], ETCCDI)
# if x is None:
#     print('No data')
# else:
#     print(x.shape)
#     print(x.head())
#     print(x.tail())
#     print(x.describe())


In [12]:
def importEvents(eventDictionary, df_stations, source):

    result = None
    
    # counters are just for progress feedback
    n = len(eventDictionary)
    i = 0
    for eventCode, eventData in eventDictionary.items():
        i += 1
        print(f"{i}/{n} -- Importing '{eventData['Label']}' data")
 
       
        df = importOneEvent(eventCode, eventDictionary[eventCode], source)
        if df is None:
            pass
        elif result is None:
            result = df
        else:
            # index in both dfs is Site and Year
            result = pd.merge(result, df, how='outer', left_index=True, right_index=True)

    # reset the index to 'Site' alone, so we can easily join it with the weather station data
    if result is None:
        print('Could not find any data for any of the events.')
    else:
        result = result.reset_index().set_index('Site')
        result = pd.merge(result, df_stations, how='inner', left_index=True, right_index=True)

        return result   



## Plotting ##

In [13]:
def plotAveragesByState(fig, df, whichWeatherEvents, thisStateOnly=None):

    N = 10 # days over which the running average is calculated
    DKG = '#303030'
    
    if thisStateOnly is None:
        COLS = 3
        DELTA = 0.6 # offset per curve when plotting them on top of each other
        zoomed = False
        statesToPlot = ['Australia'] + sorted(df.State.unique())
        superTitle = 'Australia - Climate Trends'
    else:
        DELTA = 1 # offset per curve when plotting them on top of each other
        COLS = 1
        zoomed = True
        statesToPlot = [thisStateOnly]
        superTitle = thisStateOnly + ' - Climate Trends'

    nStates = len(statesToPlot)
    ROWS = math.ceil(nStates / COLS)
       
    fig.suptitle(superTitle, x=0.05, ha='left', fontsize=11, fontweight='demibold', c=DKG)
    
    plotIndex = 1    # 1 .. ROWS*COLS
    stateIndex = 0   # 0 .. nStates
    axesToPlaceLegend = None
    axesToGetLegend = None
    axes = None
    for row in range(0, ROWS):
        for col in range(0, COLS):
                      
            # if plotting all states, jump over the top-right quandrant to draw the legend
            if (zoomed): 
                doPlot = True
            elif stateIndex >= nStates: # index runs from 0..nStates-1
                doPlot = False
            elif (row == 0) and (col == COLS-1):  # this is whre we place the legend
                doPlot = False
            else:
                doPlot = True

            if doPlot:
                axes = plt.subplot(ROWS, COLS, plotIndex, sharex=axes, sharey=axes)
                axes.yaxis.set_ticks_position('none')
                axes.yaxis.set_ticks_position('none')

                if not zoomed: 
                    if (stateIndex == 0):   # if not zoomed, first one is always Australia
                        axesToGetLegend = axes
                    if (row == 0) and (col == COLS-2):
                        axesToPlaceLegend = axes
                
                state = statesToPlot[stateIndex]
                if (state == 'Australia'):
                    df_toPlot = df
                else:
                    df_toPlot = df[df['State'] == state]
                if not zoomed:
                    axes.set_title(state, y=0.98, x=0.5, ha='center', fontsize=9, c=DKG)
                    
                # group by year, calculating mean
                df_toPlot = df_toPlot.groupby('Year').mean()

                nCurves = len(whichWeatherEvents)
                curveIndex = 0
                offset = DELTA * nCurves
                
                # We will plot one curve for each weather event
                for eventCode, eventData in whichWeatherEvents.items():
                    x = df_toPlot.index  # x-data
                    y = df_toPlot[eventCode] # y-data
                    curveIndex += 1
                    offset -= DELTA  # this corresponds to the value 0 in the current curve
                    
                    # ------------------------------------------------------------------------------------------------
                    # Because they may have very different scales, we normalise each curve to values between 0 and 1.
                    # Then, to avoid them overlapping in the graph, we offset them each by a given y-constant DELTA
                    # Make offset decrease down to zero, so curve order corresponds with legend order
                    # ------------------------------------------------------------------------------------------------
                    
                    # Plot a faint, thin, grey line between each curve, so that any up- or down-trend is better visible
                    if zoomed and offset > 0:
                        axes.axhline(y=offset, xmin=0, xmax=1, lw=0.5, alpha=0.3, c='grey')

                    # minimum and maximum required for normalisation
                    mn = y.min()
                    mx = y.max()
                    
                    # for a few events, i.e. frost days, all values are zero for most states
                    # rather than plot a weird looking straight line, don't plot anything
                    if mn == mx: 
                        if zoomed:
                            plt.text(1920, offset + DELTA * 0.45, 'No ' + eventData['Label'], ha='left', fontsize=20, c='grey')
 
                    # normal case
                    else: 
                        y = (y - mn) / (mx - mn) + offset  # normalisation

                        # Try and plot a regression curve (third degree polynomal).
                        # Plot faint but double thick line to make it better visible
                        #  without detracting from the actual curve
                        try: 
                            a, b, c, d = np.polyfit(x, y, 3)
                            axes.plot(x, a*x**3 + b*x*x + c*x + d, \
                                      c=eventData['colour'], alpha=0.3, linestyle='--', lw=2.0, \
                                      zorder=curveIndex * 100)
                        except:
                            pass
                        
                        # Do a line plot
                        # use an N-day average to minimise noise and make trends easier to see
                        yRoll = y.rolling(N, win_type='triang').mean()
                        axes.plot(x, yRoll, label=eventData['Label'], \
                                  color = eventData['colour'], zorder=curveIndex)
                        
                    
                # ------------------------------------------------------------------------------------------------
                # Back to iterating over plots (states)  -- beautify
                # ------------------------------------------------------------------------------------------------

                # instead of a legend which would waste a lot of space, directly plot labels 
                if zoomed:
                    _, xMax = axes.get_xlim()
                    _, yMax = axes.get_ylim()
                    yPos = DELTA * nCurves
                    i = 0
                    for _, eventData in whichWeatherEvents.items():
                        if i == 0 and len(whichWeatherEvents) == 1:
                            yForText = yMax
                        elif i == 0:
                            yForText = yMax
                        else:
                            yForText = yPos
                        plt.text(xMax * 0.999, yForText - 0.05, \
                                 eventData.get('LongLabel', eventData['Label']), \
                                 ha='right', va='top', fontsize=9, c=DKG)
                        yPos -= DELTA
                        i += 1

                # Hide y ticklabels as they are meaningless
                for label in axes.get_yticklabels():
                    label.set_visible(False)

                # Since x-axis is shared, hide x-ticklabels for all but the bottom row
                #  but leave the ticklines
                if row < ROWS - 1:
                    for label in axes.get_xticklabels():
                        label.set_visible(False)   

                # Make frame, ticks and tick labels dark grey
                # some curves refer to degress, others to mm, others to counts
                for label in axes.get_xticklabels():
                    label.set_color(DKG)
                for line in axes.get_xticklines():
                    line.set_color(DKG)
                for pos in ['bottom', 'top', 'right', 'left']:
                    axes.spines[pos].set_color(DKG)
    
                # important: must be inside of    if doPlot:
                stateIndex += 1

            # important: must be outside of    if doPlot:
            plotIndex += 1
 
    if zoomed:
        plt.subplots_adjust(left = 0.05, bottom=0.05, right=0.8, top=0.89) 
    else:
        # we need room for state labels above, and legend to the right
        plt.subplots_adjust(left = 0.05, bottom=0.05, right=0.79, top=0.90)    
 
    # When showing all states show a shared legend
    # If zoomed plot labels for each state
    if not zoomed and axesToPlaceLegend is not None:
        handles, labels = axesToGetLegend.get_legend_handles_labels()
        axesToPlaceLegend.legend(handles, labels, \
                             bbox_to_anchor=(1.7, 1.0), loc='upper center', borderaxespad=0., \
                             fontsize=9, labelcolor=DKG)

  

In [14]:
def createFigure():
    fig = plt.figure(num=42, clear=True)
    fig.tight_layout()
    return fig


# Navigation and selection of which events to show #

In [15]:
class Controller:

    def __init__(self, figure, df, dataSource, initialFlags, \
                 buttonDimensions=[0.05, 0.9, 0.1, 0.05], \
                 checkboxDimensions=[0.6, 0.5, 0.3, 0.2]):

        # everyting we need to plot
        self.figure = figure
        self.df = df
        self.currentState = None
        self.dataSource = dataSource
        # for backbutton
        self.buttonDimensions = buttonDimensions  # set of four floats, xpos, ypos, width, height
        # for checkboxex
        self.categoryFlags = initialFlags
        self.checkboxDimensions = checkboxDimensions  # set of four floats, xpos, ypos, width, height
        
        # references to object we have to prevent from being garbage collected
        self.cid = None
        self.backButton = None
        self.checkBoxes = None
        self.radioButtons = None
 
        self.connectZoom() 
        self.connectSelectButtons()

       
    def connectZoom(self):
        self.cid = self.figure.canvas.mpl_connect('button_press_event', self.zoom)

        
    def disconnectZoom(self):
        if self.cid is not None:
            plt.disconnect(self.cid)
        self.cid = None

        
    def connectBackButton(self):
        # for a button to remain responsive, we need to keep a reference to it
        # therefore it is easier to create it here rather than inside plotAveragesByState
        self.backButton = Button(plt.axes(self.buttonDimensions), 'Back')
        self.backButton.on_clicked(self.back)
 

    def disconnectBackButton(self):
        self.backButton = None # allow this button to be garbage collected


    def connectSelectButtons(self):
        labels = list(self.categoryFlags.keys())
        flags = list(self.categoryFlags.values())
#         self.checkBoxes = CheckButtons(plt.axes(self.checkboxDimensions), labels, flags)
#         self.checkBoxes.on_clicked(self.clickedCB)
                                 
        i = 0
        for category, flag in self.categoryFlags.items():
            if flag:
                break
            i += 1
        self.radioButtons = RadioButtons(plt.axes(self.checkboxDimensions), labels, active=i, activecolor='#505050')
        self.radioButtons.on_clicked(self.clickedRB)
    
    
    def disconnectSelectButtons(self):
        self.checkBoxes = None # allow these to be garbage collected

        
    def zoom(self, event):
        axes = event.inaxes;
        if axes is not None:
            state = axes.get_title() # we must do this before we clear the figure
            if state is not None and state != '': 
                self.disconnectZoom()
                self.currentState = state
                self.plotState()
  

    def back(self, event):
        self.disconnectBackButton()
        self.currentState = None
        self.plotAll()
        
                    
    def plotState(self):
        self.figure.clear()
        
        plotAveragesByState(self.figure, self.df, self.get_selectedEvents(), thisStateOnly=self.currentState)
        self.connectBackButton()
        self.connectSelectButtons()
        self.figure.canvas.draw()
              
                
    def plotAll(self):
        self.figure.clear()
        
        plotAveragesByState(self.figure, self.df, self.get_selectedEvents())
        self.connectZoom()
        self.connectSelectButtons()
        self.figure.canvas.draw()
       
    
    def clickedCB(self, label):
        self.categoryFlags[label] = not self.categoryFlags[label]
        
        if self.currentState is None:
            self.plotAll()
        else:
            self.plotState()

    def clickedRB(self, label):
        for category in self.categoryFlags.keys():
            self.categoryFlags[category] = (category == label)
        
        if self.currentState is None:
            self.plotAll()
        else:
            self.plotState()


    def get_selectedEvents(self):
        return getFilteredEvents(self.dataSource, self.categoryFlags)
     

## For Convenience - import the original data once and store in one .csv file ##

In [16]:
# Because the original data are stored in thousands of separate files
# import is greatly sped up if we import them once and then store the entire df in a csv file

def makeFullFileName(dataSource, partFileName):
    return partFileName + '-' + dataSource + '.csv'
    

def importAllAndStore(dataSource, fileName='Climate-Trends-Australia'):

    # load the weather stations
    df_stations = loadWeatherStations()
    
    # Load the data for all weather events, station and year, as available
    # Index will be indexed by SiteID
    allEvents = defineEvents(dataSource)
    df = importEvents(allEvents, df_stations, dataSource)

    if df is None:
        print('We could not find any data for this combination of weather events.')
    else:
        try:
            df.to_csv(makeFullFileName(dataSource, fileName))
            print('Done')
        except:
            print('An error occurred writing that file.')
            


In [17]:
### Tester ###

# importAllAndStore(HADEX_1961)

# Main Method # 

In [18]:
# If passing a non-empty partialFileName, 
# this will import the main datafram from the csv file rather than from the original data
# Full filename will be <partialFileName>-<dataSource>.csv

def loadAndPlotClimateTrends(dataSource=ETCCDI, partialFileName=''):

    # load the weather stations
    df_stations = loadWeatherStations()

    # define which weather events we investigate, this will subsequently be controlled by checkboxes
    categoryFlags = {'Warm': True, 'Cold': False, 'Wet': False, 'Dry': False, 'Extremes': False}
    selectedEvents = getFilteredEvents(dataSource, categoryFlags)

    # Load the data for all weather events, station and year, as available
    # Index will be indexed by SiteID
    df = None
    if partialFileName == '':
        allEvents = defineEvents(dataSource)
        df = importEvents(allEvents, df_stations, dataSource)
        if df is None:
            print('We could not find any data for this combination of weather events.')

    else:
        try:
            df = pd.read_csv(makeFullFileName(dataSource, partialFileName))
        except:
            print('Could not import csv file.')

    # create a figure, and draw the inial plot
    fig = createFigure()
    plotAveragesByState(fig, df, selectedEvents)

    # create interactive controllers
    controller = Controller(fig, df, dataSource, \
                          initialFlags=categoryFlags,\
                          buttonDimensions=[0.05, 0.905, 0.1, 0.04], \
                          checkboxDimensions=[0.835, 0.698, 0.15, 0.2])


# Entry Point #

In [19]:
# HADEX_1961 = 'HADEX_1961' ### best of the three data sources
# HADEX_1981 = 'HADEX_1981'
# ETCCDI = 'ETCCDI'  

loadAndPlotClimateTrends(dataSource=HADEX_1961, partialFileName='Climate-Trends-Australia')
# loadAndPlotClimateTrends(dataSource=HADEX_1961, partialFileName='')

