In [None]:
##TCI code implementing dask for maximum spatial extent and tracking memory usage

In [1]:
#!/usr/bin/env python
#TCI_shapefile1.py

''' 
This code loads in surface reflectance data from the data cube, calculates 
tasselled cap indices, and outputs a netcdf file.
Created by Bex Dunn 08/05/2017
'''
#get some libraries
import datacube
import xarray as xr
from datacube.storage import masking
from datacube.storage.masking import mask_to_dict
import json
import pandas as pd
import shapely
from shapely.geometry import shape
import numpy as np #need this for pq fuser

#libraries for polygon and polygon mask
import fiona
import shapely.geometry
import rasterio.features
import rasterio
from datacube.utils import geometry
from datacube.storage.masking import mask_valid_data as mask_invalid_data

#for writing to netcdf
from datacube.storage.storage import write_dataset_to_netcdf
#dealing with system commands
import sys

#suppress warnings thrown when using inequalities in numpy (the threshold values!)
import warnings

In [48]:
#%load_ext memory_profiler
%mprun -T 'mptest.txt' -f 




*** Profile printout saved to text file 'mptest.txt'. 


In [38]:
%mprun?

In [14]:
#code to work with a polygon input rather than a lat/long box
# #pick a shape file
shape_file = ('/g/data/r78/rjd547/groundwater_activities/Burdekin_shapefiles/burd_dam/burd_dam_noZ.shp')
# open all the shapes within the shape file
shapes = fiona.open(shape_file)

i=0
print('i is :'+str(i))
if i > len(shapes):
    print('index not in the range for the shapefile: '+str(i)+' not in '+str(len(shapes)))
    sys.exit(0)

#code to take in the system argument i.e. the number of the polygon to use.
#copy attributes from shapefile and define shape_name
geom_crs = geometry.CRS(shapes.crs_wkt)
geo = shapes[i]['geometry']
geom = geometry.Geometry(geo, crs=geom_crs)
geom_bs = shapely.geometry.shape(shapes[i]['geometry'])
shape_name = shape_file.split('/')[-1].split('.')[0]+'_'+str(i)

i is :0


In [15]:
geom.boundingbox

BoundingBox(left=402430.4822999984, bottom=-2240082.9014999997, right=435008.5001000017, top=-2201792.9518999998)

In [16]:
spatial_q = {
    'x': (geom.boundingbox.left, geom.boundingbox.right), 
    'y': (geom.boundingbox.top, geom.boundingbox.bottom),
    'crs': geom.crs.wkt,
    }
spatial_q

{'crs': 'PROJCS["GABWRA_Albers_Equal_Area_Conic",GEOGCS["GCS_WGS_1984",DATUM["WGS_1984",SPHEROID["WGS_84",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Albers_Conic_Equal_Area"],PARAMETER["False_Easting",0.0],PARAMETER["False_Northing",0.0],PARAMETER["longitude_of_center",143.0],PARAMETER["Standard_Parallel_1",-21.0],PARAMETER["Standard_Parallel_2",-29.0],PARAMETER["latitude_of_center",0.0],UNIT["Meter",1.0]]',
 'x': (402430.4822999984, 435008.5001000017),
 'y': (-2201792.9518999998, -2240082.9014999997)}

In [17]:
import dask
dask.set_options(get=dask.get)

<dask.context.set_options at 0x7f5f5b70d8d0>

In [None]:
#tell the datacube which app to use
dc = datacube.Datacube(app='dc-nbar')

#### DEFINE SPATIOTEMPORAL RANGE AND BANDS OF INTEREST
#Use this to manually define an upper left/lower right coords
#Either as polygon or as lat/lon range


#Define temporal range
start_of_epoch = '1987-01-01'
#need a variable here that defines a rolling 'latest observation'
end_of_epoch =  '2016-12-31'

#Define wavelengths/bands of interest, remove this kwarg to retrieve all bands
bands_of_interest = ['blue',
                     'green',
                     'red', 
                     'nir',
                     'swir1', 
                     'swir2'
                     ]

#Define sensors of interest
sensor1 = 'ls5'
sensor2 = 'ls7'
sensor3 = 'ls8'

query = {
    'time': (start_of_epoch, end_of_epoch), # 'geopolygon': geom
    'dask_chunks': {'time': 5},
}
query.update(spatial_q)

#Group PQ by solar day to avoid idiosyncracies of N/S overlap differences in PQ algorithm performance
pq_albers_product = dc.index.products.get_by_name(sensor1+'_pq_albers')
valid_bit = pq_albers_product.measurements['pixelquality']['flags_definition']['contiguous']['bits']

def pq_fuser(dest, src):
    valid_val = (1 << valid_bit)

    no_data_dest_mask = ~(dest & valid_val).astype(bool)
    np.copyto(dest, src, where=no_data_dest_mask)

    both_data_mask = (valid_val & dest & src).astype(bool)
    np.copyto(dest, src & dest, where=both_data_mask)

wetness_coeff = {}
wetness_coeff['ls5'] = (0.151, 0.179, 0.330, 0.341, -0.711, -0.457)
wetness_coeff['ls7'] = (0.151, 0.179, 0.330, 0.341, -0.711, -0.457)
#wetness_coeff['ls7'] = (0.2626, 0.2141, 0.0926, 0.0656, -0.7629, -0.5388)
wetness_coeff['ls8'] = (0.1511, 0.1973, 0.3283, 0.3407, -0.7117, -0.4559)


## PQ and Index preparation


# retrieve the NBAR and PQ for the spatiotemporal range of interest


#Retrieve the NBAR and PQ data for sensor n
sensor1_nbar = dc.load(product= sensor1+'_nbart_albers', group_by='solar_day', measurements = bands_of_interest,  **query)
sensor1_pq = dc.load(product= sensor1+'_pq_albers', group_by='solar_day', fuse_func=pq_fuser, **query)
           
crs = sensor1_nbar.crs
crswkt = sensor1_nbar.crs.wkt
affine = sensor1_nbar.affine

#Generate PQ masks and apply those masks to remove cloud, cloud shadow, saturated observations
s1_cloud_free = masking.make_mask(sensor1_pq, 
                              cloud_acca='no_cloud',
                              cloud_shadow_acca = 'no_cloud_shadow',
                              cloud_shadow_fmask = 'no_cloud_shadow',
                              cloud_fmask='no_cloud',
                              blue_saturated = False,
                              green_saturated = False,
                              red_saturated = False,
                              nir_saturated = False,
                              swir1_saturated = False,
                              swir2_saturated = False,
                              contiguous=True)
s1_good_data = s1_cloud_free.pixelquality.loc[start_of_epoch:end_of_epoch]
sensor1_nbar = sensor1_nbar.where(s1_good_data)
sensor1_nbar.attrs['crs'] = crs
sensor1_nbar.attrs['affine'] = affine

sensor2_nbar = dc.load(product= sensor2+'_nbart_albers', group_by='solar_day', measurements = bands_of_interest,  **query)
sensor2_pq = dc.load(product= sensor2+'_pq_albers', group_by='solar_day', fuse_func=pq_fuser, **query)                  

s2_cloud_free = masking.make_mask(sensor2_pq, 
                              cloud_acca='no_cloud',
                              cloud_shadow_acca = 'no_cloud_shadow',
                              cloud_shadow_fmask = 'no_cloud_shadow',
                              cloud_fmask='no_cloud',
                              blue_saturated = False,
                              green_saturated = False,
                              red_saturated = False,
                              nir_saturated = False,
                              swir1_saturated = False,
                              swir2_saturated = False,
                              contiguous=True)
s2_good_data = s2_cloud_free.pixelquality.loc[start_of_epoch:end_of_epoch]
sensor2_nbar = sensor2_nbar.where(s2_good_data)
sensor2_nbar.attrs['crs'] = crs
sensor2_nbar.attrs['affine'] = affine

sensor3_nbar = dc.load(product= sensor3+'_nbart_albers', group_by='solar_day', measurements = bands_of_interest,  **query)
sensor3_pq = dc.load(product= sensor3+'_pq_albers', group_by='solar_day', fuse_func=pq_fuser, **query)                  

s3_cloud_free = masking.make_mask(sensor3_pq, 
                              cloud_acca='no_cloud',
                              cloud_shadow_acca = 'no_cloud_shadow',
                              cloud_shadow_fmask = 'no_cloud_shadow',
                              cloud_fmask='no_cloud',
                              blue_saturated = False,
                              green_saturated = False,
                              red_saturated = False,
                              nir_saturated = False,
                              swir1_saturated = False,
                              swir2_saturated = False,
                              contiguous=True)
s3_good_data = s3_cloud_free.pixelquality.loc[start_of_epoch:end_of_epoch]
sensor3_nbar = sensor3_nbar.where(s3_good_data)
sensor3_nbar.attrs['crs'] = crs
sensor3_nbar.attrs['affine'] = affine

nbar_clean = xr.concat([sensor1_nbar, sensor2_nbar, sensor3_nbar], dim='time')
time_sorted = nbar_clean.time.argsort()
nbar_clean = nbar_clean.isel(time=time_sorted)

#Calculate Taselled Cap Wetness
wetness_sensor1_nbar = ((sensor1_nbar.blue*wetness_coeff[sensor1][0])+(sensor1_nbar.green*wetness_coeff[sensor1][1])+
                          (sensor1_nbar.red*wetness_coeff[sensor1][2])+(sensor1_nbar.nir*wetness_coeff[sensor1][3])+
                          (sensor1_nbar.swir1*wetness_coeff[sensor1][4])+(sensor1_nbar.swir2*wetness_coeff[sensor1][5]))
wetness_sensor2_nbar = ((sensor2_nbar.blue*wetness_coeff[sensor2][0])+(sensor2_nbar.green*wetness_coeff[sensor2][1])+
                          (sensor2_nbar.red*wetness_coeff[sensor2][2])+(sensor2_nbar.nir*wetness_coeff[sensor2][3])+
                          (sensor2_nbar.swir1*wetness_coeff[sensor2][4])+(sensor2_nbar.swir2*wetness_coeff[sensor2][5]))
wetness_sensor3_nbar = ((sensor3_nbar.blue*wetness_coeff[sensor3][0])+(sensor3_nbar.green*wetness_coeff[sensor3][1])+
                          (sensor3_nbar.red*wetness_coeff[sensor3][2])+(sensor3_nbar.nir*wetness_coeff[sensor3][3])+
                          (sensor3_nbar.swir1*wetness_coeff[sensor3][4])+(sensor3_nbar.swir2*wetness_coeff[sensor3][5]))

In [None]:
wet_thresh = -400 #may need to adapt this based on aridity/location

#count the number of wetness scenes for each pixel
wet_count_1 = wetness_sensor1_nbar.count(dim = 'time')

#set wetness data threshold. catch warning about using numpy inequalities.. RuntimeWarning: invalid value encountered in greaterif not reflexive
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    wat_plus_wetv_1=wetness_sensor1_nbar.where(wetness_sensor1_nbar>wet_thresh)

#count the amount of times that water plus wet veg is above the threshold for each pixel
threshold_count_1=wat_plus_wetv_1.count(dim='time')

#divide the number of times wetness is seen by the number of wetness scenes to get a proportion of time that the 
#pixel is wet or wet veg'd:
new_wet_count_1= threshold_count_1/wet_count_1


In [None]:
new_wet_count_1.shape

In [None]:
##Do it all again, because concatenating it makes it hard to do the mask filter 'where'
#count the number of wetness scenes for each pixel
wet_count_2 = wetness_sensor2_nbar.count(dim = 'time')

#set wetness data threshold. catch warning about using numpy inequalities.. RuntimeWarning: invalid value encountered in greaterif not reflexive
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    wat_plus_wetv_2=wetness_sensor2_nbar.where(wetness_sensor2_nbar>wet_thresh)

#count the amount of times that water plus wet veg is above the threshold for each pixel
threshold_count_2=wat_plus_wetv_2.count(dim='time')

#divide the number of times wetness is seen by the number of wetness scenes to get a proportion of time that the 
#pixel is wet or wet veg'd:
new_wet_count_2= threshold_count_2/wet_count_2

In [None]:
##Do it all again, because concatenating it makes it hard to do the mask filter 'where'
#count the number of wetness scenes for each pixel
wet_count_3 = wetness_sensor3_nbar.count(dim = 'time')

#set wetness data threshold. catch warning about using numpy inequalities.. RuntimeWarning: invalid value encountered in greaterif not reflexive
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    wat_plus_wetv_3=wetness_sensor3_nbar.where(wetness_sensor3_nbar>wet_thresh)

#count the amount of times that water plus wet veg is above the threshold for each pixel
threshold_count_3=wat_plus_wetv_3.count(dim='time')

#divide the number of times wetness is seen by the number of wetness scenes to get a proportion of time that the 
#pixel is wet or wet veg'd:
new_wet_count_3= threshold_count_3/wet_count_3

In [None]:
new_wet_count = dask.array.stack([new_wet_count_1,new_wet_count_2,new_wet_count_3])
new_wet_count = dask.array.sum(new_wet_count, axis =0)

In [None]:
#load new_wet_count into memory so we can plot it and write it to netcd
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.imshow(new_wet_count, cmap = 'gist_earth_r')
plt.show()

In [None]:
# #Output wet_count to netCDF

# #get the original dataset attributes (crs)
# #set up variable attributes to hold the attributes from sensor1_nbar
# attrs = sensor1_nbar
# #dump the extra data to just get the attributes
# nbar_data = attrs.data_vars.keys()
# for j in nbar_data:
#     #drop band data, retaining just the attributes
#     attrs =attrs.drop(j)
# #set up new variable called wet_vars, and assign attributes to it in a dictionary
# wet_vars = {'wet_count':''}
# wet_count_data = attrs.assign(**wet_vars)
# wet_count_data['wet_count'] = wet_count

# ncpath = '/g/data/r78/rjd547/groundwater_activities/GalileeBasin/Gal_AOI_5k_Raijin/'

# try:
#     write_dataset_to_netcdf(wet_count_data,variable_params={'wet_count': {'zlib':True}}, filename=ncpath+shape_name+'_run01.nc')
# except RuntimeError as err:
#     print("RuntimeError: {0}".format(err))
    
    
# print('successfully ran TCI for '+shape_name+' polygon number '+str(i))


In [42]:
import time

@profile
def test1():
    n = 10000
    a = [1] * n
    time.sleep(1)
    return a

@profile
def test2():
    n = 100000
    b = [1] * n
    time.sleep(1)
    return b

if __name__ == "__main__":
    test1()
    test2()

ERROR: Could not find file <ipython-input-42-f6c03c205705>
NOTE: %mprun can only be used on functions defined in physical files, and not in the IPython environment.
ERROR: Could not find file <ipython-input-42-f6c03c205705>
NOTE: %mprun can only be used on functions defined in physical files, and not in the IPython environment.
