# Compare Geomedian summaries with opening, dilation operations <img align="right" src="../Supplementary_data/dea_logo.jpg">

* [**Sign up to the DEA Sandbox**](https://docs.dea.ga.gov.au/setup/sandbox.html) to run this notebook interactively from a browser
* **Compatibility:** Notebook currently compatible with the `DEA Sandbox` environment
* **Products used:** 
TBC

### Assumptions

In [None]:
import datacube
import matplotlib.pyplot as plt
from odc.algo import mask_cleanup, erase_bad, enum_to_bool, to_f32, to_f32, xr_geomedian, int_geomedian
from datacube.utils import masking
from datacube.utils.cog import write_cog
import fiona
import rioxarray

import sys
sys.path.insert(1, '../Tools/')
from dea_tools.datahandling import wofs_fuser
from dea_tools.dask import create_local_dask_cluster

from datacube.utils.masking import make_mask
from datacube.utils.geometry import CRS, Geometry, GeoBox
from dea_tools.plotting import rgb

import os
import sys
import fiona
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from dea_tools.datahandling import load_ard

sys.path.insert(1, "../Tools/")
from dea_tools.dask import create_local_dask_cluster


start_buffering = 0
end_buffering = 10

maturity = "final"
#time_period = None
#time_period = ('2023-01-01', '2023-03-01')
time_period = "2022"

bands = [
    "nbart_blue",
    "nbart_red",
    "nbart_green",
    "nbart_nir",
    "nbart_swir_1",
    "nbart_swir_2",
    "oa_fmask",
]

products = [
#    "ga_ls5t_ard_3",
#    "ga_ls7e_ard_3",
    "ga_ls8c_ard_3",
    "ga_ls9c_ard_3",
]

path_row_list = [
#    "089084", R
    "090084",
    "090086",
    "091086"
#    "093086", R
#    "096072",
#    "098071", R
#    "099079",
#    "105069"
]

geomedian_threads = 10

### Data Load Queries

In [None]:
def landsat_scene_poly(path_row, radius=None):
    """
    Get geometry for a given landsat path row

    Parameters
    ----------
    path_row : string
        Path row to search for
    radius : int, optional
        If provided, the centroid of the path row geom will be buffered
        by `radius` metres to provide a smaller geometry located within
        the path row, reducing memory and processing.

    Returns
    -------
    geometry : geometry
        geopolygon for given path row
    """

    # Path to Landsat file on S3
    landsat_shape = "https://data.dea.ga.gov.au/derivative/ga_ls_path_row_grid.geojson"

    # Select feature
    with fiona.open(landsat_shape) as all_shapes:
        for s in all_shapes:
            # landsat pathrows dont include 0 in front hence convert path-row to int to drop 0
            if s["properties"].get("PR") == int(path_row):
                # Extract geom
                geom = Geometry(s["geometry"], crs=CRS("EPSG:4326"))

                # Buffer centroid by X and return geom
                if radius is not None:
                    geom = geom.to_crs("EPSG:3577").centroid.buffer(radius)

                return geom

In [None]:
# for full path-row geopolygon
def define_query_params(path_row, time_period, maturity, radius=None):
    """
    Create query params for odc load

    Parameters
    ----------
    path_row : string
        Path row to search for
    time_period : list
        Time range
    maturity : string
        The dataset maturity level to include in the analysis
    radius : int, optional
        If provided, the centroid of the path row geom will be buffered
        by `radius` metres to provide a smaller geometry located within
        the path row, reducing memory and processing.

    Returns
    -------
    query_params : dictionary
        qury params to use for odc load
    """
    query_poly = landsat_scene_poly(path_row, radius)
    query_params = dict(
        geopolygon=query_poly,
        time=time_period,
        region_code=path_row,
        dataset_maturity=maturity,
    )
    return query_params


# for small scale fast tests

# Sandy Beaches: query_params = dict(x=(115.16, 115.23), y=(-30.75, -30.83), dataset_maturity=maturity, time=time_period)
# Alpine Snow: query_params = dict(x=(148.22, 148.39), y=(-36.37, -36.50), dataset_maturity=maturity, time=time_period)
# Urban Area: query_params = dict(x=(144.77, 145.02), y=(-37.68, -37.85), dataset_maturity=maturity, time=time_period)
# Salt Lakes: query_params = dict(x=(135.72, 135.94), y=(-31.26, -31.38), dataset_maturity=maturity, time=time_period)

def define_query_params_lat_lon_test(time_period, maturity):
    query_params = dict(x=(115.16, 115.23), y=(-30.75, -30.83), dataset_maturity=maturity, time=time_period)
    return query_params

def define_load_params(bands, load_product, query_params):
    """
    Define load params

    Parameters
    ----------
    bands : list
        measurement bands
    load_product : string
        odc product
    query_params: dictionary
        odc query parameters

    Returns
    -------
    load_params : dict
        dictionary of load params
    """
    # Find matching datasets
    dss = dc.find_datasets(product=load_product, **query_params)

    # Identify native CRS from datasets; fall back on "EPSG:3577"
    # if no data is found to prevent an error
    native_crs = dss[0].crs if len(dss) > 0 else "EPSG:3577"

    # Set load params (measurements to load, Dask chunking, resampling etc)
    load_params = dict(
        measurements=bands,
        output_crs=native_crs,  # Native CRS
        resolution=(-30, 30),  # Native resolution
        align=(15, 15),  # Required for native resolution load
        group_by="solar_day",
        dask_chunks={},
        skip_broken_datasets=True, # having around one-three failed timesteps due to s3 read errors for ls5
    )
        
    return load_params

In [None]:
def load_data(load_params, load_product, query_params):
    """
    Load odc data

    Parameters
    ----------
    load_params : dictionary
        load parameters dictionary
    load_product : string
        odc product
    query_params: dictionary
        odc query parameters

    Returns
    -------
    ds : dataset
        geospatial satellite data dataset
    """
    # Lazily load data
    ds = dc.load(product=load_product,
                 **query_params,
                 **load_params)

    return ds


def load_data_geomedian(load_params, load_product, query_params):
    """
    Load odc data for geomedian applying cloud filter

    Parameters
    ----------
    load_params : dictionary
        load parameters dictionary
    load_product : string
        odc products
    query_params: dictionary
        odc query parameters

    Returns
    -------
    ds : dataset
        geospatial satellite data dataset
    """ 
    print(load_product)
    ds = load_ard(dc=dc,
                  products=[load_product],
                  # remove filtering for geomedian comparisons
                  #min_gooddata=0.90,
                  **query_params,
                  **load_params
    )
    
    return ds

In [None]:
def apply_geomedian(ds, geomedian_threads):
    """
    Apply geomedian

    Parameters
    ----------
    ds : dataset
        odc data
    geomedian_threads: integer
        number of threads for processing

    Returns
    -------
    geomedian : dataset
     computed geomedian for input dataset
    """
    geomedian = int_geomedian(ds, 
                              num_threads=geomedian_threads)
    geomedian = geomedian.compute()
    #rgb(geomedian, size=10)
    return geomedian

In [None]:
def calc_cloud_shadow_mask(ds):
    """
    Calculate cloud shadow mask

    Parameters
    ----------
    ds : dataset
        data

    Returns
    -------
    cloud_shadow_mask : xr.DataArray
        cloud shadow mask
    nodata_mask: xr.DataArray
        nodata mask
    """
    # Identify pixels that are either "nodata", "cloud" or "cloud_shadow"
    nodata_mask = enum_to_bool(ds.oa_fmask, categories=["nodata"])
    cloud_shadow_mask = enum_to_bool(ds.oa_fmask, categories=["cloud", "shadow"])

    return cloud_shadow_mask, nodata_mask


# Plot
# cloud_shadow_mask.isel(time=slice(4, 12)).plot(col="time", col_wrap=4)

In [None]:
def plot_std_gradient_buffer(
    context, std_buffer_df, path_row, product, start_buffering, end_buffering, time_period, operation, export_figure=True
):
    # Set up three panel fig
    fig, axes = plt.subplots(1, 3, figsize=(16, 5))
    plt.subplots_adjust(wspace=0.3)

    # Apply numpy gradient to each column in dataset
    std_gradient_df = std_buffer_df.apply(np.gradient, axis=0)

    # Plot standard deviation
    std_buffer_df.plot(
        ax=axes[0],
        xlabel="Buffer distance",
        ylabel=f"{context}",
        title=f"{context} per buffer pixel",
        legend=False,
    )

    # Plot gradient
    std_gradient_df.plot(
        ax=axes[1],
        xlabel="Buffer distance",
        ylabel=f"{context} gradient",
        title=f"{context} gradient per buffer pixel",
    )

    # Add labels to every second item
    for index in std_gradient_df.index[::2]:
        axes[1].text(
            index,
            std_gradient_df.nbart_blue.loc[index],
            round(std_gradient_df.nbart_blue.loc[index], 2),
            size=8,
        )

    # Plot mean of all gradients
    std_gradient_df_mean = std_gradient_df.mean(axis=1).to_frame("All bands")
    std_gradient_df_mean.plot(
        ax=axes[2],
        xlabel="Buffer distance",
        ylabel=f"{context} gradient",
        title=f"{context} gradient per buffer pixel",
    )

    # Add labels to every second item
    for index in std_gradient_df_mean.index[::2]:
        axes[2].text(
            index,
            std_gradient_df_mean["All bands"].loc[index],
            round(std_gradient_df_mean["All bands"].loc[index], 2),
            size=8,
        )

    # Set grid and x-ticks
    for ax in axes:
        ax.grid(alpha=0.1)
        ax.set_xticks([i for i in range(start_buffering, end_buffering)])

    # Add title above subplots
    fig.suptitle(f"Path/row {path_row}, {product}", fontsize=14)
    plt.show()

    # Optionally export figure
    if export_figure:
        fig.savefig(
            f"output_data/{path_row}_{product}_{operation}_{context}_{time_period}_plots.jpg", bbox_inches="tight"
        )

### Load data

In [None]:
dc = datacube.Datacube(app='Geomedian_cloud_buffering')

# Create local dask cluster
client = create_local_dask_cluster(return_client=True)

### Geomedian Analysis for Dilation of cloud and shadow

In [None]:
from datacube.utils.masking import mask_invalid_data

# Loop through each validation path row
for path_row in path_row_list:
    for product in products:
        
        output_dir = f"/gdata1/projects/cloud_masking/geomedian_dilation_results/{path_row}"
        os.makedirs(output_dir, exist_ok=True)
    
        # Create query parameters
        query_params = define_query_params(
            path_row, time_period, maturity, radius=50000
        )
        
        # Small Test Area
        #query_params = define_query_params_lat_lon_test(time_period, maturity)
        
        # Create Load Parameters
        load_params = define_load_params(bands, product, query_params)

        # Load ARD Data
        ds = load_data(load_params, product, query_params).persist()

        # Create cloud and shadow mask
        cloud_shadow_mask, nodata_mask = calc_cloud_shadow_mask(ds)

        output_dict = {}
        
        # Loop through each buffer radius
        for buffer in range(start_buffering, end_buffering + 1):

            dilation = buffer
            
            # create cloud and shadow masked datasets with specified buffer
            clouds_shadows_dilation = mask_cleanup(cloud_shadow_mask, mask_filters=[('dilation', buffer)])
            ds_buffer = ds.where(~clouds_shadows_dilation)
            
            # Set invalid nodata pixels to NaN
            ds_buffer_valid = mask_invalid_data(ds_buffer)
            
            print(f"Calculating geomedian with {buffer} pixel buffer applied")
            ds_geomedian = apply_geomedian(ds_buffer_valid, geomedian_threads)
            
            # Save the geomedian data
            rgb(ds_geomedian, size=10, savefig_path=f"{output_dir}/geomedian_{path_row}_{product}_dilation_{dilation}_{str(time_period)}.tif")
            ds_geomedian.nbart_blue.rio.to_raster(f"{output_dir}/geomedian_{path_row}_{product}_dilation_{dilation}_pixels__{time_period}_nbart_blue.tif")
            ds_geomedian.nbart_red.rio.to_raster(f"{output_dir}/geomedian_{path_row}_{product}_dilation_{dilation}_pixels_{time_period}_nbart_red.tif")
            ds_geomedian.nbart_green.rio.to_raster(f"{output_dir}/geomedian_{path_row}_{product}_dilation_{dilation}_pixels_{time_period}_nbart_green.tif")
            ds_geomedian.nbart_nir.rio.to_raster(f"{output_dir}/geomedian_{path_row}_{product}_dilation_{dilation}_pixels_{time_period}_nbart_nir.tif")
            ds_geomedian.nbart_swir_1.rio.to_raster(f"{output_dir}/geomedian_{path_row}_{product}_dilation_{dilation}_pixels_{time_period}_nbart_swir_1.tif")
            ds_geomedian.nbart_swir_2.rio.to_raster(f"{output_dir}/geomedian_{path_row}_{product}_dilation_{dilation}_pixels_{time_period}_nbart_swir_2.tif")
            ds_geomedian.oa_fmask.rio.to_raster(f"{output_dir}/geomedian_{path_row}_{product}_dilation_{dilation}_pixels_{time_period}_oa_fmask.tif")
            
            # drop fmask from std results
            ds_geomedian = ds_geomedian.drop("oa_fmask")
            
            print("Calc standard deviation of the difference between 0 and X buffer")
            std_ds = ds_geomedian.std().mean().compute()
            std_df = std_ds.to_array().to_dataframe(name="std")
            output_dict[buffer] = std_df
            # Print results
            print(
                f"Buffer in pixels: {buffer}, {': '.join(std_df.round(1).to_string(index_names=False).split())}"
            )
            
        # Concatenate outputs into a single dataframe, then unstack to wide 
        # format with each variable as a column
        std_geomedian_df = pd.concat(output_dict, names=["pixel_buffer", "variable"])[
            "std"
        ].unstack("variable")
        
        # Export results as csv with a "product" and "path_row" column
        std_geomedian_df.assign(product=product, path_row=path_row).to_csv(
            f"{output_dir}/geomedian_std_buffer{path_row}_{product}_dilation_{dilation}_{str(time_period)}.csv", index=True
        )

        # Plot the standard deviation and gradient results
        plot_std_gradient_buffer(
            "Standard Deviation",
            std_geomedian_df,
            path_row,
            product,
            start_buffering,
            end_buffering,
            str(time_period),
            f"dilation_{dilation}",
            export_figure=True,
        )

### Geomedian Analysis for Opening on cloud and Dilation on cloud and shadow

In [None]:
from datacube.utils.masking import mask_invalid_data

# Loop through each validation path row
for path_row in path_row_list:
    for product in products:
        
        output_dir = f"/gdata1/projects/cloud_masking/geomedian_opening_dilation_results/{path_row}"
        os.makedirs(output_dir, exist_ok=True)
    
        # Create query parameters
        # small test example
        #query_params = define_query_params_lat_lon_test(time_period, maturity)
        
        # full specified area
        query_params = define_query_params(
            path_row, time_period, maturity, radius=50000
        )
        
        load_params = define_load_params(bands, product, query_params)

        # Load ARD Data
        ds = load_data(load_params, product, query_params).persist()

        # Create three seperate masks for cloud, shadow and no data
        cloud_mask = enum_to_bool(ds.oa_fmask, categories=["cloud"])
        shadow_mask = enum_to_bool(ds.oa_fmask, categories=["shadow"])
        nodata_mask = enum_to_bool(ds.oa_fmask, categories=["nodata"])

        output_dict = {}
        
        dilation = 6
        
        # Loop through each buffer radius
        for buffer in range(start_buffering, end_buffering + 1):
            
            print(f"Creating masks with {buffer} pixel buffer applied")
            
            # Opening on cloud mask with varying level of pixels applied
            opening_cloud_mask = mask_cleanup(cloud_mask, mask_filters=[('opening', buffer)])
            
            # Combine mask with cloud and shadow
            combined_opening_cloud_mask = shadow_mask | opening_cloud_mask
            
            # Run combined mask against fixed dilation on cloud + shadow
            combined_opening_dilation_mask = mask_cleanup(combined_opening_cloud_mask, mask_filters=[('dilation', dilation)])
            
            # Apply no data mask
            combined_mask = combined_opening_dilation_mask | nodata_mask

            # Run mask on dataset
            ds_buffer = ds.where(~combined_mask)
            
            # Set invalid nodata pixels to NaN
            ds_buffer_valid = mask_invalid_data(ds_buffer)

            # Plot the inividual time series
            """
            rgb(
                ds_buffer, 
                size=10,
                col="time",
                #savefig_path=f"{output_dir}/geomedian_{path_row}_{product}_{buffer}_{time_period}.tif"
            )
            """
            
            # Apply geomedian
            print(f"Calculating geomedian with {buffer} pixel buffer applied")
            ds_geomedian = apply_geomedian(ds_buffer_valid, geomedian_threads)
            
            # Convert to float, setting all nodata pixels to `np.nan` (required
            # for the standard deviation calculation)
            # ds_geomedian = to_f32(ds_geomedian)
            
            # Save the geomedian data
            rgb(ds_geomedian, size=10, savefig_path=f"{output_dir}/geomedian_{path_row}_{product}_opening_{buffer}_dilation_{dilation}_{str(time_period)}.tif")
            ds_geomedian.nbart_blue.rio.to_raster(f"{output_dir}/geomedian_{path_row}_{product}_opening_{buffer}_dilation_{dilation}_pixels__{time_period}_nbart_blue.tif")
            ds_geomedian.nbart_red.rio.to_raster(f"{output_dir}/geomedian_{path_row}_{product}_opening_{buffer}_dilation_{dilation}_pixels_{time_period}_nbart_red.tif")
            ds_geomedian.nbart_green.rio.to_raster(f"{output_dir}/geomedian_{path_row}_{product}_opening_{buffer}_dilation_{dilation}_pixels_{time_period}_nbart_green.tif")
            ds_geomedian.nbart_nir.rio.to_raster(f"{output_dir}/geomedian_{path_row}_{product}_opening_{buffer}_dilation_{dilation}_pixels_{time_period}_nbart_nir.tif")
            ds_geomedian.nbart_swir_1.rio.to_raster(f"{output_dir}/geomedian_{path_row}_{product}_opening_{buffer}_dilation_{dilation}_pixels_{time_period}_nbart_swir_1.tif")
            ds_geomedian.nbart_swir_2.rio.to_raster(f"{output_dir}/geomedian_{path_row}_{product}_opening_{buffer}_dilation_{dilation}_pixels_{time_period}_nbart_swir_2.tif")
            ds_geomedian.oa_fmask.rio.to_raster(f"{output_dir}/geomedian_{path_row}_{product}_opening_{buffer}_dilation_{dilation}_pixels_{time_period}_oa_fmask.tif")
            
            # drop fmask from std results
            ds_geomedian = ds_geomedian.drop("oa_fmask")
            
            print("Calc standard deviation of the difference between 0 and X buffer")
            std_ds = ds_geomedian.std().mean().compute()
            std_df = std_ds.to_array().to_dataframe(name="std")
            output_dict[buffer] = std_df
            # Print results
            print(
                f"Buffer in pixels: {buffer}, {': '.join(std_df.round(1).to_string(index_names=False).split())}"
            )
            
        # Concatenate outputs into a single dataframe, then unstack to wide 
        # format with each variable as a column
        std_geomedian_df = pd.concat(output_dict, names=["pixel_buffer", "variable"])[
            "std"
        ].unstack("variable")
        
        # Export results as csv with a "product" and "path_row" column
        std_geomedian_df.assign(product=product, path_row=path_row).to_csv(
            f"{output_dir}/geomedian_std_buffer{path_row}_{product}_opening_{buffer}_dilation_{dilation}_{str(time_period)}.csv", index=True
        )

        # Plot the standard deviation and gradient results
        plot_std_gradient_buffer(
            "Standard Deviation",
            std_geomedian_df,
            path_row,
            product,
            start_buffering,
            end_buffering,
            str(time_period),
            f"opening_dilation_{dilation}",
            export_figure=True,
        )