# Masking and Scaling <img align="right" src="../resources/csiro_easi_logo.png"> 

#### Index
- [Overview](#Overview)
- [Setup](#Setup)
   - [Imports](#Imports)
   - [Dask](#Dask)
   - [Example query](#Example-query)
- [Mask by no-data](#Mask-by-no-data)
- [Apply the mask](#Apply-the-mask)
- [Flag definition](#Flag-definition)
- [Mask by "on/off" values](#Mask-by-"on/off"-values)
- [Mask by enumeration values](#Mask-by-enumeration-values)
- [Mask by measurement value](#Mask-by-measurement-value)
- [Mask by shapefile](#Mask-by-shapefile)

## Overview

The ODC has methods for handling bit-mask layers and associated flag descriptions (e.g,. quality information). Flag descriptions are defined in the product document for any bit-mask layers. There are functions for creating a binary array mask from a selected set of flag values or from a `nodata` value. The binary mask can then be applied to the measurement layers.

A binary mask is an array with `True` values where data values are to be kept and `False` values where data values are to be replaced. The binary mask is the same size as a data array that it will be applied to.


Notes:
- A `nodata` value for _invalid_ or _missing_ data should be defined for each measurement layer as this is part of the [product document](https://datacube-core.readthedocs.io/en/latest/ops/product.html). The `nodata` value will correspond to the data type of the measurement layer (data array).
- The `nodata` values in an array can be replaced with \"Not a Number (NaN)\" (`np.nan`), in which case the array will change data type to `float64`. This may have implications for memory usage or for further calculations.
- Many `Xarray` and `Numpy` functions handle `NaN` values \"naturally\" (appropriately ignored). Likewise the plotting libraries will tend to ignore `NaN`s natively.
- If using a `nodata` value other than `NaN` then this may need to be handled specifically in your code (e.g., passed to the `xarray`/`numpy` or plotting routines).


Further information:

- https://xarray.pydata.org/en/stable/user-guide/data-structures.html
- https://xarray.pydata.org/en/stable/user-guide/computation.html#missing-values
- https://numpy.org/doc/stable/reference/generated/numpy.where.html

## Setup

#### Imports
These are a standard set of imports that we use across many notebooks

In [None]:
# Data tools
import numpy as np
import xarray as xr
import pandas as pd
from datetime import datetime

# Datacube
import datacube
from datacube.utils import masking  # https://github.com/opendatacube/datacube-core/blob/develop/datacube/utils/masking.py
from odc.algo import enum_to_bool   # https://github.com/opendatacube/odc-tools/blob/develop/libs/algo/odc/algo/_masking.py
from datacube.utils.rio import configure_s3_access

# Holoviews, Datashader and Bokeh
import hvplot.pandas
import hvplot.xarray
import holoviews as hv
import colorcet as cc
hv.extension("bokeh", logo=False)

# Python
import sys, os, re
os.environ['USE_PYGEOS'] = '0'

# Optional EASI tools
sys.path.append(os.path.expanduser('../scripts'))
from easi_tools import EasiDefaults
import notebook_utils
easi = EasiDefaults()

#### Dask
We have now put the Dask cluster code into a [_notebook_utils_](../scripts/notebook_utils.py) function.

In [None]:
cluster, client = notebook_utils.initialize_dask(use_gateway=False)
display(cluster if cluster else client)
print(notebook_utils.localcluster_dashboard(client, server=easi.hub))

#### AWS configuration
To use data in public requester-pays buckets, run the following code (once per dask cluster):

In [None]:
configure_s3_access(aws_unsigned=False, requester_pays=True, client=client)

#### Example query

In [None]:
# This configuration is read from the defaults for this system. 
# Examples are provided in a commented line to show how to set these manually.

study_area_lat = easi.latitude
# study_area_lat = (39.2, 39.3)

study_area_lon = easi.longitude
# study_area_lon = (-76.7, -76.6)

product = easi.product('landsat')
# product = 'landsat8_c2l2_sr'

set_time = easi.time
# set_time = ('2020-08-01', '2020-12-01')

set_crs = easi.crs('landsat')
# set_crs = 'EPSG:32618'

set_resolution = easi.resolution('landsat')
# set_resolution = (-30, 30)

In [None]:
query = {
    'product': product,
    'x': study_area_lon,
    'y': study_area_lat,
    'time': set_time,
    'output_crs': set_crs,
    'resolution': set_resolution,
    'group_by': 'solar_day',
    'dask_chunks': {'x': 2048, 'y': 2048}
}

dc = datacube.Datacube()
data = dc.load(**query)
data

## Datacube masking library
There are two datacube masking libraries that we use (both included in the imports above):
- [datacube masking library](https://github.com/opendatacube/datacube-core/blob/develop/datacube/utils/masking.py) > Primary masking functions
- [odc_tools masking library](https://github.com/opendatacube/odc-tools/blob/develop/libs/algo/odc/algo/_masking.py) > Enumeration and morphological operators

The datacube masking functions can be used to identify, describe and create masks from `flags_definition` measurements and from `nodata` attributes.

When applied to a Dataset the datacube masking functions select and use the first measurement with the `flags_definition` property. Let's be more specific for cases where there may be multiple `flags_definition` measurements in a dataset.


In [None]:
# Measurements for the selected product
measurements = dc.list_measurements().loc[query['product']]

# Separate lists of measurement data names and flag names
data_names = measurements[pd.isnull(measurements.flags_definition)].index
flag_names = measurements[pd.notnull(measurements.flags_definition)].index

# Select one for use below
flag_name = flag_names[0]

## Mask by no-data

There are two convenience functions that create a mask from the `nodata` attribute of each measurement. These can be applied to a Dataset but if so then they apply to each layer (including any bit-masks!).

In [None]:
# Apply to measurement layers

# Under the hood: data.where(data != data.nodata) -> replace False with NaN -> float64
valid_data = masking.mask_invalid_data(data[data_names], keep_attrs=True)

# Under the hood: data != data.nodata -> bool
valid_mask = masking.valid_data_mask(data[data_names])

## Apply the mask

Use [numpy.where()](https://numpy.org/doc/stable/reference/generated/numpy.where.html) to apply a mask array to measurement arrays. Shown here are two ways to apply a mask to the measurement arrays. The first replaces the `False` values with `NaN` and changes the data type to `float64`. The second replaces the `False` values with a given value and retains the data type if possible.

In [None]:
# Apply to measurement layers: Change dtype to float64

data_masked = data[data_names].where(valid_mask)  # Default: Where False replace with NaN -> convert dtype to float64
display(data_masked)  # Type: float64

# Equivalent to valid_data above
# data_masked.equals(valid_data)

data_masked[data_names[0]].plot(col='time', robust=True, col_wrap=4)

In [None]:
# Apply to measurement layers: Retain the data type (dtype) - note that the data type stays as int32

nodata = -9999  # A new nodata value
data_masked = data[data_names].where(valid_mask, nodata)  # Where False values are replaced with the nodata -> retain dtype if compatible

# Update the nodata value in the data variable attributes
for var in data_masked.data_vars:
    data_masked[var].attrs['nodata'] = nodata
    
display(data_masked)

data_masked[data_names[0]].plot(col='time', robust=True, col_wrap=4)

## Flag definition

There are two ways that bit-values can be defined and used in a `flag_definition`:
- \"On/Off\" values and labels for each binary bit(s)
- A combination of bit names and values creates the mask. See [Mask by \"on/off\" values](#Mask-by-\"on/off\"-values)
- List of discrete integer values and labels (enumeration)
- A combination of integer values creates a mask. See [Mask by enumeration values](#Mask-by-enumeration-values)

The following `flag_definition`s have flag names (_index_ column), bit values (_bits_ column), \"on/off\" values and labels (_values_ column), and an optional description (_description_ column).
 

In [None]:
# Pandas table. First flags_definition measurement found
masking.describe_variable_flags(data)

In [None]:
# Pandas table. Select a different flags_definition measurement
masking.describe_variable_flags(data.qa_aerosol)

In [None]:
# Simple dictionary. Select a flags_definition attribute
flags_def = data[flag_name].flags_definition
flags_def

## Mask by "on/off" values

The `make_mask()` returns a mask where `True` corresponds to the selected bits and values. These may considered as _good_ or _bad_ pixel flag selections depending on the application and the `flag_definition`.

Define a dictionary of ___good___ pixel flags using values shown in the variable flags above `{'flag': 'value'}`.

>__NOTE:__ The examples below are designed to work with the Landsat flags. Other products will have different flag definitions.

In [None]:
good_pixel_flags = {
    'nodata': False,
    'cloud': 'not_high_confidence',
    'cloud_shadow': 'not_high_confidence',
}

Make a mask corresponding to the `good_pixel_flags` and plot the result.

In [None]:
mask = masking.make_mask(data[flag_name], **good_pixel_flags)  # expand dictionary of pixel flags

good_data = data[data_names[0]].where(mask)
good_data.plot(col='time', robust=True, col_wrap=4)

The sum of the mask gives the number good pixels for each time layer.

In [None]:
pixels = mask.shape[1] * mask.shape[2]
percent = mask.sum(axis=[1,2]) / pixels *100
list(map('{:.2f}%'.format,percent.values))

The `create_mask_value()` function can be used to verify how the pixel flags are combined to create the mask.

In [None]:
# make_mask: variable & bitmask == bitvalue
bitmask, bitvalue = masking.create_mask_value(data[flag_name].flags_definition, **good_pixel_flags)
display(bin(bitmask), bitvalue)

## Mask by enumeration values

_Work in progress_

Enumerated masks contain discrete integer values in the `flags_definition`. Examples include:
- Sentinel-2 \"SCL\" quality layer
- Geoscience Australia's Landsat and Sentinel-2 ARD products ('ga_ls5t_ard_3', 'ga_ls7e_ard_3', 'ga_ls8c_ard_3', 's2a_ard_granule', 's2b_ard_granule', 's2a_nrt_granule', 's2b_nrt_granule')
   

In [None]:
# Load data from an enumeration product

# Select a list of pixel flags

In [None]:
# enum_mask = enum_to_bool(data[flag_name], enum_pixel_flags)  # tuple or list of pixel flags

# good_data = data[data_names[0]].where(enum_mask)
# good_data.plot(col='time', robust=True, col_wrap=4)

## Mask by measurement value

_Work in progress_

Use xarray functions, e.g. http://xarray.pydata.org/en/stable/api.html#id5

In [None]:
# TODO: Update this example

# # Set all invalid to `NaN` - the valid range for Landsat Collection 2 reflectance bands is 1 to 65455
# mask_by_value = (data[data_names] >= 1) & (data[measurement_names] <= 65455)
# display(mask_by_value)  # Type: bool

# # TODO: confirm that 'where' applies each boolean array variable-wise
# valid_data = data[measurement_names].where(mask_by_value)
# valid_data

## Mask by shapefile

_Work in progress_

**Please refer to the [time_series_masking_with_shapefile.ipynb](../Case%20studies/time_series_masking_with_shapefile.ipynb) notebook in the `/Case studies` folder for an example of masking using a shapefile**

Available masks
- Aust inland waters
- Aust coast
- VNSC coastlines
- Catchments (Aust, VN)
- https://www.soest.hawaii.edu/pwessel/gshhg/

***More details to come***