In [None]:
import datacube
import rasterio
import boto3
import xarray as xr
import numpy as np
import re
from datacube.utils.dask import start_local_dask
from datacube import Datacube
from osgeo import ogr, gdal, osr
from scipy.stats import norm
import pandas as pd
import matplotlib.pyplot as plt
import os
import scipy.stats as sps
import awswrangler as wr
import rioxarray as rioxr
import geopandas as gpd
from datetime import datetime

from datacube.utils.dask import start_local_dask
from datacube import Datacube

from dea_tools.spatial import xr_vectorize, xr_rasterize
from matplotlib.lines import Line2D

In [None]:
client = start_local_dask(n_workers=1, threads_per_worker=60, memory_limit='400GB')
client

In [None]:
def generate_seamask(shape_file, data_shape, data_crs, orig_coords, resolution):
    """
        creak mask without oceans
        input:
            shape_file: the shape file of Australia coastline
            data_shape: the shape of loaded data to be masked upon
            orig_coords: the origin of the image for gdal to decide the transform
            resolution: pixel size with signs, e.g., (30, -30) for C3 and (25, -25) for C2
        output:
            a numpy array of mask, where valid pixels = 1
    """
    source_ds = ogr.Open(shape_file)
    source_layer = source_ds.GetLayer()
    source_layer.SetAttributeFilter("FEAT_CODE!='sea'")

    yt, xt = data_shape
    xres = resolution[0]
    yres = resolution[1]
    no_data = 0

    xcoord, ycoord = orig_coords
    geotransform = (xcoord - (xres*0.5), xres, 0, ycoord - (yres*0.5), 0, yres)

    target_ds = gdal.GetDriverByName('MEM').Create('', xt, yt, gdal.GDT_Byte)
    target_ds.SetGeoTransform(geotransform)
    albers = osr.SpatialReference()
    albers.ImportFromEPSG(int(data_crs))
    target_ds.SetProjection(albers.ExportToWkt())
    band = target_ds.GetRasterBand(1)
    band.SetNoDataValue(no_data)

    gdal.RasterizeLayer(target_ds, [1], source_layer, burn_values=[1])
    return band.ReadAsArray()

In [None]:
def derive_polygons(wo_annual_dataset, au_shape, o_file):
    """
    Generate polygons of waterbodies with annual wo summary
    """
    c3_land_raster = generate_seamask(au_shape, wo_annual_dataset.count_clear.shape[1:], 3577,
                                  (wo_annual_dataset.x.data.min(), wo_annual_dataset.y.data.max()), (30, -30))
    count_clear = wo_annual_dataset.count_clear.where(wo_annual_dataset.count_clear > -999, 0).sum(dim='time').load()
    count_wet = wo_annual_dataset.count_wet.where(wo_annual_dataset.count_wet > -999, 0).sum(dim='time').load()
    frequency = count_wet/count_clear
    polygons_low = xr_vectorize((frequency > fq_low) & (count_clear >= t_pixels) & c3_land_raster)1
    polygons_high = xr_vectorize((frequency > fq_high) & (count_clear >= t_pixels) & c3_land_raster)
    polygons_low = polygons_low[(polygons_low.attribute>0) & (polygons_low.geometry.area <= max_area)]
    polygons_high = polygons_high[(polygons_high.geometry.area >= min_area) & (polygons_high.geometry.area <= max_area)
                                       & (polygons_high.attribute>0)]
    filtered_polygons = gpd.sjoin(left_df=polygons_low, 
                 right_df=polygons_high, 
                 how="inner", predicate="intersects").reset_index().drop_duplicates(subset=["index"])
    filtered_polygons = pd.concat([filtered_polygons, polygons_high]).dissolve().buffer(1e-4).explode(index_parts=True).reset_index().drop(columns=["level_0", "level_1"])
    filtered_polygons.to_file(o_file, driver='GeoJSON')

In [None]:
def explode_large_polygons(large_polygons, pp_thresh=0.005):
    """
    explode large polygons according to compactness threshold
    """
    tmp_polygons = large_polygons
    exploded_polygons = []
    b_size = 50
    pp_values = tmp_polygons.geometry.area * 4 *np.pi / (tmp_polygons.geometry.length ** 2)
    tmp_polygons = tmp_polygons[pp_values <= pp_thresh].buffer(-b_size)
    while tmp_polygons.size > 0:
        print(tmp_polygons)
        i = 1
        while (tmp_polygons.type == "Polygon").any():
            i += 1
            print(i*b_size)
            tmp_polygons = tmp_polygons.buffer(-b_size)
        tmp_polygons = tmp_polygons.explode(index_parts=True).buffer(i*b_size)
        pp_values = tmp_polygons.geometry.area * 4 *np.pi / (tmp_polygons.geometry.length ** 2)
        exploded_polygons += [tmp_polygons[pp_values > pp_thresh]]
        tmp_polygons = tmp_polygons[pp_values <= pp_thresh].buffer(-i*b_size)
    return gpd.GeoDataFrame(geometry=pd.concat(exploded_polygons)).reset_index().drop(columns=["level_0", "level_1"])

In [None]:
wo_x_dirs = wr.s3.list_directories(wo_bucket)
wo_file_dirs = []
for x_idx in wo_x_dirs:
    wo_file_dirs += wr.s3.list_directories(x_idx)

In [None]:
landsat_shape = "./landsat_au/landsat_au.shp"
grid_shape = "./au-grid.geojson"
au_shape = "./aus_map/cstauscd_r_3577.shp"

In [None]:
wo_bands = ['count_clear', 'count_wet']

In [None]:
t_pixels = 128
fq_low = 0.05
fq_high = 0.1
min_area = 3125
max_area = 5000000000

In [None]:
# generate polygons from dilated wo annual summary
for f_dir in wo_file_dirs:
    print(f_dir)
    if os.path.exists('_'.join(f_dir.split("/")[-3:-1])+'_dilation_6.geojson'):
        continue
    wo_dataset = None
    for i in range(1987, 2023):
        dataset = None
        for band in wo_bands:
            non_empty_list = wr.s3.list_objects(f_dir + str(i) + "--P1Y", boto3_session=session, suffix=[band+'.tif'])
            if non_empty_list == []:
                continue
            for o in non_empty_list:
                data = rioxr.open_rasterio(o, chunks={'x':3200, 'y':3200})
                data.name = band
            tmp_set = data.to_dataset()
            tmp_set = tmp_set.rename_dims({'band': 'time'})
            tmp_set = tmp_set.rename_vars({'band': 'time'})
            tmp_set.time.data[0] = i
            if dataset is None:
                dataset = tmp_set
            else:
                dataset = xr.merge([dataset, tmp_set])
        if wo_dataset is None:
            wo_dataset = dataset
        else:
            wo_dataset = xr.concat([wo_dataset, dataset], dim='time')
    derive_polygons(wo_dataset, au_shape, '_'.join(f_dir.split("/")[-3:-1])+'_dilation_6.geojson')

In [None]:
# generate polygons from non-dilated wo summary
dc = datacube.Datacube()
for f_dir in wo_file_dirs:
    if os.path.exists('_'.join(f_dir.split("/")[-3:-1])+'.geojson'):
        continue
    region_code = ''.join(f_dir.split("/")[-3:-1])
    print(region_code)
    datasets = dc.find_datasets(product='ga_ls_wo_fq_cyear_3', region_code=region_code)
    wo_dataset = dc.load(
        datasets=datasets,
        measurements=wo_bands,
        output_crs="EPSG:3577",
        resolution=(-30, 30),
        dask_chunks={"time": 1}
    )
    derive_polygons(wo_dataset, au_shape, '_'.join(f_dir.split("/")[-3:-1])+'.geojson')

In [None]:
w_polygons_dilation = None
w_polygons = None
for f_dir in wo_file_dirs:
    print(f_dir)
    if os.path.exists('_'.join(f_dir.split("/")[-3:-1])+"_poly.png"):
        continue
    w_polygons_dilation = gpd.GeoDataFrame.from_file('_'.join(f_dir.split("/")[-3:-1])+'_dilation_6.geojson')
    w_polygons = gpd.GeoDataFrame.from_file('_'.join(f_dir.split("/")[-3:-1])+'.geojson')
    print(f"load finish {datetime.now()}")
    u_dilation = w_polygons_dilation.overlay(w_polygons, how='difference')
    u_current = w_polygons.overlay(w_polygons_dilation, how='difference')
    print(f"diff finish {datetime.now()}")
    fig, ax = plt.subplots(figsize=(16, 16))
    w_polygons_dilation.plot(ax=ax, color='hotpink', label='dilation 6');
    w_polygons.plot(ax=ax, color='steelblue', alpha=0.5, label='no dilation');
    lines = [
        Line2D([0], [0], linestyle="none", marker="s", markersize=10, markerfacecolor=t.get_facecolor())
        for t in ax.collections[-2:]
    ]
    labels = [t.get_label() for t in ax.collections[-2:]]
    ax.legend(lines, labels)
    plt.savefig('_'.join(f_dir.split("/")[-3:-1])+"_poly.png")
    
    fig, ax = plt.subplots(figsize=(16, 16)) 
    u_dilation.plot(ax=ax, edgecolor='hotpink', cmap='Set1', figsize=(16, 16), label="dilation-current")
    u_current.plot(ax=ax, alpha=0.5, edgecolor='steelblue', cmap='Set1', label="current-dilation")
    lines = [
        Line2D([0], [0], linestyle="none", marker="s", markersize=10, markerfacecolor=t.get_edgecolor())
        for t in ax.collections
    ]
    labels = [t.get_label() for t in ax.collections]
    ax.legend(lines, labels)
    plt.savefig('_'.join(f_dir.split("/")[-3:-1])+"_diff.png")
    
    fig = plt.figure(figsize=(16, 9))
    c_dens, c_bins, _ = plt.hist(u_dilation.geometry.area/(30*30), bins=30, label="dilation-current", color="pink")
    plt.hist(u_current.geometry.area/(30*30), bins=c_bins, alpha=.5, label="current-dilation", color="steelblue")
    plt.legend(loc="upper right")
    plt.savefig('_'.join(f_dir.split("/")[-3:-1])+"_area.png")

In [None]:
# load and further process all the polygons saved in the geojson earlier
o_dilation = []
o_current = []

for f_dir in wo_file_dirs:
    w_polygons_dilation = gpd.GeoDataFrame.from_file('_'.join(f_dir.split("/")[-3:-1])+'_dilation_6.geojson')
    w_polygons = gpd.GeoDataFrame.from_file('_'.join(f_dir.split("/")[-3:-1])+'.geojson')
    o_dilation += [w_polygons_dilation]
    o_current += [w_polygons]
o_dilation = pd.concat(o_dilation)
o_current = pd.concat(o_current)
exploded_polygons_dilation = explode_large_polygons(o_dilation)
exploded_polygons_current = explode_large_polygons(o_current)

sorted_current = exploded_polygons_current.sort_values(by="geometry",
                                                    key=lambda col: col.values.area,
                                                    ascending=False)
sorted_dilation = exploded_polygons_dilation.sort_values(by="geometry", 
                                                        key=lambda col: col.values.area, 
                                                        ascending=False)
u_current = exploded_polygons_current.overlay(exploded_polygons_dilation,
                                              how='difference').dissolve().explode(index_parts=True).reset_index().drop(columns=["level_0", "level_1"])
u_dilation = exploded_polygons_dilation.overlay(exploded_polygons_current,
                                                how='difference').dissolve().explode(index_parts=True).reset_index().drop(columns=["level_0", "level_1"])

In [None]:
# plot the distribution of polygon area
fig, ax1 = plt.subplots(figsize=(16, 9))
num_boxes = 2
labels = ["dilation", "current"]
ax1.boxplot([np.log(exploded_polygons_dilation.geometry.area/(30*30)), 
             np.log(exploded_polygons_current.geometry.area/(30*30))], 0, 'o', 0)
ax1.set_ylim(0.5, num_boxes + 0.5)
ax1.set_yticklabels(labels,rotation=0, fontsize=12)
# plt.legend(loc="upper right")
plt.savefig("all_tiles_area.png")

In [None]:
# plot the difference of polygon in area
fig, ax1 = plt.subplots(figsize=(16, 9))
num_boxes = 2
labels = ["dilation-current", "current-dilation"]
ax1.boxplot([np.log(u_dilation.geometry.area), 
             np.log(u_current.geometry.area)], 0, 'o', 0)
ax1.set_ylim(0.5, num_boxes + 0.5)
ax1.set_yticklabels(labels,rotation=0, fontsize=12)
# plt.legend(loc="upper right")
plt.savefig("all_tiles_area_diff.png")

In [None]:
# plot histgram of polygons area
fig = plt.figure(figsize=(16, 9))
c_dens, c_bins, _ = plt.hist(np.log(exploded_polygons_dilation.geometry.area), bins=50, 
                             density=True, label="dilation", color="pink")
plt.hist(np.log(exploded_polygons_current.geometry.area), bins=c_bins, 
         density=True, alpha=.5, label="current", color="steelblue")
plt.legend(loc="upper right")
plt.savefig("all_tiles_area_hist.png")
plt.show()

In [None]:
# plot the histgram of difference of polygons in area
fig = plt.figure(figsize=(16, 9))
c_dens, c_bins, _ = plt.hist(np.log(u_dilation.geometry.area), bins=50, 
                             density=True, label="dilation-current", color="pink")
plt.hist(np.log(u_current.geometry.area), bins=c_bins, 
         density=True, alpha=.5, label="current-dilation", color="steelblue")
plt.legend(loc="upper right")
plt.savefig("all_tiles_area_diff_hist.png")
plt.show()

In [None]:
# plot the difference of top 20 polygons
for i in range(20):
    fig, ax = plt.subplots(figsize=(16, 16)) 
    if i == 11:
        sorted_dilation.iloc[i:i+1].plot(ax=ax, color='hotpink', label='dilation 6');
        sorted_current.iloc[i+1:i+2].plot(ax=ax, color="steelblue", alpha=.5, label='current')
    elif i == 12:
        sorted_dilation.iloc[i:i+1].plot(ax=ax, color='hotpink', label='dilation 6');
        sorted_current.iloc[i-1:i].plot(ax=ax, color="steelblue", alpha=.5, label='current')
    else:
        sorted_dilation.iloc[i:i+1].plot(ax=ax, color='hotpink', label='dilation 6');
        sorted_current.iloc[i:i+1].plot(ax=ax, color="steelblue", alpha=.5, label='current')
    lines = [
        Line2D([0], [0], linestyle="none", marker="s", markersize=10, markerfacecolor=t.get_facecolor())
        for t in ax.collections[-2:]
    ]
    labels = [t.get_label() for t in ax.collections[-2:]]
    ax.legend(lines, labels)
    plt.savefig("large_poly_"+str(i)+".png")