# <font color=green>Tasselled Cap Wetness Epoch Stats Summary Notebook</font>

### Loads surface reflectance data from the data cube, calculates tasselled cap indices, and outputs a netcdf file. Created by Bex Dunn modified by Vanessa Newey

In [1]:
#for writing to error files:
from __future__ import print_function
#get some libraries
import datacube
import xarray as xr
from datacube.storage import masking
#from datacube.storage.masking import mask_to_dict #think this is obsolete
import json
import pandas as pd
import shapely
from shapely.geometry import shape
import numpy as np #need this for pq fuser

#libraries for polygon and polygon mask
import fiona
import shapely.geometry
import rasterio.features
import rasterio
from datacube.utils import geometry
from datacube.helpers import ga_pq_fuser
from datacube.storage.masking import mask_invalid_data

#for writing to netcdf
from datacube.storage.storage import write_dataset_to_netcdf
#dealing with system commands
import sys
import os.path

#####These not needed for raijin::::
import matplotlib.pyplot as plt
from ipywidgets import interact
from IPython.display import display
import ipywidgets as widgets

#suppress warnings thrown when using inequalities in numpy (the threshold values!)
import warnings

def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)

In [2]:
def load_nbart(sensor,query,bands_of_interest): 
    '''loads nbart data for a sensor, masks using pq, then filters out terrain -999s
    function written 23-08-2017 based on dc v1.5.1'''  
    dataset = []
    product_name = '{}_{}_albers'.format(sensor, 'nbart')
    print('loading {}'.format(product_name))
    ds = dc.load(product=product_name, measurements=bands_of_interest,
                 group_by='solar_day', **query)
    #grab crs defs from loaded ds if ds exists
    if ds:
        crs = ds.crs
        affine = ds.affine
        print('loaded {}'.format(product_name))
        mask_product = '{}_{}_albers'.format(sensor, 'pq')
        sensor_pq = dc.load(product=mask_product, fuse_func=ga_pq_fuser,
                            group_by='solar_day', **query)
        if sensor_pq:
            print('making mask {}'.format(mask_product))
            cloud_free = masking.make_mask(sensor_pq.pixelquality,
                                           cloud_acca='no_cloud',
                                           cloud_shadow_acca = 'no_cloud_shadow',                           
                                           cloud_shadow_fmask = 'no_cloud_shadow',
                                           cloud_fmask='no_cloud',
                                           blue_saturated = False,
                                           green_saturated = False,
                                           red_saturated = False,
                                           nir_saturated = False,
                                           swir1_saturated = False,
                                           swir2_saturated = False,
                                           contiguous=True)
            ds = ds.where(cloud_free)
            ds.attrs['crs'] = crs
            ds.attrs['affine'] = affine
            print('masked {} with {} and filtered terrain'.format(product_name,mask_product))
            # nbarT is correctly used to correct terrain by replacing -999.0 with nan
            ds=ds.where(ds!=-999.0)
        else: 
            print('did not mask {} with {}'.format(product_name,mask_product))
    else:
        print ('did not load {}'.format(product_name)) 

    if len(ds)>0:
        return ds
    else:
        return None

In [3]:
def calc_wetness(sensor_data,sensor):
    '''This function multiplies band data by wetness coefficients to produce a "wetness" band.
    sensor_data is surface reflectance data loaded from the datacube
    sensor = 'ls5, 'ls7' or 'ls8'
    Coefficients are from Crist and Cicone 1984 for ls5 and ls7, and from Baig, Zhang, Shuai & Tong for ls8
    function written 23-08-2017 based on dc v1.5.1'''
    
    wetness_coeff ={'ls5':{'blue':0.151, 'green':0.179, 'red':0.330, 'nir':0.341, 'swir1':-0.711, 'swir2':-0.457},
                    'ls7':{'blue':0.151, 'green':0.179, 'red':0.330, 'nir':0.341, 'swir1':-0.711, 'swir2':-0.457},
                    'ls8':{'blue':0.1511,'green':0.1973,'red':0.3283,'nir':0.3407,'swir1':-0.7117,'swir2':-0.4559}}  
    
    #if there is sensor data for the time period
    if sensor_data is not None: 
        #make a deep copy of the sensor data
        wetness = sensor_data.copy(deep=True)
        #iterate over the spectral bands
        for band_name in sensor_data.data_vars:
            #multiply each band by the wetness transform coefficient to get a band-specific wetness value
            wetness_band = sensor_data[band_name]*wetness_coeff[sensor][band_name]
            #update the existing band data with the wetness data
            wetness.update({band_name:(['time','y','x'],wetness_band)})
        #finally, add a wetness data variable to the array that is the sum of the wetness "bands"    
        wetness['wetness']=wetness.blue+wetness.green+wetness.red+wetness.nir+wetness.swir1+wetness.swir2    
        print('calculated wetness for {}'.format(sensor))
        wetness = wetness.drop(('blue','green','red','nir','swir1','swir2'))
        return wetness
    
    else:
        print('did not calculate wetness for {}'.format(sensor))
        return None    

In [4]:
def calc_wetveg_overthresh(wetness,threshold=-400):
    '''Calculate the wetness values where wetness>threshold. Inputs are wetness array and threshold value, 
    default threshold is -400. Band for wetness>threshold is added to wetness. This is not the count.'''
    if wetness is not None:
        with warnings.catch_warnings():
            #suppress irritating behaviour in xarray.where
            warnings.simplefilter("ignore")
            #water_plus_wetveg is wetness values where wetness>threshold
            wetness['water_plus_wetveg'] = wetness.wetness.where(wetness.wetness>threshold)
            print('thresholded wetness added to array')
            return wetness
    else:
        print('did not calculate wetness overthreshold' )
        return None    

In [5]:
def count_wets(wetness):
    '''count the number of wetness scenes for each pixel,
    count the amount of times that water plus wet veg is above the threshold
    load both into memory (this assumes you are using dask),
    return a dictionary of wet count and threshold count'''
    if wetness is not None:
        #count the number of wetness scenes for each pixel
        wet_count = wetness.wetness.count(dim='time')

        #count the amount of times that water plus wet veg is above the threshold
        threshold_count= wetness.water_plus_wetveg.count(dim='time')
        
        #bring both counts into memory
        wet_count.load()
        threshold_count.load()
        
        #define dictionary of wet count and threshold count
        counts = {'wet count':wet_count, 'threshold count':threshold_count}
        print('counted')
        return counts
    else:
        print('did not count' )
        return None    

In [6]:
def write_your_netcdf(data, dataset_name, filename,crs):
    '''this function turns an xarray dataarray into a dataset so we can write it to netcdf. It adds on a crs definition
    from the original array. data = your xarray dataset, dataset_name is a string describing your variable'''    
    #turn array into dataset so we can write the netcdf
    dataset= data.to_dataset(name=dataset_name)
    #grab our crs attributes to write a spatially-referenced netcdf
    dataset.attrs['crs'] = crs
    #dataset.dataset_name.attrs['crs'] = crs
    try:
        write_dataset_to_netcdf(dataset, filename)
    except RuntimeError as err:
        print("RuntimeError: {0}".format(err))        

## Enter input shapefile, output file and start and end date

In [7]:
#save netcdf outputs to this folder:
netcdf_output_loc ='/g/data/r78/rjd547/groundwater_activities/Analysis/'

#netcdf_output_loc ='/g/data/r78/'

In [8]:
#code to work with a polygon input
shape_file = ('/home/547/rjd547/jupyter_notebooks/GWandDEA_bex_ness/Little_GW_AOI_for_demo/kEEP_ord/KEEP_AOI.shp')
style = {'description_width': 'initial'}
shape_file_text = widgets.Text(value=shape_file,placeholder='update this field',
    description='path to shape file',
    style = {'description_width': 'initial'},                          
    disabled=False,
    layout=widgets.Layout(width='70%'))
def handle_shape_file(sender):
    shape_file=shape_file_text.value
shape_file_text.observe(handle_shape_file)
display(shape_file_text)
# open all the shapes within the shape file
shapes = fiona.open(shape_file)

In [9]:
#i is the index of the shape file we have chosen
i =0 
#copy attributes from shapefile and define shape_name
geom_crs = geometry.CRS(shapes.crs_wkt)
geo = shapes[i]['geometry']
geom = geometry.Geometry(geo, crs=geom_crs)
geom_bs = shapely.geometry.shape(shapes[i]['geometry'])
shape_name = shape_file.split('/')[-1].split('.')[0]+'_'+str(i)


In [10]:
# #check if the file has already been written:
filename = netcdf_output_loc+shape_name+'test1.nc'
infoLabel = widgets.Label(value="Please enter the filename including the path to the output NetCDF file",
    color='Red')
display(infoLabel)

output_file_text = widgets.Text(value=filename,placeholder='update this field',
    description='path to output file',
    style = {'description_width': 'initial'},
    disabled=False,
    layout=widgets.Layout(width='70%'))
def handle_output_file(sender):
    filename=output_file_text.value
    if os.path.isfile(filename):
        infoLabel.value = '{} already exists please change filename'.format(filename)
       # display(infoLabel)
    else:
        infoLabel.value = '{} is the output filename'.format(filename)
output_file_text.on_submit(handle_output_file)
display(output_file_text)


    

In [11]:
#tell the datacube which app to use
dc = datacube.Datacube(app='dc-nbar')

In [12]:
#### DEFINE SPATIOTEMPORAL RANGE AND BANDS OF INTEREST
#Define temporal range

start_of_epoch = '2016-01-01'
end_of_epoch =  '2016-12-31'
#TODO Replace with datepicker widget when ipywidgets devs sort this out
from_date_picker = widgets.Text(value=start_of_epoch,placeholder='update this field',
    description='start date',
    disabled=False)
def handle_from_date(sender):
    start_of_epoch=from_date_picker.value
from_date_picker.observe(handle_from_date)
display(from_date_picker)

#TODO Replace with datepicker widget when ipywidgets devs sort this out
to_date_picker = widgets.Text(value = end_of_epoch,placeholder='update this field',
    description='end date',
    disabled=False)
def handle_to_date(sender):
    end_of_epoch = to_date_picker.value
to_date_picker.observe(handle_to_date)
display(to_date_picker)


In [17]:
GoButton= widgets.Button(description='Load Data and calculate wetness')
def handle_load_calc(b):
    #Define wavelengths/bands of interest, remove this kwarg to retrieve all bands
    bands_of_interest = ['blue',
                         'green',
                         'red',
                         'nir',
                         'swir1',
                         'swir2'
                         ]

    query = {
        'time': (start_of_epoch, end_of_epoch), 'geopolygon': geom
    }


    #this is done separately instead of in a loop because the datasets can be quite large.
    #currently this is a way of memory handling -there is probably a better way of doing it.
    sensor1_nbart=load_nbart('ls5',query,bands_of_interest)
    sensor2_nbart=load_nbart('ls7',query,bands_of_interest)
    sensor3_nbart=load_nbart('ls8',query,bands_of_interest)

    print('Calculate wetness for each timeslice')

    wetness_sensor1_nbart=calc_wetness(sensor1_nbart,'ls5')
    wetness_sensor2_nbart=calc_wetness(sensor2_nbart,'ls7')
    wetness_sensor3_nbart=calc_wetness(sensor3_nbart,'ls8')

    print('Calculate wetness over the threshold for each timeslice (remove values under the threshold)')

    water_plus_wetveg_1 =calc_wetveg_overthresh(wetness_sensor1_nbart)
    water_plus_wetveg_2 =calc_wetveg_overthresh(wetness_sensor2_nbart)
    water_plus_wetveg_3 =calc_wetveg_overthresh(wetness_sensor3_nbart)

    wetness_multi =water_plus_wetveg_3
    nbart_multi = sensor3_nbart
    #wetness_multi = xr.concat([water_plus_wetveg_1,water_plus_wetveg_2,water_plus_wetveg_3], dim='time')
    #nbart_multi = xr.concat([sensor1_nbart,sensor2_nbart,sensor3_nbart], dim ='time')
    # Set the percentage of good data that you'd like to display with pernan variable - 0.9 will return rows that have 90%
    # of valid values
    pernan = 0.8
    water_plus_veg_sum = wetness_multi#.dropna('time',  thresh = int(pernan*wetness_multi.wetness.isel(time=0).size))
    nbart_multi_drop = nbart_multi.where(nbart_multi.time  == water_plus_veg_sum.time)
    print('ran fine no issues')
    return water_plus_veg_sum
    return nbart_multi_drop
display(GoButton)
GoButton.on_click(handle_load_calc)

loading ls5_nbart_albers
did not load ls5_nbart_albers
loading ls7_nbart_albers
loaded ls7_nbart_albers
making mask ls7_pq_albers
masked ls7_nbart_albers with ls7_pq_albers and filtered terrain
loading ls8_nbart_albers
loaded ls8_nbart_albers
making mask ls8_pq_albers
masked ls8_nbart_albers with ls8_pq_albers and filtered terrain
Calculate wetness for each timeslice
did not calculate wetness for ls5
calculated wetness for ls7
calculated wetness for ls8
Calculate wetness over the threshold for each timeslice (remove values under the threshold)
did not calculate wetness overthreshold
thresholded wetness added to array
thresholded wetness added to array
ran fine no issues
loading ls5_nbart_albers
did not load ls5_nbart_albers
loading ls7_nbart_albers
loaded ls7_nbart_albers
making mask ls7_pq_albers
masked ls7_nbart_albers with ls7_pq_albers and filtered terrain
loading ls8_nbart_albers
loaded ls8_nbart_albers
making mask ls8_pq_albers
masked ls8_nbart_albers with ls8_pq_albers and filte

NameError: name 'thres' is not defined

In [24]:
plot_button = widgets.Button(description='Plot outputs')
def handle_plot_save(b):
    def f(n):
        try:
            plt.clf()
            fig, (ax1,ax2,ax3) = plt.subplots(figsize=(12,5),ncols=3)        
            water_plus_veg_sum.wetness.isel(time=n).plot(ax=ax1, cmap='Greens')                
            #print(nbart_multi_drop2.time)
            #nbart_multi_drop2.red.isel(time=n).plot(ax=ax2, cmap='Blues')
            rgb = nbart_multi_drop.isel(time =n).to_array(dim='color').sel(color=['swir1', 'nir', 'green']).transpose('y', 'x', 'color')
            #rgb = nbar_clean.isel(time =time_slice).to_array(dim='color').sel(color=['swir1', 'nir', 'green']).transpose('y', 'x', 'color')
            fake_saturation = 4500
            clipped_visible = rgb.where(rgb<fake_saturation).fillna(fake_saturation)
            max_val = clipped_visible.max(['y', 'x'])
            scaled = (clipped_visible / max_val)

            ax2.imshow(scaled, interpolation = 'nearest',
               extent=[scaled.coords['x'].min(), scaled.coords['x'].max(), 
                       scaled.coords['y'].min(), scaled.coords['y'].max()])

            date_ = nbart_multi_drop.isel(time=n).time.data
            #ax2.set_title(date_.astype('datetime64[D]'))
            ax2.set_title(str(date_))

            fig.text=str(water_plus_veg_sum.wetness.isel(time=n).time)


            plt.tight_layout()
            plt.show()
        except:
            print('timeslice ' + str(water_plus_veg_sum.wetness.isel(time=n).time.time) + ' has some null data')


    timeslices = len(water_plus_veg_sum.time)                                    
    interact(f,n=(0,timeslices-1),value=timeslices-1)
    display()

    print('Count number of wetness scenes and number of times tcw above threshold for each pixel')

    counts_sensor_1_nbart = count_wets(wetness_sensor1_nbart)
    counts_sensor_2_nbart = count_wets(wetness_sensor2_nbart)
    counts_sensor_3_nbart = count_wets(wetness_sensor3_nbart)

    test = counts_sensor_3_nbart['threshold count']/counts_sensor_3_nbart['wet count']
    test.plot(cmap ='gist_earth_r')
    plt.show()

    print('Divide the number of times wetness is seen per pixel by the number of wetness scenes per pixel to get a proportion of time that the pixel is wet')

    counts_list = [counts_sensor_1_nbart, counts_sensor_2_nbart,counts_sensor_3_nbart]
    threshold_list =[]
    wet_list=[]
    for acount in counts_list:
        #test for data existence
        if acount is not None:
            wet_count = acount['wet count']
            threshold = acount['threshold count']
            threshold_list.append(threshold)
            wet_list.append(wet_count)
    #times wetness is over threshold by pixel         
    threshold_allsensors = sum(threshold_list) 
    #number of wetness scenes by pixel
    wet_count_allsensors = sum(wet_list)        
    wet_proportion_allsensors = threshold_allsensors/wet_count_allsensors

    wet_proportion_allsensors.plot(cmap ='gist_earth_r')

    print('successfully ran TCW for '+shape_name+' polygon number '+str(i))

    ## this is to steal the crs from whichever wetness array actually has one

    if wetness_sensor1_nbart is not None:
        crs = wetness_sensor1_nbart.crs
    else:
        if wetness_sensor2_nbart is not None:
            crs = wetness_sensor2_nbart.crs
        else: 
            crs = wetness_sensor3_nbart.crs
    print(crs)     

    if wetness_sensor1_nbart is not None:
        crs = wetness_sensor1_nbart.crs
    else:
        if wetness_sensor2_nbart is not None:
            crs = wetness_sensor2_nbart.crs
        else: 
            crs = wetness_sensor3_nbart.crs
    print(crs)     

    write_your_netcdf(wet_proportion_allsensors,'tcw',filename=filename, crs=crs)
    print('successfully wrote tcw netcdf for '+shape_name+' polygon number '+str(i))
    eprint('successfully wrote tcw netcdf for for '+shape_name+' polygon number '+str(i))

    #overthresh is observations over our wetness threshold count per pixel
    write_your_netcdf(threshold_allsensors,'overthresh',filename=netcdf_output_loc+shape_name+'_overthresh.nc',crs=crs)
    print('successfully wrote overthresh netCDF for '+shape_name+' polygon number '+str(i))
    eprint('successfully wrote overthresh netCDFfor '+shape_name+' polygon number '+str(i))

    #clear_observations is count of wetness scenes at pixel
    write_your_netcdf(wet_count_allsensors,'clearobs',filename=netcdf_output_loc+shape_name+'_clearobs.nc',crs=crs)
    print('successfully wrote clearobs netCDF for '+shape_name+' polygon number '+str(i))
    eprint('successfully wrote clearobs netCDFfor '+shape_name+' polygon number '+str(i))
display(plot_button)
plot_button.on_click(handle_plot_save)

NameError: name 'water_plus_veg_sum' is not defined