In [1]:

import rasterio
import xarray as xr
import sys
import datacube
import numpy as np

from multiprocessing import Pool
import dask.array as da
import os
import dask
from dask.distributed import Client


sys.path.append('../lib')
sys.path.append('../../Scripts')
import tsmask_func as tsf
import testpair as cym
import dea_datahandling as ddh
from dea_dask import create_local_dask_cluster

   


In [2]:
def partition_blocks(irow, icol, prow, pcol):
    
    py = irow // prow + 1
    px = icol // pcol + 1
    
    blocklist = []
    for i in range(prow):
        y1=i*py
        if i == prow -1:
            y2 = irow
        else:
            y2 = (i+1)*py
            
        for j in range(pcol):
            x1 = j*px
            if j == pcol-1:
                x2 = icol
            else:
                x2 = (j+1)*px
                
            blocklist.append([y1, y2, x1, x2])
    
    return py, px, blocklist
        

In [4]:
def bs_tsmask(blue, green, red, nir, swir1, swir2):
  
    bsmask = cym.tsmask_lastdim(blue, green, red, nir, swir1, swir2)
    
    return bsmask

def gen_tsmask(chblue, chgreen, chred, chnir, chswir1, chswir2):
    return xr.apply_ufunc(
        bs_tsmask, chblue, chgreen, chred, chnir, chswir1, chswir2,
        dask='parallelized',
        input_core_dims=[["time"], ["time"],["time"], ["time"],["time"], ["time"]],
        output_core_dims= [['indices']], 
        dask_gufunc_kwargs = {'output_sizes' : {'indices' : 4}},
        output_dtypes = [np.float32]
 
    )



In [5]:
dirc ='/home/jovyan/nmask_testdata/cbr_dask_run/indices'
loc_str ='canberra'

ncpu = 14


y1, y2 = -32.53284301899998, -33.52310232399998
x1, x2 = 121.934694247, 123.105109264

crs = 'EPSG:4326'
out_crs = 'UTM'

start_of_epoch = '2017-01-01'
end_of_epoch ='2020-12-31'

if out_crs == 'UTM':

    out_crs = tsf.utm_code(x1, x2)


client = Client(n_workers = ncpu, threads_per_worker=2, processes = True)
client



0,1
Client  Scheduler: tcp://127.0.0.1:39371  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 14  Cores: 28  Memory: 128.85 GB


In [6]:
 
dc = datacube.Datacube(app='load_clearsentinel')

s2_ds=tsf.load_s2_nbart_dask(dc, y1, y2, x1, x2, start_of_epoch, end_of_epoch, {   
        "time": 1 }, crs, out_crs )


 # number of rows
irow=s2_ds['y'].size
# number of columns
icol=s2_ds['x'].size
# number of time steps
tn = s2_ds['time'].size

s2_ds.close()



print(tn, irow, icol)


  username=username, password=password,


Finding datasets
    s2a_ard_granule
    s2b_ard_granule
Returning 485 time steps as a dask array
485 5515 5497


In [7]:
chy, chx, blocklist = partition_blocks(irow, icol, 3, 3)

tg_ds=tsf.load_s2_nbart_dask(dc, y1, y2, x1, x2, start_of_epoch, end_of_epoch, {   
        "time": 1, "y": chy, "x" : chx }, crs, out_crs )


Finding datasets
    s2a_ard_granule
    s2b_ard_granule
Returning 485 time steps as a dask array


In [8]:
# divide the whole target area into a set of smaller blocks
# so that the data volume required for each iteration can be acommodated by the system memory 

indices = np.zeros((irow,icol, 4), dtype=np.float32)

for block in blocklist:
    y1, y2, x1, x2 = block
    print("loading data for ", block)
    
    rchy = 64
    rchx = 64
    
    
    pblue = tg_ds.blue[:, y1:y2, x1:x2].persist()
    pgreen = tg_ds.green[:, y1:y2, x1:x2].persist()
    pred = tg_ds.red[:, y1:y2, x1:x2].persist()
    pnir = tg_ds.nir[:, y1:y2, x1:x2].persist()
    pswir1 = tg_ds.swir1[:, y1:y2, x1:x2].persist()
    pswir2 = tg_ds.swir2[:, y1:y2, x1:x2].persist()
    
  
    
    chblue = pblue.chunk({"time":-1, "y":rchy, "x":rchx})
    chgreen = pgreen.chunk({"time":-1, "y":rchy, "x":rchx})
    chred = pred.chunk({"time":-1, "y":rchy, "x":rchx})
    chnir = pnir.chunk({"time":-1, "y":rchy, "x":rchx})
    chswir1 = pswir1.chunk({"time":-1, "y":rchy, "x":rchx})
    chswir2 = pswir2.chunk({"time":-1, "y":rchy, "x":rchx})
   
    am = gen_tsmask(chblue, chgreen, chred, chnir, chswir1, chswir2)
 
    indices[y1:y2, x1:x2, :] = am.compute()
    print("Finish computing indices for ", block)
    
    
 

loading data for  [0, 1839, 0, 1833]
Finish computing indices for  [0, 1839, 0, 1833]
loading data for  [0, 1839, 1833, 3666]
Finish computing indices for  [0, 1839, 1833, 3666]
loading data for  [0, 1839, 3666, 5497]
Finish computing indices for  [0, 1839, 3666, 5497]
loading data for  [1839, 3678, 0, 1833]
Finish computing indices for  [1839, 3678, 0, 1833]
loading data for  [1839, 3678, 1833, 3666]
Finish computing indices for  [1839, 3678, 1833, 3666]
loading data for  [1839, 3678, 3666, 5497]
Finish computing indices for  [1839, 3678, 3666, 5497]
loading data for  [3678, 5515, 0, 1833]
Finish computing indices for  [3678, 5515, 0, 1833]
loading data for  [3678, 5515, 1833, 3666]
Finish computing indices for  [3678, 5515, 1833, 3666]
loading data for  [3678, 5515, 3666, 5497]
Finish computing indices for  [3678, 5515, 3666, 5497]


In [9]:
print(indices)
print(indices.shape)

[[[ 0.10530532 -0.5940883   0.14705923  0.45274743]
  [ 0.104082   -0.59441453  0.14665574  0.44977713]
  [ 0.10713799 -0.60134995  0.15053304  0.45919058]
  ...
  [ 0.11404356 -0.5907778   0.17154269  0.6007948 ]
  [ 0.10746598 -0.6051685   0.17719258  0.59506375]
  [ 0.10711598 -0.60066396  0.17867848  0.5899254 ]]

 [[ 0.10842401 -0.57209826  0.13972102  0.47483483]
  [ 0.10418689 -0.5970373   0.14236468  0.4641497 ]
  [ 0.10801766 -0.5889217   0.14521173  0.48964843]
  ...
  [ 0.10650909 -0.6031923   0.18129069  0.60714114]
  [ 0.10863233 -0.5896865   0.17715238  0.59213847]
  [ 0.10436713 -0.5806248   0.17231117  0.5821918 ]]

 [[ 0.10516712 -0.5921474   0.14890698  0.45380196]
  [ 0.1034426  -0.6032516   0.14505716  0.46297264]
  [ 0.10569072 -0.59957224  0.1435893   0.4772514 ]
  ...
  [ 0.10040749 -0.612911    0.18792157  0.59667003]
  [ 0.10573873 -0.59501356  0.1834635   0.5940693 ]
  [ 0.09932663 -0.61523765  0.18416291  0.5955776 ]]

 ...

 [[ 0.19492462 -0.49236044  0.2462

In [10]:
geotrans = tg_ds.geobox.transform.to_gdal()
prj = tg_ds.geobox.crs.wkt

indices_list=[ 's6m', 'mndwi', 'msavi','whi']

print("Begin writing long term mean of indices files")

for i, indname in enumerate(indices_list):
    fname = dirc + '/'+loc_str+'_'+indname+'_'+start_of_epoch+'_'+end_of_epoch+'.tif'
    ddh.array_to_geotiff(fname, indices[:, :, i], geotrans, prj)

print("Finish writing long term mean of indices files")

tg_ds.close()

Begin writing long term mean of indices files
Finish writing long term mean of indices files
