In [24]:
import xarray

import sys
sys.path.append(sys.path[0] + '/..')
import utils._indexDefinitions as _index

In [25]:
def applyCriteria(indexDa, critDa):
    """This function applies the criteria to the indeces provided to determine if the events are negative or positive events.
    
    Then it sums events which are fire promoting or not-fire promoting.
    
    Should return two dataArrays with the count of fire Promoting and non fire Promoting modes for the models and years in the input data.
    
    """
    
    #enso and iod events are positive, sam events are negative
    indexNames=list(critDa.index.values)
    
    firePos=list(set(_index.firePos).intersection(indexNames))
    fireNeg=list(set(_index.fireNeg).intersection(indexNames))

    #events are greater than or less than the criteria
    posEvents=xarray.concat(
        (
            [
            (indexDa>critDa).sel(index=iNam)
            for iNam in indexNames
            ]
            #(indecesDa>criteriaDa).sel(index=['nino34','dmi']),
            #(indecesDa<-criteriaDa).sel(index='sam')
        ),
        'index'
    )

    negEvents=xarray.concat(
        (
            [
            (indexDa<-1*critDa).sel(index=iNam)
            for iNam in indexNames
            ]
        ),
        'index'
    ) 
    
    # a positive impact is a positive event with a positive impact, or a negative event with a negative impact
    firePosDa=xarray.concat(
        (posEvents.sel(index=firePos), negEvents.sel(index=fireNeg)
        ),
        'index')
    
    # a negative impact is a negative event with a positive impact, or a positive event with a negative impact
    fireNegDa=xarray.concat(
        (negEvents.sel(index=firePos),posEvents.sel(index=fireNeg)
        ),
        'index')
    
    return firePosDa, fireNegDa    

In [26]:
def compound(fireDa):
    """
    
    
    For now, assuming there are only three indeces being assessed, more should be possible, but not in scope
    
    """

    #Which indices are we using?
    indexNames=fireDa.index.values

    ensoIndex=(set(_index.enso).intersection(indexNames))
    iodIndex=(set(_index.iod).intersection(indexNames))
    samIndex=(set(_index.sam).intersection(indexNames))

    indexNames=[*ensoIndex, *iodIndex, *samIndex]
    
    #For now, limiting ourselves to one for each driver
    if len(ensoIndex)!=1:
        raise Error('number of enso indeces is not 1')
    elif len(iodIndex)!=1:
        raise Error('number of iod indeces is not 1')
    elif len(samIndex)!=1:
        raise Error('number of sam indeces is not 1')

    #Which years there their three?
    fireDa=compoundThree(fireDa)
    fireDa=fireDa.assign_attrs({**fireDa.attrs, 'all3':str(indexNames)})

    #Something to iterate names of pairs into
    pairs=list()
    
    #Match each index with those further along the index list
    #(Nested for loops are probably bad juju)
    for i1 in range(0,len(indexNames)):
        for i2 in range(i1+1, len(indexNames)):
            #Its a compound of those two, if they both occur, and excluding if its a compound of all three
            tempDa=(
                fireDa.sel(index=indexNames[i1])*
                fireDa.sel(index=indexNames[i2])*
                (fireDa.sel(index='all3')==False)
            )
            tempDa=tempDa.assign_coords(index=indexNames[i1]+'+'+indexNames[i2])
            tempDa=tempDa.expand_dims('index')
            fireDa=xarray.concat([fireDa,tempDa], 'index')
            
            pairs.append(indexNames[i1]+'+'+indexNames[i2])
            
    #Write the names of pairs into attributes for neatness
    fireDa=fireDa.assign_attrs({**fireDa.attrs, 'pairs':pairs})
            
    return fireDa

In [27]:
def compoundThree(fireDa): 
    
    #How many events were there in each year?
    fireCount=fireDa.sum(dim='index')
    fireCount=fireCount.assign_coords(index='count')
    fireCount=fireCount.expand_dims('index')
    fireDa = xarray.concat([fireDa, fireCount],'index')
    
    #Which years are there all three
    threeDa=(fireDa.sel(index='count')==3)
    threeDa['index']='all3'
    threeDa=threeDa.expand_dims('index')
    fireDa=xarray.concat([fireDa,threeDa], 'index')

    return fireDa

In [28]:
def overlappingBinSum(da):
    """Return a da with the sum of events for each index given in overlaping 30 year bins, seperated at 10 year intervals"""

    #overlapping 30 year bins at 10 year intervals from 850 to 2100

    #first bin mid poitns is 865, and last is 2085

    #number of bins is (2085-865)/10 + 1 = 123

    #output is sum for each index of the number of events, for each interval and experiment

    #output a list of DA to then concatenate
    
    #a couple of vars to append to
    binMid=list()
    binSum=list()

    #hardcoding these is lazy, but maybe fine
    startYear=int(da.year[0])
    endYear=int(da.year[-1])
    interval=10
    binSize=30

    numberOfBins=(endYear-startYear-binSize)/interval + 1

    # for every bin
    for iBin in numpy.arange(0,numberOfBins):
        # firstYear is 850 + counter*30
        firstYear=startYear+iBin*interval
        # last year is 30 years after
        lastYear=firstYear+binSize
        # label/midPoint for bin
        binMid.append(int((firstYear+lastYear)/2))
        # calculate the sum for this year interval        
        binSum.append(da.where((da.year>=firstYear) & (da.year<lastYear)).sum(dim='year'))

    #for the list of means, concat the results into a new xarray with new dimension 'year'
    overlapBinDa=xarray.concat(binSum, 'year')
    #populate the dimension year with the midpoint
    overlapBinDa=overlapBinDa.assign_coords(year=binMid)
    #add some attributes for reference
    overlapBinDa=overlapBinDa.assign_attrs({
        **da[0].attrs,
        'Bins':'Overlapping 30 year bins, seperating by 10 year intervals', 
        'Year':'Midpoint of bin'
    })

    return overlapBinDa