In [None]:
import numpy as np
import xarray as xr
import datacube 

import dask
from dask.distributed import Client

import sys
sys.path.append('src')
import DEADataHandling
import query_from_shp

### Info for parallel processing with Dask
1. If reading netcdf files make sure each worker has one thread
2. memory_limit is per worker not per cluster of workers
3. When launching multiple workers (needed when reading netcdfs) on the same node you have to supply memory limit, otherwise every worker will assume they have all the memory

In [None]:
client = Client(n_workers=4, threads_per_worker=1, memory_limit='5GB')
client

### User Inputs

In [None]:
#If not using a polygon then enter your AOI coords
#below:
lat, lon = -34.294, 146.037
latLon_adjust = 0.10

start = '2019-03-01'
end = '2019-05-31'

shp_fpath = "/g/data/r78/cb3058/dea-notebooks/dcStats/data/spatial/griffith_MSAVI_test.shp"
chunk_size = 800

### Load data

In [None]:
query = query_from_shp.query_from_shp(shp_fpath, start, end)
query

In [None]:
# query = {'lon': (lon - latLon_adjust, lon + latLon_adjust),
#          'lat': (lat - latLon_adjust, lat + latLon_adjust),
#         'time': ('2014-12-01', '2019-05-31'), }

query = query_from_shp.query_from_shp(shp_fpath, start, end)
dc = datacube.Datacube(app='load_clearlandsat')
ds = DEADataHandling.load_clearlandsat(dc=dc, query=query, sensors=['ls5','ls7','ls8'], bands_of_interest=['nir', 'red'],
                                       dask_chunks = {'x': chunk_size, 'y': chunk_size}, masked_prop=0.25, mask_pixel_quality=True)

In [None]:
def msavi_func(nir, red):
    return (2*nir+1-np.sqrt((2*nir+1)**2 - 8*(nir-red)))/2

def msavi_ufunc(ds):
    return xr.apply_ufunc(
        msavi_func, ds.nir, ds.red,
        dask='parallelized',
        output_dtypes=[float])

def compute_seasonal(data):		
    msavi = msavi_ufunc(data)
    #calculate the MSAVI    
    msavi = msavi.resample(time='M').mean('time')
    #calculate seasonal climatology
    msavi_seasonalClimatology = msavi.groupby('time.season').mean('time')
    #resample monthly msavi to seasonal means
    msavi_seasonalMeans = msavi.resample(time='QS-DEC').mean('time')
    #calculate anomalies
    masvi_anomalies = msavi_seasonalMeans.groupby('time.season') - msavi_seasonalClimatology
    return masvi_anomalies

a = compute_seasonal(ds)
a.to_netcdf('results/test.nc')  

In [None]:
#reopen without dask chunks and it'll plot quickly
b = xr.open_dataarray('results/test.nc')
b

In [None]:
b.plot(x='x',y='y', col='time', col_wrap=5, vmin=-0.5,vmax=0.5, cmap='RdBu', figsize=(15,15))

In [None]:
b.mean(['x', 'y']).plot(figsize=(12,5), ylim=(-0.25, 0.25))

### Compute monthly MSAVI anomalies

In [None]:
def compute_monthly(data):		
    #calculate the MSAVI
    msavi = xr.DataArray(data = (2*data.nir+1-np.sqrt((2*data.nir+1)**2 - 8*(data.nir-data.red)))/2,
                  coords=data.coords,attrs=dict(crs=data.crs))
    
    msavi = msavi.resample(time='M').mean('time')
    
    #calculate climatology
    climatology = msavi.groupby('time.month').mean('time')
    #calculate anomalies
    anomalies = msavi.groupby('time.month') - climatology

    return anomalies, climatology 

x, y = compute_monthly(ds)

In [None]:
y.plot(x='x', y='y', col='month',col_wrap=4, vmin=0.0,vmax=1.0, figsize=(15,10), cmap='plasma')

In [None]:
x.mean(['x', 'y']).plot(figsize=(12,5), ylim=(-0.25, 0.25))

In [None]:
# 2018 anomalies
# x.isel(time=range(-17,-5)).plot(x='x',y='y',col='time',col_wrap=4,figsize=(13,10), vmin=-0.5, vmax=0.5, cmap='BrBG')

In [None]:
# 2010 anomalies
# x.isel(time=range(265,277)).plot(x='x',y='y',col='time',col_wrap=4,figsize=(13,10), vmin=-0.5, vmax=0.5, cmap='BrBG')


### Seasonal MSAVi anomalies 

Not sure this is working as expected. Appears that the mean of the anomalies is not zero as you'd expect.

In [None]:
 def compute_seasonal(data):		
    #calculate the MSAVI
    msavi = xr.DataArray(data = (2*data.nir+1-np.sqrt((2*data.nir+1)**2 - 8*(data.nir-data.red)))/2,
              coords=data.coords,attrs=dict(crs=data.crs))
    
    msavi = msavi.resample(time='M').mean('time')
    #calculate seasonal climatology
    msavi_seasonalClimatology = msavi.groupby('time.season').mean('time')
    
    #resample monthly msavi to seasonal means
    msavi_seasonalMeans = msavi.resample(time='QS-DEC').mean('time')
    #calculate anomalies
    masvi_anomalies = msavi_seasonalMeans.groupby('time.season') - msavi_seasonalClimatology

    return masvi_anomalies, msavi_seasonalClimatology

a,b=compute_seasonal(ds)

In [None]:
b.plot(x='x',y='y', col='season',col_wrap=2, vmin=0,vmax=1.0, figsize=(10,7), cmap='plasma')

In [None]:
a.mean(['x', 'y']).plot(figsize=(12,5), ylim=(-0.25, 0.25))

In [None]:
# 2018 anomalies
# a.isel(time=range(-6,-2)).plot(x='x',y='y',col='time',col_wrap=2,figsize=(13,10), vmin=-0.5, vmax=0.5, cmap='BrBG')

In [None]:
# 2010 anomalies
# a.isel(time=range(-38,-34)).plot(x='x',y='y',col='time',col_wrap=2,figsize=(13,10), vmin=-0.5, vmax=0.5, cmap='BrBG')

### CUTS

In [None]:
def msavi_func(nir, red):
    return (2*nir+1-np.sqrt((2*nir+1)**2 - 8*(nir-red)))/2

def msavi_ufunc(ds):
    return xr.apply_ufunc(
        msavi_func, ds.nir, ds.red,
        dask='parallelized',
        output_dtypes=[float])

msavi = msavi_ufunc(ds_mo).compute()

In [None]:
climatology = msavi.groupby('time.month').mean('time')

anomalies = msavi.groupby('time.month') - climatology

In [None]:
    #Functions for weighting months to help with seasonal climatology
    dpm = {'noleap': [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31],
           '365_day': [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31],
           'standard': [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31],
           'gregorian': [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31],
           'proleptic_gregorian': [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31],
           'all_leap': [0, 31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31],
           '366_day': [0, 31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31],
           '360_day': [0, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30]}

    def leap_year(year, calendar='standard'):
        """Determine if year is a leap year"""
        leap = False
        if ((calendar in ['standard', 'gregorian',
            'proleptic_gregorian', 'julian']) and
            (year % 4 == 0)):
            leap = True
            if ((calendar == 'proleptic_gregorian') and
                (year % 100 == 0) and
                (year % 400 != 0)):
                leap = False
            elif ((calendar in ['standard', 'gregorian']) and
                     (year % 100 == 0) and (year % 400 != 0) and
                     (year < 1583)):
                leap = False
        return leap

    def get_dpm(time, calendar='standard'):
        """
        return a array of days per month corresponding to the months provided in `months`
        """
        month_length = np.zeros(len(time), dtype=np.int)

        cal_days = dpm[calendar]

        for i, (month, year) in enumerate(zip(time.month, time.year)):
            month_length[i] = cal_days[month]
            if leap_year(year, calendar=calendar):
                month_length[i] += 1
        return month_length

    def season_mean(ds, calendar='standard'):
        # Make a DataArray of season/year groups
        year_season = xr.DataArray(ds.time.to_index().to_period(freq='Q-NOV').to_timestamp(how='E'),
                                   coords=[ds.time], name='year_season')

        # Make a DataArray with the number of days in each month, size = len(time)
        month_length = xr.DataArray(get_dpm(ds.time.to_index(), calendar=calendar),
                                    coords=[ds.time], name='month_length')
        # Calculate the weights by grouping by 'time.season'
        weights = month_length.groupby('time.season') / month_length.groupby('time.season').sum()

        # Test that the sum of the weights for each season is 1.0
        np.testing.assert_allclose(weights.groupby('time.season').sum().values, np.ones(4))

        # Calculate the weighted average
        return (ds * weights).groupby('time.season').sum(dim='time')

    #calculate the seasonal climatology
#     msavi_seasonalClimatology = season_mean(msavi)