# Tabulation of Fractional Cover data within shapefile polygons

**What does this notebook do?**

This notebook is a pilot collaboration between Geoscience Australia and Australian Bureau of Statistics. The purpose of the notebook is to use a shapefile polygon boundaries to load fractional cover dataset, complete zonal statistics and tabulate the results.

**Requirements**

You need to run the following commands from the command line prior to launching jupyter notebooks from the same terminal so that the required libraries and paths are set:

`module use /g/data/v10/public/modules/modulefiles`

`module load dea`


**Background**

Data from the Landsat 5,7 and 8 satellite missions are accessible through Digital Earth Australia (DEA). The code snippets in this notebook will let you retrieve and plot the Fractional Cover (FC25) data stored in DEA.


**How to use this notebook**

A basic understanding of any programming language is desirable but one doesn't have to be an expert Python programmer to manipulate the code to get and display the data.This doc applies to the following Landsat satellites, Fractional Cover bands and the WOfS dataset:

- Landsat 5
- Landsat 7
- Landsat 8
- PV - Photosythetic vegetation
- NPV - Non-Photosythetic vegetation
- BS - Bare Soil
- UE - Unmixing Error
- Water Observations from Space (WOFs)
- WOfS Feature Layer (WOFL)

**Bugs still to fix**

- Memory errors for large (~2000 km^2) polygons
- tidy and document the `fc_polygon_tabulation()` module. Opportunity to create additional modules to declutter.

**Errors or bugs**

If you find an error or bug in this notebook, please contact erin.telfer@ga.gov.au.


## Import Libraries and define functions

In [23]:
%matplotlib inline

from datetime import time, datetime
import os.path

from matplotlib import pyplot as plt
import pandas
import numpy
import xarray as xr
import rasterio
import rasterio.features
import fiona
import dask
from dask.delayed import delayed
from dask.distributed import LocalCluster, Client
import tempfile

import datacube
from datacube import Datacube
from datacube.virtual import construct, construct_from_yaml
from datacube.ui.task_app import year_splitter

In [24]:
help(datacube.ui.task_app)

Help on module datacube.ui.task_app in datacube.ui:

NAME
    datacube.ui.task_app

FUNCTIONS
    add_dataset_to_db(index, datasets)
    
    break_query_into_years(time_query, **kwargs)
    
    cell_list_to_file(filename, cell_list)
    
    check_existing_files(paths)
        Check for existing files and optionally delete them.
        
        :param paths: sequence of path strings or path objects
    
    do_nothing(result)
    
    get_full_lineage(index, id_)
    
    load_config(index, app_config_file, make_config, make_tasks, *args, **kwargs)
    
    load_tasks(taskfile)
    
    pickle_stream(objs, filename)
    
    run_tasks(tasks, executor, run_task, process_result=None, queue_size=50)
        :param tasks: iterable of tasks. Usually a generator to create them as required.
        :param executor: a datacube executor, similar to `distributed.Client` or `concurrent.futures`
        :param run_task: the function used to run a task. Expects a single argument of one of the ta

### Set up a local dask cluster
This lets several processes work at the same time, and manage total memory usage

We also get a dashboard to see how the system is running

In [2]:
cluster = LocalCluster(local_dir=tempfile.gettempdir(), 
                       n_workers=3, 
                       memory_limit=6e9)
client = Client(cluster)
dask.config.set(get=client.get)
client

0,1
Client  Scheduler: tcp://127.0.0.1:46631  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 3  Cores: 9  Memory: 18.00 GB


In [3]:
dc = Datacube()

### Construct virtual product

In [4]:
LS7_BROKEN_DATE = datetime(2003, 5, 31)
is_pre_slc_failure = lambda dataset: dataset.center_time < LS7_BROKEN_DATE

In [5]:
def wofls_fuser(dest, src):
    where_nodata = (src & 1) == 0
    numpy.copyto(dest, src, where=where_nodata)
    return dest

In [6]:
fc_land_only_yaml = """
    transform: apply_mask
    mask_measurement_name: water
    preserve_dtype: false
    input:
        juxtapose:
          - collate:
              - transform: apply_mask
                mask_measurement_name: pixelquality
                preserve_dtype: false
                input:
                    juxtapose:
                      - product: ls5_fc_albers
                        group_by: solar_day
                        measurements: [PV, NPV, BS]
                      - transform: make_mask
                        input:
                            product: ls5_pq_albers
                            group_by: solar_day
                            fuse_func: datacube.helpers.ga_pq_fuser
                        flags:
                            ga_good_pixel: true
                        mask_measurement_name: pixelquality
              - transform: apply_mask
                mask_measurement_name: pixelquality
                preserve_dtype: false
                input:
                    juxtapose:
                      - product: ls7_fc_albers
                        group_by: solar_day
                        measurements: [PV, NPV, BS]
                        # dataset_predicate: __main__.is_pre_slc_failure
                      - transform: make_mask
                        input:
                            product: ls7_pq_albers
                            group_by: solar_day
                            fuse_func: datacube.helpers.ga_pq_fuser
                        flags:
                            ga_good_pixel: true
                        mask_measurement_name: pixelquality
              - transform: apply_mask
                mask_measurement_name: pixelquality
                preserve_dtype: false
                input:
                    juxtapose:
                      - product: ls8_fc_albers
                        group_by: solar_day
                        measurements: [PV, NPV, BS]
                      - transform: make_mask
                        input:
                            product: ls8_pq_albers
                            group_by: solar_day
                            fuse_func: datacube.helpers.ga_pq_fuser
                        flags:
                            ga_good_pixel: true
                        mask_measurement_name: pixelquality
          - transform: make_mask
            input:
                product: wofs_albers
                group_by: solar_day
                fuse_func: __main__.wofls_fuser
            flags:
                water_observed: false
            mask_measurement_name: water
"""
fc_land_only = construct_from_yaml(fc_land_only_yaml)

### Set up geometry functions

In [7]:
def geometry_mask(geoms, geobox, all_touched=False, invert=False, chunks=None):
    """
    Create a mask from shapes.

    By default, mask is intended for use as a
    numpy mask, where pixels that overlap shapes are False.
    :param list[Geometry] geoms: geometries to be rasterized
    :param datacube.utils.GeoBox geobox:
    :param bool all_touched: If True, all pixels touched by geometries will be burned in. If
                             false, only pixels whose center is within the polygon or that
                             are selected by Bresenham's line algorithm will be burned in.
    :param bool invert: If True, mask will be True for pixels that overlap shapes.
    """
    data = rasterio.features.geometry_mask([geom.to_crs(geobox.crs) for geom in geoms],
                                           out_shape=geobox.shape,
                                           transform=geobox.affine,
                                           all_touched=all_touched,
                                           invert=invert)
    if chunks is not None:
        data = dask.array.from_array(data, chunks=tuple(chunks[d] for d in geobox.dims))
        
    coords = [xr.DataArray(data=coord.values, name=dim, dims=[dim], attrs={'units': coord.units}) 
              for dim, coord in geobox.coords.items()]
    return xr.DataArray(data, coords=coords)

In [8]:
def get_shapes(shape_file):
    with fiona.open(shape_file) as shapes:
        crs = datacube.utils.geometry.CRS(shapes.crs_wkt)
        for shape in shapes:
            geom = datacube.utils.geometry.Geometry(shape['geometry'], crs=crs)
            yield geom, shape['properties']

In [11]:
def fc_summary(data,mask_int):
    fc = data[['BS', 'PV', 'NPV']].sum(dim=('x', 'y'))
#     fc_sum = fc.to_array('variable').sum(dim='variable')

    area = fc * (25 * 25 / 1_000_000)
    area = area.rename({'BS': 'BS_area', 'PV': 'PV_area', 'NPV': 'NPV_area'})
    for da in area.data_vars.values():
        da.attrs['units'] = 'km2'

    fc = fc / mask_int * 100 * 100
#     fc = fc * 100 / fc_sum
    for da in fc.data_vars.values():
        da.attrs['units'] = '%'
        
    fc = fc.merge(area)
    
    return fc

In [12]:
def keepna(a, dim=None, thresh=None):
    if type(a) is xr.Dataset:
        return a.apply(keepna, keep_attrs=True, dim=dim, thresh=thresh)
    
    keep_dim = [] if dim is None else [dim]
    dims = [d for d in a.dims if d not in keep_dim]
    if thresh is None:
        keep = numpy.isfinite(a).sum(dim=dims) > 0
    else:
        keep = numpy.isfinite(a).sum(dim=dims) >= thresh
    return a.where(keep, other=numpy.nan)

In [16]:
def plot_stacked(daily_data, catchment_id, show=True):
    if not show:
        plt.ioff()
        
    fig,ax = plt.subplots(figsize=(10,5))
    ax.stackplot(daily_data.dropna(dim='time').time.data, 
                 daily_data.dropna(dim='time').BS, 
                 daily_data.dropna(dim='time').NPV, 
                 daily_data.dropna(dim='time').PV,
                 colors = ['tan','olive','darkolivegreen',], 
                 labels=['BS','NPV','PV',])
    plt.legend(loc='upper center', ncol = 3)
    plt.title(f'FC Components: Catchment ID {catchment_id}', size=12)
    plt.ylabel('Percentage (%)', size=12) #Set Y label
    plt.xlabel('Date', size=12) #Set X label
    plt.savefig(f'/g/data/r78/ext547/abs/output/{catchment_id}_monthly_plot.png');
    plt.close(fig)
    
    # Turn interactive back on
    if not show:
        plt.show()

### Process the query
For each year and polygon query the product, apply the gemotry mask and compute the frational cover stats

Using `client.compute()` lets us use the monthly results in calculating the annual results at the same time.

In [25]:
shape_file = os.path.expanduser('../input/SA_2016_threepolygons_3577.shp')
# shape_file = os.path.expanduser('../input/SA_2016_twopolygons_3577.shp')
shapes = list(get_shapes(shape_file))

In [18]:
start_year, end_year = 2000, 2002
time_range = (str(start_year), str(end_year))
time_range

('2000', '2002')

In [19]:
# # Use this list instead of shapes to just the big outback South Australian area
# s2 = [(g,p) for g, p in shapes if str(p['SA2_MAIN16']) == '406021141']

In [20]:
# s2

If we have enough resources, we can start the query and calculation of the next year's data while the previous is still being calculated. `by_slice=False` will be faster, but use more memory.

For larger areas `by_slice` will need to be `True`, so that the compute cluster does not become overwhelmed.  

If you get the error:
> `distributed.nanny - WARNING - Worker exceeded 95% memory budget. Restarting`

then you will need to set `by_slice=True`

In [21]:
by_slice=True

In [26]:
for geometry, properties in shapes:
    catchment_id = str(properties['SA2_MAIN16'])
    print(f"Catchment ID: {catchment_id}, size: {properties['AREASQKM16']}km^2, time: {time_range}")
          
    monthly_values = []
    annual_values = []
    mask = None
          
    for sub_time_range in year_splitter(time_range[0], time_range[-1]):
        print(f'  lazy loading {sub_time_range}...')  
        data = fc_land_only.load(dc, dask_chunks={'time': 1, 'y': 4000, 'x': 4000}, 
                                 time=sub_time_range, 
                                 geopolygon=geometry)
        print(f'    lazy loaded {sub_time_range}')

        if mask is None:
          mask = geometry_mask([geometry], data.geobox, invert=True, chunks=data.chunks)
          mask_int = int(mask.sum())
        data = data.where(mask)

        data = data.resample(time='1MS').mean(dim='time', skipna=True)
        data = keepna(data, dim='time', thresh=0.9*int(mask.sum()))

        monthly_data = fc_summary(data, mask_int)
#         annual_data = monthly_data.resample(time='1YS').mean(dim='time', skipna=True)

        print(f"    calculating for {dict(monthly_data.sizes)}")
        monthly_data, annual_data = client.compute([monthly_data, annual_data], sync=by_slice)
        monthly_data = client.compute([monthly_data], sync=by_slice)

        print("    compute submitted")
          
        monthly_values.append(monthly_data)
        annual_values.append(annual_data)
    
    if not by_slice:
        print("  all years queried, hard load data")
        monthly_values = client.gather(monthly_values)
        annual_values = client.gather(annual_values)

    monthly_values = xr.concat(monthly_values, dim='time').dropna(dim='time')
    plot_stacked(monthly_values, catchment_id, show=False)
          
    annual_values = xr.concat(annual_values, dim='time').dropna(dim='time')
          
#     print("  all data loaded, save to csv")
#     monthly_values.to_dataframe().to_csv(f"/g/data/r78/ext547/abs/output/{catchment_id}_monthly.csv")
#     annual_values.to_dataframe().to_csv(f"/g/data/r78/ext547/abs/output/{catchment_id}_annual.csv")
          
    print(f"  Catchment {catchment_id} done")

Catchment ID: 105011092, size: 56843.6928km^2, time: ('2000', '2002')
  lazy loading ('2000-01-01 00:00:00', '2000-12-31 23:59:59.999999999')...
    lazy loaded ('2000-01-01 00:00:00', '2000-12-31 23:59:59.999999999')
    calculating for {'time': 12}




KilledWorker: ('dataset-f8030b1972ad4914bf324b1b15bc60ab', 'tcp://127.0.0.1:43168')